Add wireguard-nat-nftables python script
This commit is contained in:
		
					parent
					
						
							
								34b8dcef9c
							
						
					
				
			
			
				commit
				
					
						ea11e41005
					
				
			
		
					 6 changed files with 152 additions and 4 deletions
				
			
		
							
								
								
									
										17
									
								
								pkgs/wireguard-nat-nftables/default.nix
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								pkgs/wireguard-nat-nftables/default.nix
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,17 @@
 | 
			
		|||
{ pkgs, ... }:
 | 
			
		||||
let
 | 
			
		||||
  nftablesWithPythonOverlay = final: prev: {
 | 
			
		||||
    nftables = (prev.nftables.override { withPython = true; });
 | 
			
		||||
  };
 | 
			
		||||
  pkgs-overlay = pkgs.extend nftablesWithPythonOverlay;
 | 
			
		||||
in 
 | 
			
		||||
pkgs-overlay.python310Packages.buildPythonApplication {
 | 
			
		||||
  pname = "wireguard-nat-nftables";
 | 
			
		||||
  version = "0.0.1";
 | 
			
		||||
 | 
			
		||||
  propagatedBuildInputs = with pkgs-overlay; [
 | 
			
		||||
    python310Packages.nftables
 | 
			
		||||
  ];
 | 
			
		||||
 | 
			
		||||
  src = ./src;
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										7
									
								
								pkgs/wireguard-nat-nftables/src/setup.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								pkgs/wireguard-nat-nftables/src/setup.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,7 @@
 | 
			
		|||
from distutils.core import setup
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
    name='wireguard-nat-nftables',
 | 
			
		||||
    version='0.0.1',
 | 
			
		||||
    scripts=['wireguard-nat-nftables.py']
 | 
			
		||||
)
 | 
			
		||||
							
								
								
									
										92
									
								
								pkgs/wireguard-nat-nftables/src/wireguard-nat-nftables.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								pkgs/wireguard-nat-nftables/src/wireguard-nat-nftables.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,92 @@
 | 
			
		|||
#!/usr/bin/env python3
 | 
			
		||||
 | 
			
		||||
import nftables
 | 
			
		||||
import json
 | 
			
		||||
import subprocess
 | 
			
		||||
import time
 | 
			
		||||
import sys
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    f = open(sys.argv[1], "r")
 | 
			
		||||
    config = json.loads(f.read())
 | 
			
		||||
    f.close()
 | 
			
		||||
 | 
			
		||||
    interface = config["interface"]
 | 
			
		||||
    wg_interface = config["wg_interface"]
 | 
			
		||||
    pubkey_port_mapping = config["pubkey_port_mapping"]
 | 
			
		||||
 | 
			
		||||
    nft = nftables.Nftables()
 | 
			
		||||
    nft.set_json_output(True)
 | 
			
		||||
    nft.set_handle_output(True)
 | 
			
		||||
 | 
			
		||||
    # add nat table rules for dnat and snat masquerade
 | 
			
		||||
    nft.cmd("add table nat")
 | 
			
		||||
    nft.cmd("add chain nat prerouting { type nat hook prerouting priority -100; }")
 | 
			
		||||
    nft.cmd("add chain nat postrouting { type nat hook postrouting priority 100; }")
 | 
			
		||||
    
 | 
			
		||||
    # load current nftables rules
 | 
			
		||||
    rc, output, error = nft.cmd("list ruleset")
 | 
			
		||||
    if error:
 | 
			
		||||
        print(error, file=sys.stderr)
 | 
			
		||||
    nftables_output = json.loads(output)
 | 
			
		||||
 | 
			
		||||
    add_masquerade = True
 | 
			
		||||
    for item in nftables_output["nftables"]:
 | 
			
		||||
        if ("rule" in item 
 | 
			
		||||
            and item["rule"]["family"] == "ip"
 | 
			
		||||
            and item["rule"]["table"] == "nat"
 | 
			
		||||
            and item["rule"]["chain"] == "postrouting"
 | 
			
		||||
            and "masquerade" in item["rule"]["expr"][0]
 | 
			
		||||
        ):
 | 
			
		||||
            add_masquerade = False
 | 
			
		||||
            break
 | 
			
		||||
    if add_masquerade:
 | 
			
		||||
        nft.cmd("add rule nat postrouting masquerade")
 | 
			
		||||
 | 
			
		||||
    while True:
 | 
			
		||||
        # list WireGuard peer endpoint addresses of WireGuard VPN connection
 | 
			
		||||
        process = subprocess.Popen(["wg", "show", wg_interface, "endpoints"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
 | 
			
		||||
        stdout, stderr = process.communicate()
 | 
			
		||||
        lines = stdout.decode().split("\n")[:-1]
 | 
			
		||||
        if stderr:
 | 
			
		||||
            print("{}: {}".format(wg_interface, stderr.decode()), file=sys.stderr)
 | 
			
		||||
        else:
 | 
			
		||||
            # map destination port to IP
 | 
			
		||||
            port_ip_mapping = {}
 | 
			
		||||
            for line in lines:
 | 
			
		||||
                pubkey = line.split("\t")[0]
 | 
			
		||||
                ip = line.split("\t")[1].split(":")[0] # probably only works for IPv4
 | 
			
		||||
                for port in pubkey_port_mapping[pubkey]:
 | 
			
		||||
                    port_ip_mapping[port] = ip
 | 
			
		||||
 | 
			
		||||
            # load current nftables rules
 | 
			
		||||
            rc, output, error = nft.cmd("list ruleset")
 | 
			
		||||
            if error:
 | 
			
		||||
                print(error, file=sys.stderr)
 | 
			
		||||
            nftables_output = json.loads(output)
 | 
			
		||||
 | 
			
		||||
            # update existing nftable dnat rules, if the remote IP mismatches
 | 
			
		||||
            for item in nftables_output["nftables"]:
 | 
			
		||||
                if "rule" in item and item["rule"]["family"] == "ip" and item["rule"]["table"] == "nat" and item["rule"]["chain"] == "prerouting":
 | 
			
		||||
                    handle = item["rule"]["handle"]
 | 
			
		||||
                    ip = item["rule"]["expr"][2]["dnat"]["addr"]
 | 
			
		||||
                    port = item["rule"]["expr"][1]["match"]["right"]
 | 
			
		||||
                    if not ip == port_ip_mapping[port]:
 | 
			
		||||
                        rc, output, error = nft.cmd("replace rule nat prerouting handle {} iif {} udp dport {} dnat to {}".format(handle, interface, port, port_ip_mapping[port]))
 | 
			
		||||
                        if error:
 | 
			
		||||
                            eprint(error)
 | 
			
		||||
                        else:
 | 
			
		||||
                            print("Changed dnat address from {} to {} for UDP port {}".format(ip, port_ip_mapping[port], port))
 | 
			
		||||
                    port_ip_mapping.pop(port)
 | 
			
		||||
 | 
			
		||||
            # loop through all remaining ports and add needed dnat rules
 | 
			
		||||
            for port in port_ip_mapping:
 | 
			
		||||
                rc, output, error = nft.cmd("add rule nat prerouting iif {} udp dport {} dnat to {}".format(interface, port, port_ip_mapping[port]))
 | 
			
		||||
                if error:
 | 
			
		||||
                    print(error, file=sys.stderr)
 | 
			
		||||
                else:
 | 
			
		||||
                    print("Added dnat rule from UDP port {} to address {}".format(port, port_ip_mapping[port]))
 | 
			
		||||
        time.sleep(10)
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue