#!/usr/bin/env python3 import nftables import json import subprocess import time import sys def main(): f = open(sys.argv[1], "r") config = json.loads(f.read()) f.close() interface = config["interface"] wg_interface = config["wg_interface"] pubkey_port_mapping = config["pubkey_port_mapping"] nft = nftables.Nftables() nft.set_json_output(True) nft.set_handle_output(True) # add nat table rules for dnat and snat masquerade nft.cmd("add table nat") nft.cmd("add chain nat prerouting { type nat hook prerouting priority -100; }") nft.cmd("add chain nat postrouting { type nat hook postrouting priority 100; }") # load current nftables rules rc, output, error = nft.cmd("list ruleset") if error: print(error, file=sys.stderr) nftables_output = json.loads(output) add_masquerade = True for item in nftables_output["nftables"]: if ("rule" in item and item["rule"]["family"] == "ip" and item["rule"]["table"] == "nat" and item["rule"]["chain"] == "postrouting" and "masquerade" in item["rule"]["expr"][0] ): add_masquerade = False break if add_masquerade: nft.cmd("add rule nat postrouting masquerade") while True: # list WireGuard peer endpoint addresses of WireGuard VPN connection process = subprocess.Popen(["wg", "show", wg_interface, "endpoints"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) stdout, stderr = process.communicate() lines = stdout.decode().split("\n")[:-1] if stderr: print("{}: {}".format(wg_interface, stderr.decode()), file=sys.stderr) else: # map destination port to IP port_ip_mapping = {} for line in lines: pubkey = line.split("\t")[0] ip = line.split("\t")[1].split(":")[0] # probably only works for IPv4 for port in pubkey_port_mapping[pubkey]: port_ip_mapping[port] = ip # load current nftables rules rc, output, error = nft.cmd("list ruleset") if error: print(error, file=sys.stderr) nftables_output = json.loads(output) # update existing nftable dnat rules, if the remote IP mismatches for item in nftables_output["nftables"]: if "rule" in item and item["rule"]["family"] == "ip" and item["rule"]["table"] == "nat" and item["rule"]["chain"] == "prerouting": handle = item["rule"]["handle"] ip = item["rule"]["expr"][2]["dnat"]["addr"] port = item["rule"]["expr"][1]["match"]["right"] if not ip == port_ip_mapping[port]: rc, output, error = nft.cmd("replace rule nat prerouting handle {} iif {} udp dport {} dnat to {}".format(handle, interface, port, port_ip_mapping[port])) if error: eprint(error) else: print("Changed dnat address from {} to {} for UDP port {}".format(ip, port_ip_mapping[port], port)) port_ip_mapping.pop(port) # loop through all remaining ports and add needed dnat rules for port in port_ip_mapping: rc, output, error = nft.cmd("add rule nat prerouting iif {} udp dport {} dnat to {}".format(interface, port, port_ip_mapping[port])) if error: print(error, file=sys.stderr) else: print("Added dnat rule from UDP port {} to address {}".format(port, port_ip_mapping[port])) time.sleep(10) if __name__ == "__main__": main()