nix-infra/pkgs/wireguard-nat-nftables/src/wireguard-nat-nftables.py

93 lines
3.7 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
import nftables
import json
import subprocess
import time
import sys
def main():
f = open(sys.argv[1], "r")
config = json.loads(f.read())
f.close()
interface = config["interface"]
wg_interface = config["wg_interface"]
pubkey_port_mapping = config["pubkey_port_mapping"]
nft = nftables.Nftables()
nft.set_json_output(True)
nft.set_handle_output(True)
# add nat table rules for dnat and snat masquerade
nft.cmd("add table nat")
nft.cmd("add chain nat prerouting { type nat hook prerouting priority -100; }")
nft.cmd("add chain nat postrouting { type nat hook postrouting priority 100; }")
# load current nftables rules
rc, output, error = nft.cmd("list ruleset")
if error:
print(error, file=sys.stderr)
nftables_output = json.loads(output)
add_masquerade = True
for item in nftables_output["nftables"]:
if ("rule" in item
and item["rule"]["family"] == "ip"
and item["rule"]["table"] == "nat"
and item["rule"]["chain"] == "postrouting"
and "masquerade" in item["rule"]["expr"][0]
):
add_masquerade = False
break
if add_masquerade:
nft.cmd("add rule nat postrouting masquerade")
while True:
# list WireGuard peer endpoint addresses of WireGuard VPN connection
process = subprocess.Popen(["wg", "show", wg_interface, "endpoints"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
lines = stdout.decode().split("\n")[:-1]
if stderr:
print("{}: {}".format(wg_interface, stderr.decode()), file=sys.stderr)
else:
# map destination port to IP
port_ip_mapping = {}
for line in lines:
pubkey = line.split("\t")[0]
ip = line.split("\t")[1].split(":")[0] # probably only works for IPv4
for port in pubkey_port_mapping[pubkey]:
port_ip_mapping[port] = ip
# load current nftables rules
rc, output, error = nft.cmd("list ruleset")
if error:
print(error, file=sys.stderr)
nftables_output = json.loads(output)
# update existing nftable dnat rules, if the remote IP mismatches
for item in nftables_output["nftables"]:
if "rule" in item and item["rule"]["family"] == "ip" and item["rule"]["table"] == "nat" and item["rule"]["chain"] == "prerouting":
handle = item["rule"]["handle"]
ip = item["rule"]["expr"][2]["dnat"]["addr"]
port = item["rule"]["expr"][1]["match"]["right"]
if not ip == port_ip_mapping[port]:
rc, output, error = nft.cmd("replace rule nat prerouting handle {} iif {} udp dport {} dnat to {}".format(handle, interface, port, port_ip_mapping[port]))
if error:
eprint(error)
else:
print("Changed dnat address from {} to {} for UDP port {}".format(ip, port_ip_mapping[port], port))
port_ip_mapping.pop(port)
# loop through all remaining ports and add needed dnat rules
for port in port_ip_mapping:
rc, output, error = nft.cmd("add rule nat prerouting iif {} udp dport {} dnat to {}".format(interface, port, port_ip_mapping[port]))
if error:
print(error, file=sys.stderr)
else:
print("Added dnat rule from UDP port {} to address {}".format(port, port_ip_mapping[port]))
time.sleep(10)
if __name__ == "__main__":
main()