82 lines
		
	
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			82 lines
		
	
	
	
		
			3.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#!/usr/bin/env python3
 | 
						|
 | 
						|
import nftables
 | 
						|
import json
 | 
						|
import subprocess
 | 
						|
import time
 | 
						|
import sys
 | 
						|
 | 
						|
def main():
 | 
						|
    f = open(sys.argv[1], "r")
 | 
						|
    config = json.loads(f.read())
 | 
						|
    f.close()
 | 
						|
 | 
						|
    interface = config["interface"]
 | 
						|
    interface_address = config["interface_address"]
 | 
						|
    wg_interface = config["wg_interface"]
 | 
						|
    pubkey_port_mapping = config["pubkey_port_mapping"]
 | 
						|
 | 
						|
    nft = nftables.Nftables(sys.argv[2] + "/libnftables.so.1")
 | 
						|
    nft.set_json_output(True)
 | 
						|
    nft.set_handle_output(True)
 | 
						|
 | 
						|
    # add nat table rules for dnat and snat
 | 
						|
    nft.cmd("add table wireguard-nat")
 | 
						|
    nft.cmd("flush table wireguard-nat")
 | 
						|
    nft.cmd("add chain wireguard-nat prerouting { type nat hook prerouting priority -100; }")
 | 
						|
    nft.cmd("add chain wireguard-nat postrouting { type nat hook postrouting priority 100; }")
 | 
						|
    nft.cmd("add rule wireguard-nat postrouting oifname {} snat to {}".format(interface, interface_address))
 | 
						|
 | 
						|
    # load current nftables rules
 | 
						|
    rc, output, error = nft.cmd("list ruleset")
 | 
						|
    if error:
 | 
						|
        print(error, file=sys.stderr)
 | 
						|
    nftables_output = json.loads(output)
 | 
						|
 | 
						|
    while True:
 | 
						|
        # list WireGuard peer endpoint addresses of WireGuard VPN connection
 | 
						|
        process = subprocess.Popen(["wg", "show", wg_interface, "endpoints"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
 | 
						|
        stdout, stderr = process.communicate()
 | 
						|
        lines = stdout.decode().split("\n")[:-1]
 | 
						|
        if stderr:
 | 
						|
            print("{}: {}".format(wg_interface, stderr.decode()), file=sys.stderr)
 | 
						|
        else:
 | 
						|
            # map destination port to IP
 | 
						|
            port_ip_mapping = {}
 | 
						|
            for line in lines:
 | 
						|
                pubkey = line.split("\t")[0]
 | 
						|
                ip = line.split("\t")[1].split(":")[0] # probably only works for IPv4
 | 
						|
                for port in pubkey_port_mapping[pubkey]:
 | 
						|
                    port_ip_mapping[port] = ip
 | 
						|
 | 
						|
            # load current nftables rules
 | 
						|
            rc, output, error = nft.cmd("list ruleset")
 | 
						|
            if error:
 | 
						|
                print(error, file=sys.stderr)
 | 
						|
            nftables_output = json.loads(output)
 | 
						|
 | 
						|
            # update existing nftable dnat rules, if the remote IP mismatches
 | 
						|
            for item in nftables_output["nftables"]:
 | 
						|
                if "rule" in item and item["rule"]["family"] == "ip" and item["rule"]["table"] == "wireguard-nat" and item["rule"]["chain"] == "prerouting":
 | 
						|
                    handle = item["rule"]["handle"]
 | 
						|
                    ip = item["rule"]["expr"][2]["dnat"]["addr"]
 | 
						|
                    port = item["rule"]["expr"][1]["match"]["right"]
 | 
						|
                    if not ip == port_ip_mapping[port]:
 | 
						|
                        rc, output, error = nft.cmd("replace rule wireguard-nat prerouting handle {} iif {} udp dport {} dnat to {}".format(handle, interface, port, port_ip_mapping[port]))
 | 
						|
                        if error:
 | 
						|
                            eprint(error)
 | 
						|
                        else:
 | 
						|
                            print("Changed dnat address from {} to {} for UDP port {}".format(ip, port_ip_mapping[port], port))
 | 
						|
                    port_ip_mapping.pop(port)
 | 
						|
 | 
						|
            # loop through all ports and add needed dnat rules
 | 
						|
            for port in port_ip_mapping:
 | 
						|
                rc, output, error = nft.cmd("add rule wireguard-nat prerouting iif {} udp dport {} dnat to {}".format(interface, port, port_ip_mapping[port]))
 | 
						|
                if error:
 | 
						|
                    print(error, file=sys.stderr)
 | 
						|
                else:
 | 
						|
                    print("Added dnat rule from UDP port {} to address {}".format(port, port_ip_mapping[port]))
 | 
						|
        time.sleep(10)
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    main()
 |