diff --git a/src/App.tsx b/src/App.tsx index ca9c559..bb1fc8f 100644 --- a/src/App.tsx +++ b/src/App.tsx @@ -18,7 +18,8 @@ import { EdgeMouseHandler } from '@xyflow/react'; import '@xyflow/react/dist/style.css'; -import { AppNode, AppEdge, NodeData, EdgeData, NodeDataUpdate, EdgeDataUpdate, Settings, initialSettings, SettingsContext } from './types/graph'; +import { AppNode, AppEdge, NodeData, EdgeData, NodeDataUpdate, EdgeDataUpdate } from './types/graph'; +import {Settings, initialSettings, SettingsContext} from './types/settings' import CustomNode from './components/CustomNode'; import NodeEditor from './components/NodeEditor'; import EdgeEditor from './components/EdgeEditor' diff --git a/src/components/CustomNode.tsx b/src/components/CustomNode.tsx index 6885f91..5fca8c6 100644 --- a/src/components/CustomNode.tsx +++ b/src/components/CustomNode.tsx @@ -1,9 +1,9 @@ import { ReactNode, useContext } from 'react'; import { Handle, Position, NodeProps, useReactFlow} from '@xyflow/react'; import { AppNode, AppEdge, AppGraph, NodeData } from '../types/graph'; -import {Settings, SettingsContext} from '../types/settings' +import {Settings, SettingsContext, SubnetInfo} from '../types/settings' import StringBuilder from '../utils/StringBuilder'; - +import { Result } from '../utils/Result'; import { CIDR, IPUtils } from '../utils/iputils'; import { tryDerivePublicKey } from '../utils/wireguardConfig' import './CustomNode.css'; @@ -17,26 +17,6 @@ class ConfigResult { ) {} } -// function mapAddressToCIDR(nodeIds : string[], getAddress: GetAddress) : Record { -// const nodeIdToCIDR : Record = {}; -// for(const nodeId of nodeIds) { -// const address = getAddress(nodeId); -// if(!address) continue; -// const result = IPUtils.parse(address); -// const cidr = result.cidr; -// if(!cidr) { -// throw new Error("节点地址无效"); -// } -// if(cidr.version === 'IPv4') { -// cidr.mask = 32; -// } else if(cidr.version === 'IPv6') { -// cidr.mask = 128; -// } -// nodeIdToCIDR[nodeId] = cidr; -// } -// return nodeIdToCIDR; -// } - function generateInterfaceConfig(settings: Settings, data: NodeData) : StringBuilder { const address = [data.ipv4Address, data.ipv6Address].flatMap(p => p ? [p] : []).join(', '); const config = new StringBuilder(); @@ -52,8 +32,65 @@ function generateInterfaceConfig(settings: Settings, data: NodeData) : StringBui return config; } -function generateConfig(settings: Settings, data: NodeData, graph: AppGraph) : ConfigResult { - const config = generateInterfaceConfig(settings, data); +function getFromChildRouterMapper(subnet: SubnetInfo, graph: AppGraph, nodeId: string, disallowCIDRs: CIDR[]) : Result> { + const subGraph = graph.getConnectedSubgraph(subnet.nodes.map(kv => kv.nodeId)); + + if(!subGraph) + return Result.Error(`图不连通,无法生存子网路由:${subnet.subnet.toString()}`); + + + const fromChildMapper: Map = new Map(); + const visited: Set = new Set(); + visited.add(nodeId); + const queue : string[] = []; + const children = subGraph.getChildren(nodeId); + children.forEach(child => { + fromChildMapper.set(child, [child]); + visited.add(child); + queue.push(child); + }); + + const allNodes = []; + while(queue.length > 0) { + const curr = queue.shift()!; + const fromChild = fromChildMapper.get(curr); + if(!fromChild) continue; + allNodes.push(curr); + const nexts = subGraph.getChildren(curr); + for(const next of nexts) { + if(visited.has(next)) continue; + fromChild.push(next); + visited.add(next); + queue.push(next); + } + } + + const allCIDRs = allNodes.flatMap((nodeId: string) => { + const cidr = subnet.nodes.find(n => n.nodeId === nodeId)?.cidr; + return cidr ? [cidr] : []; + }); + + const result: Map = new Map(); + for (const [fromChild, nodes] of fromChildMapper) { + const targetCIDRs = nodes.flatMap((nodeId: string) => { + const cidr = subnet.nodes.find(n => n.nodeId === nodeId)?.cidr; + if(!cidr || disallowCIDRs.some(disallow => disallow.contains(cidr))) return []; + return [cidr]; + }); + + const mergeResult = IPUtils.mergeCIDRs(allCIDRs, targetCIDRs); + if(!mergeResult) { + return Result.Error("无法生成路由配置"); + } + result.set(fromChild, mergeResult); + } + + return Result.Result(result); +} + + +function generateConfig(settings: Settings, data: NodeData, graph: AppGraph, subnets: SubnetInfo[]) : ConfigResult { + const config: StringBuilder = generateInterfaceConfig(settings, data); const disallowCIDRs : CIDR[] = []; if(data.disallowIPs) { @@ -67,119 +104,120 @@ function generateConfig(settings: Settings, data: NodeData, graph: AppGraph) : C } } - const belongsToEdge : Record = {[node.id]: node.id}; + const allowIPsConfig: Map = new Map(); - const queue : AppNode[] = []; - const nearEdges = getNearEdges(node); - nearEdges.forEach(edge => { - const nextNode = getNextNode(edge, node)!; - belongsToEdge[nextNode.id] = edge.id; - queue.push(nextNode); - }); - - while(queue.length > 0) { - const currentNode = queue.shift()!; - const fromEdgeId = belongsToEdge[currentNode.id]; - if(!fromEdgeId) continue; - - getNearEdges(currentNode).forEach(edge => { - const nextNode = getNextNode(edge, currentNode)!; - if(!belongsToEdge[nextNode.id]) { - belongsToEdge[nextNode.id] = fromEdgeId; - queue.push(nextNode); + for(const subnet of subnets) { + const result = getFromChildRouterMapper(subnet, graph, data.id, disallowCIDRs); + const mapper = result.result; + if(!mapper) return new ConfigResult(false, undefined, `路由生成失败:${result.errorInfo()}`); + + for(const [fromChild, cidrs] of mapper) { + if(!allowIPsConfig.has(fromChild)) allowIPsConfig.set(fromChild, []); + const allowIPs = allowIPsConfig.get(fromChild)!; + for(const cidr of cidrs) { + allowIPs.push(cidr.toString()); } - }); + } } - - const groupedByEdge: Record = {}; - const nodeIds : string[] = []; - for (const nodeId in belongsToEdge) { - const edgeId = belongsToEdge[nodeId]; - if(edgeId === nodeId) continue; // 跳过起始节点 - nodeIds.push(nodeId); - if(!edgeId) continue; - if (!groupedByEdge[edgeId]) { - groupedByEdge[edgeId] = []; - } - groupedByEdge[edgeId].push(nodeId); - } - - for(const edgeId in groupedByEdge) { - const groupNodeIds = groupedByEdge[edgeId]!; - const edge = getEdge(edgeId)!; - const nextNode = getNextNode(edge, node)!; - const nextNodeData = nextNode.data; - const publicKey = tryDerivePublicKey(nextNodeData.privateKey); + const nexts = graph.getChildren(data.id); + for(const next of nexts) { + const nextData = graph.queryNode(next)!.data; + + const edge = graph.queryEdge(data.id, next)!; + + const publicKey = tryDerivePublicKey(nextData.privateKey); if(!publicKey) return new ConfigResult(false, undefined, "无法从私钥派生公钥"); config.appendLine(""); config.appendLine("[Peer]"); - config.appendLine(`# ${nextNodeData.label}`); + config.appendLine(`# ${nextData.label}`); config.appendLine(`PublicKey = ${ publicKey}`); - if(edge.data?.isTwoWayEdge || edge.source === node.id) { + if(edge.data?.isTwoWayEdge || edge.source === data.id) { if(edge.data?.persistentKeepalive) { config.appendLine(`PersistentKeepalive = ${edge.data.persistentKeepalive}`); } - if(!nextNodeData.listenAddress) { - return new ConfigResult(false, undefined, `节点 ${nextNodeData.label} 未设置监听地址,无法生成配置`); + if(!nextData.listenAddress) { + return new ConfigResult(false, undefined, `节点 ${nextData.label} 未设置监听地址,无法生成配置`); } - const parse = IPUtils.parse(`${nextNodeData.listenAddress}/0`); + const parse = IPUtils.parse(`${nextData.listenAddress}/0`); if(!parse.cidr) { - return new ConfigResult(false, undefined, `节点 ${nextNodeData.label} 的监听地址无效`); + return new ConfigResult(false, undefined, `节点 ${nextData.label} 的监听地址无效`); } - const listenAddress = parse.cidr.version === 'IPv4' ? nextNodeData.listenAddress : `[${nextNodeData.listenAddress}]`; - const listenPort = nextNodeData.listenPort || settings.listenPort; + const listenAddress = parse.cidr.version === 'IPv4' ? nextData.listenAddress : `[${nextData.listenAddress}]`; + const listenPort = nextData.listenPort || settings.listenPort; config.appendLine(`EndPoint = ${listenAddress}:${listenPort}`); } - const subnets : Record[] = []; - - try { - subnets.push(mapAddressToCIDR(nodeIds, nodeId => node.data.ipv4Address && getNode(nodeId)?.data.ipv4Address)); - subnets.push(mapAddressToCIDR(nodeIds, nodeId => node.data.ipv6Address && getNode(nodeId)?.data.ipv6Address)); - } catch(e) { - if(e instanceof Error) { - return new ConfigResult(false, undefined, e.message); - } - } - const allowIPs : string[] = []; - for(const subnetMap of subnets) { - const allCIDRs = nodeIds.flatMap(id => subnetMap[id] ? [subnetMap[id]] : []); - const targetCIDRs = groupNodeIds.flatMap(id => { - const cidr = subnetMap[id]; - if(!cidr || disallowCIDRs.some(disallow => disallow.contains(cidr))) return []; - return [cidr]; - }); - const mergeResult = IPUtils.mergeCIDRs(allCIDRs, targetCIDRs); - if(!mergeResult) { - return new ConfigResult(false, undefined, `无法生成路由配置`); - } - mergeResult.forEach(cidr => {allowIPs.push(cidr.toString())}); - } - if(allowIPs.length > 0) { + const allowIPs = allowIPsConfig.get(next); + if(allowIPs && allowIPs.length > 0) { config.appendLine(`AllowedIPs = ${allowIPs.join(', ')}`); } } - // console.log(config.toString()); - return new ConfigResult(true, config.toString()); } +function getSubnet(nodes: AppNode[], version: 'ipv4' | 'ipv6'): Result { + const addresses: {nodeId: string, cidr: CIDR}[] = []; + + for(const node of nodes) { + const address = version === 'ipv4' ? node.data.ipv4Address : node.data.ipv6Address; + if(!address) continue; + const result = IPUtils.parse(address); + const cidr = result.cidr; + if(!cidr) return Result.Error(result.error); + addresses.push({nodeId: node.id, cidr: cidr}); + } + + if(addresses.length === 0) return Result.Result(undefined); + + const subnet: SubnetInfo = {subnet: addresses[0].cidr, nodes: []} + for(const address of addresses) { + const cidr = address.cidr; + if(!subnet.subnet.contains(cidr)) { + return Result.Error("不在同一子网下"); + } + + subnet.nodes.push({...address, cidr: new CIDR(cidr.version, cidr.binary, version === 'ipv4' ? 32 : 128)}); + } + return Result.Result(subnet); +} + + export default function CustomNode({ data, selected }: NodeProps): ReactNode { const settings = useContext(SettingsContext); - const { getNode, getEdge, getEdges } = useReactFlow(); + const { getNodes, getEdges } = useReactFlow(); const handleGenerate = (node : NodeData) => { - const result = generateConfig(settings, node, getEdges, getEdge, getNode); + const graph = new AppGraph(getNodes, getEdges); + const nodes = getNodes(); + let subnets: SubnetInfo[] = []; + + const v4 = getSubnet(nodes, 'ipv4'); + const v6 = getSubnet(nodes, 'ipv6'); + if(!v4.isValid()) { + toast.error(`ipv4子网配置有误:${v4.errorInfo()}`); + return; + } + + if(!v6.isValid()) { + toast.error(`ipv6子网配置有误:${v6.errorInfo()}`); + return; + } + + if(v4.result) subnets.push(v4.result); + if(v6.result) subnets.push(v6.result); + subnets = subnets.concat(settings.subnets); + + const result = generateConfig(settings, node, graph, subnets); if(result.success && result.config) { navigator.clipboard.writeText(result.config).then(() => { toast.success("配置已复制到剪贴板"); diff --git a/src/components/NodeEditor.tsx b/src/components/NodeEditor.tsx index 5ce6237..603e146 100644 --- a/src/components/NodeEditor.tsx +++ b/src/components/NodeEditor.tsx @@ -1,5 +1,6 @@ import { useState, ReactNode } from 'react'; -import { NodeData, Settings, NodeDataUpdate } from '../types/graph'; +import { NodeData, NodeDataUpdate } from '../types/graph'; +import { Settings } from '../types/settings'; import { generateWireGuardPrivateKey } from '../utils/wireguardConfig' import { IPUtils} from '../utils/iputils' import './FormEditor.css'; @@ -12,7 +13,7 @@ interface NodeEditorProps { onClose: () => void; } -function Validate(updateData : NodeDataUpdate, settings : Settings) : string[] { +function Validate(updateData : NodeDataUpdate) : string[] { const errors: string[] = []; const {ipv4Address, ipv6Address, mtu, listenPort} = updateData; @@ -95,7 +96,7 @@ export default function NodeEditor({ notes: notes } - const validation = Validate(updateData, settings); + const validation = Validate(updateData); setErrors(validation); if(validation.length > 0) { return ; diff --git a/src/types/graph.ts b/src/types/graph.ts index 19d2080..d1483b3 100644 --- a/src/types/graph.ts +++ b/src/types/graph.ts @@ -40,8 +40,8 @@ export class AppGraph { private readonly _getNextNodeIds = new Map(); constructor( + public readonly getNodes: GetNodes, public readonly getEdges: GetEdges, - public readonly getNodes: GetNodes ) { const getNextNodeIds = this._getNextNodeIds; const nodeIds = getNodes().map(node => node.id); @@ -57,9 +57,21 @@ export class AppGraph { } } + queryNode(nodeId: string): AppNode | undefined { + return this.getNodes().find(n => n.id === nodeId); + } - private static checkConnected(graph: AppGraph): boolean { - const nodes = graph.getNodes(); + queryEdge(source: string, target: string): AppEdge | undefined { + return this.getEdges().find(e => + (e.source === source && e.target === target) || (e.target === source && e.source === target)); + } + + getChildren(nodeId: string) : string[] { + return this._getNextNodeIds.get(nodeId) ?? []; + } + + private checkConnected(): boolean { + const nodes = this.getNodes(); if(nodes.length === 0) return true; const visited = new Set(); @@ -68,7 +80,7 @@ export class AppGraph { queue.push(first.id); visited.add(first.id); - const getNextNodeIds = graph._getNextNodeIds; + const getNextNodeIds = this._getNextNodeIds; while(queue.length > 0) { const curr = queue.shift()!; const next = getNextNodeIds.get(curr); @@ -83,7 +95,24 @@ export class AppGraph { return visited.size === nodes.length; } + getConnectedSubgraph(nodeIds: string[]): AppGraph | undefined { - const node + let completed = false; + let graph: AppGraph = this; + + while(!completed) { + completed = true; + const nodes = graph.getNodes(); + + for(const node of nodes) { + if(nodeIds.includes(node.id)) continue; + const newGraph = new AppGraph(() => this.getNodes().filter(n => n.id !== node.id), this.getEdges); + if(!newGraph.checkConnected()) continue; + graph = newGraph; + completed = false; + } + } + + return graph; } } \ No newline at end of file diff --git a/src/utils/Result.ts b/src/utils/Result.ts new file mode 100644 index 0000000..aaf776a --- /dev/null +++ b/src/utils/Result.ts @@ -0,0 +1,26 @@ +export class Result { + private constructor( + public readonly errors: string[], + public readonly result?: T, + ) {} + + isValid(): boolean { + return this.errors.length === 0; + } + + errorInfo(): string { + return this.errors.join("\n"); + } + + static Error(error: string): Result { + return new Result([error]); + } + + static Errors(errors: string[]): Result { + return new Result(errors); + } + + static Result(result: T) { + return new Result([], result); + } +} \ No newline at end of file