83 lines
3.5 KiB
Python
83 lines
3.5 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import nftables
|
|
import json
|
|
import subprocess
|
|
import time
|
|
import sys
|
|
|
|
def main():
|
|
f = open(sys.argv[1], "r")
|
|
config = json.loads(f.read())
|
|
f.close()
|
|
|
|
interface = config["interface"]
|
|
interface_address = config["interface_address"]
|
|
wg_interface = config["wg_interface"]
|
|
pubkey_port_mapping = config["pubkey_port_mapping"]
|
|
|
|
nft = nftables.Nftables(sys.argv[2] + "/libnftables.so.1")
|
|
nft.set_json_output(True)
|
|
nft.set_handle_output(True)
|
|
|
|
# add nat table rules for dnat and snat
|
|
nft.cmd("add table wireguard-nat")
|
|
nft.cmd("flush table wireguard-nat")
|
|
nft.cmd("add chain wireguard-nat prerouting { type nat hook prerouting priority -100; }")
|
|
nft.cmd("add chain wireguard-nat postrouting { type nat hook postrouting priority 100; }")
|
|
nft.cmd("add rule wireguard-nat postrouting oifname {} snat to {}".format(interface, interface_address))
|
|
|
|
# load current nftables rules
|
|
rc, output, error = nft.cmd("list ruleset")
|
|
if error:
|
|
print(error, file=sys.stderr)
|
|
nftables_output = json.loads(output)
|
|
|
|
while True:
|
|
# list WireGuard peer endpoint addresses of WireGuard VPN connection
|
|
process = subprocess.Popen(["wg", "show", wg_interface, "endpoints"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
stdout, stderr = process.communicate()
|
|
lines = stdout.decode().split("\n")[:-1]
|
|
if stderr:
|
|
print("{}: {}".format(wg_interface, stderr.decode()), file=sys.stderr)
|
|
else:
|
|
# map destination port to IP
|
|
port_ip_mapping = {}
|
|
for line in lines:
|
|
pubkey = line.split("\t")[0]
|
|
ip = line.split("\t")[1].split(":")[0] # probably only works for IPv4
|
|
for port in pubkey_port_mapping[pubkey]:
|
|
port_ip_mapping[port] = ip
|
|
|
|
# load current nftables rules
|
|
rc, output, error = nft.cmd("list ruleset")
|
|
if error:
|
|
print(error, file=sys.stderr)
|
|
nftables_output = json.loads(output)
|
|
|
|
# update existing nftable dnat rules, if the remote IP mismatches
|
|
for item in nftables_output["nftables"]:
|
|
if "rule" in item and item["rule"]["family"] == "ip" and item["rule"]["table"] == "wireguard-nat" and item["rule"]["chain"] == "prerouting":
|
|
handle = item["rule"]["handle"]
|
|
ip = item["rule"]["expr"][2]["dnat"]["addr"]
|
|
port = item["rule"]["expr"][1]["match"]["right"]
|
|
if not ip == port_ip_mapping[port]:
|
|
rc, output, error = nft.cmd("replace rule wireguard-nat prerouting handle {} iif {} udp dport {} dnat to {}".format(handle, interface, port, port_ip_mapping[port]))
|
|
if error:
|
|
eprint(error)
|
|
else:
|
|
print("Changed dnat address from {} to {} for UDP port {}".format(ip, port_ip_mapping[port], port))
|
|
port_ip_mapping.pop(port)
|
|
|
|
# loop through all ports and add needed dnat rules
|
|
for port in port_ip_mapping:
|
|
rc, output, error = nft.cmd("add rule wireguard-nat prerouting iif {} udp dport {} dnat to {}".format(interface, port, port_ip_mapping[port]))
|
|
if error:
|
|
print(error, file=sys.stderr)
|
|
else:
|
|
print("Added dnat rule from UDP port {} to address {}".format(port, port_ip_mapping[port]))
|
|
time.sleep(10)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|