Add wireguard-nat-nftables python script

This commit is contained in:
fi 2023-09-17 04:50:07 +02:00
parent 667b1c256b
commit 299d04142f
Signed by: fi
SSH key fingerprint: SHA256:d+6fQoDPMbSFK95zRVflRKZLRKF4cPSQb7VIxYkhFsA
6 changed files with 152 additions and 4 deletions

View file

@ -4,5 +4,6 @@
./configuration.nix
./nginx.nix
./containers/uptime-kuma
./services.nix
];
}

View file

@ -0,0 +1,30 @@
{ pkgs, ... }:
let
wireguard-nat-nftables = import ../../../pkgs/wireguard-nat-nftables pkgs;
config = pkgs.writeText "wireguard-nat-nftables-config" (builtins.toJSON {
interface = "ens3";
wg_interface = "wg0";
pubkey_port_mapping = {
"SJ8xCRb4hWm5EnXoV4FnwgbiaxmY2wI+xzfk+3HXERg=" = [ 51827 51829 ];
"BbNeBTe6HwQuHPK+ZQXWYRZJJMPdS0h81n07omYyRl4=" = [ 51828 51830 ];
"u9h+D8XZ62ABnetBRKnf6tjs+tJwM8fQ4d6ipOCLFyE=" = [ 51821 51824 ];
};
});
in
{
systemd.services.wireguard-nat-nftables = {
description = "A python script to update nftable dnat rules based on WireGuard peer IPs";
requires = [ "wireguard-wg0.service" ];
after = [ "wireguard-wg0.service" ];
script = ''
${wireguard-nat-nftables}/bin/wireguard-nat-nftables.py ${config}
'';
serviceConfig = {
Type = "simple";
User = "root";
Group = "root";
};
};
}

View file

@ -9,7 +9,8 @@
simple-nixos-mailserver.url = "gitlab:simple-nixos-mailserver/nixos-mailserver/nixos-23.05";
};
outputs = { self, nixpkgs, nixpkgs-unstable, nixos-generators, simple-nixos-mailserver, ... }@inputs: let
outputs = { self, nixpkgs, nixpkgs-unstable, nixos-generators, simple-nixos-mailserver, ... }@inputs:
let
hosts = import ./hosts.nix inputs;
helper = import ./helper.nix inputs;
in {
@ -32,9 +33,9 @@
} // builtins.mapAttrs (helper.generateColmenaHost) hosts;
hydraJobs = {
nixConfigurations = builtins.mapAttrs (
host: helper.generateNixConfiguration host { inherit nixpkgs-unstable hosts simple-nixos-mailserver; }
) hosts;
nixConfigurations = builtins.mapAttrs (host: helper.generateNixConfiguration host {
inherit nixpkgs-unstable hosts simple-nixos-mailserver;
}) hosts;
};
# Generate a base VM image for Proxmox with `nix build .#base-proxmox`

View file

@ -0,0 +1,17 @@
{ pkgs, ... }:
let
nftablesWithPythonOverlay = final: prev: {
nftables = (prev.nftables.override { withPython = true; });
};
pkgs-overlay = pkgs.extend nftablesWithPythonOverlay;
in
pkgs-overlay.python310Packages.buildPythonApplication {
pname = "wireguard-nat-nftables";
version = "0.0.1";
propagatedBuildInputs = with pkgs-overlay; [
python310Packages.nftables
];
src = ./src;
}

View file

@ -0,0 +1,7 @@
from distutils.core import setup
setup(
name='wireguard-nat-nftables',
version='0.0.1',
scripts=['wireguard-nat-nftables.py']
)

View file

@ -0,0 +1,92 @@
#!/usr/bin/env python3
import nftables
import json
import subprocess
import time
import sys
def main():
f = open(sys.argv[1], "r")
config = json.loads(f.read())
f.close()
interface = config["interface"]
wg_interface = config["wg_interface"]
pubkey_port_mapping = config["pubkey_port_mapping"]
nft = nftables.Nftables()
nft.set_json_output(True)
nft.set_handle_output(True)
# add nat table rules for dnat and snat masquerade
nft.cmd("add table nat")
nft.cmd("add chain nat prerouting { type nat hook prerouting priority -100; }")
nft.cmd("add chain nat postrouting { type nat hook postrouting priority 100; }")
# load current nftables rules
rc, output, error = nft.cmd("list ruleset")
if error:
print(error, file=sys.stderr)
nftables_output = json.loads(output)
add_masquerade = True
for item in nftables_output["nftables"]:
if ("rule" in item
and item["rule"]["family"] == "ip"
and item["rule"]["table"] == "nat"
and item["rule"]["chain"] == "postrouting"
and "masquerade" in item["rule"]["expr"][0]
):
add_masquerade = False
break
if add_masquerade:
nft.cmd("add rule nat postrouting masquerade")
while True:
# list WireGuard peer endpoint addresses of WireGuard VPN connection
process = subprocess.Popen(["wg", "show", wg_interface, "endpoints"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
lines = stdout.decode().split("\n")[:-1]
if stderr:
print("{}: {}".format(wg_interface, stderr.decode()), file=sys.stderr)
else:
# map destination port to IP
port_ip_mapping = {}
for line in lines:
pubkey = line.split("\t")[0]
ip = line.split("\t")[1].split(":")[0] # probably only works for IPv4
for port in pubkey_port_mapping[pubkey]:
port_ip_mapping[port] = ip
# load current nftables rules
rc, output, error = nft.cmd("list ruleset")
if error:
print(error, file=sys.stderr)
nftables_output = json.loads(output)
# update existing nftable dnat rules, if the remote IP mismatches
for item in nftables_output["nftables"]:
if "rule" in item and item["rule"]["family"] == "ip" and item["rule"]["table"] == "nat" and item["rule"]["chain"] == "prerouting":
handle = item["rule"]["handle"]
ip = item["rule"]["expr"][2]["dnat"]["addr"]
port = item["rule"]["expr"][1]["match"]["right"]
if not ip == port_ip_mapping[port]:
rc, output, error = nft.cmd("replace rule nat prerouting handle {} iif {} udp dport {} dnat to {}".format(handle, interface, port, port_ip_mapping[port]))
if error:
eprint(error)
else:
print("Changed dnat address from {} to {} for UDP port {}".format(ip, port_ip_mapping[port], port))
port_ip_mapping.pop(port)
# loop through all remaining ports and add needed dnat rules
for port in port_ip_mapping:
rc, output, error = nft.cmd("add rule nat prerouting iif {} udp dport {} dnat to {}".format(interface, port, port_ip_mapping[port]))
if error:
print(error, file=sys.stderr)
else:
print("Added dnat rule from UDP port {} to address {}".format(port, port_ip_mapping[port]))
time.sleep(10)
if __name__ == "__main__":
main()