#!/usr/bin/env python3

import nftables
import json
import subprocess
import time
import sys

def main():
    f = open(sys.argv[1], "r")
    config = json.loads(f.read())
    f.close()

    interface = config["interface"]
    wg_interface = config["wg_interface"]
    pubkey_port_mapping = config["pubkey_port_mapping"]

    nft = nftables.Nftables()
    nft.set_json_output(True)
    nft.set_handle_output(True)

    # add nat table rules for dnat and snat masquerade
    nft.cmd("add table nat")
    nft.cmd("add chain nat prerouting { type nat hook prerouting priority -100; }")
    nft.cmd("add chain nat postrouting { type nat hook postrouting priority 100; }")
    
    # load current nftables rules
    rc, output, error = nft.cmd("list ruleset")
    if error:
        print(error, file=sys.stderr)
    nftables_output = json.loads(output)

    add_masquerade = True
    for item in nftables_output["nftables"]:
        if ("rule" in item 
            and item["rule"]["family"] == "ip"
            and item["rule"]["table"] == "nat"
            and item["rule"]["chain"] == "postrouting"
            and "masquerade" in item["rule"]["expr"][0]
        ):
            add_masquerade = False
            break
    if add_masquerade:
        nft.cmd("add rule nat postrouting masquerade")

    while True:
        # list WireGuard peer endpoint addresses of WireGuard VPN connection
        process = subprocess.Popen(["wg", "show", wg_interface, "endpoints"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        stdout, stderr = process.communicate()
        lines = stdout.decode().split("\n")[:-1]
        if stderr:
            print("{}: {}".format(wg_interface, stderr.decode()), file=sys.stderr)
        else:
            # map destination port to IP
            port_ip_mapping = {}
            for line in lines:
                pubkey = line.split("\t")[0]
                ip = line.split("\t")[1].split(":")[0] # probably only works for IPv4
                for port in pubkey_port_mapping[pubkey]:
                    port_ip_mapping[port] = ip

            # load current nftables rules
            rc, output, error = nft.cmd("list ruleset")
            if error:
                print(error, file=sys.stderr)
            nftables_output = json.loads(output)

            # update existing nftable dnat rules, if the remote IP mismatches
            for item in nftables_output["nftables"]:
                if "rule" in item and item["rule"]["family"] == "ip" and item["rule"]["table"] == "nat" and item["rule"]["chain"] == "prerouting":
                    handle = item["rule"]["handle"]
                    ip = item["rule"]["expr"][2]["dnat"]["addr"]
                    port = item["rule"]["expr"][1]["match"]["right"]
                    if not ip == port_ip_mapping[port]:
                        rc, output, error = nft.cmd("replace rule nat prerouting handle {} iif {} udp dport {} dnat to {}".format(handle, interface, port, port_ip_mapping[port]))
                        if error:
                            eprint(error)
                        else:
                            print("Changed dnat address from {} to {} for UDP port {}".format(ip, port_ip_mapping[port], port))
                    port_ip_mapping.pop(port)

            # loop through all remaining ports and add needed dnat rules
            for port in port_ip_mapping:
                rc, output, error = nft.cmd("add rule nat prerouting iif {} udp dport {} dnat to {}".format(interface, port, port_ip_mapping[port]))
                if error:
                    print(error, file=sys.stderr)
                else:
                    print("Added dnat rule from UDP port {} to address {}".format(port, port_ip_mapping[port]))
        time.sleep(10)

if __name__ == "__main__":
    main()