diff --git a/config/hosts/valkyrie/default.nix b/config/hosts/valkyrie/default.nix index b8c16ea..68a1b85 100644 --- a/config/hosts/valkyrie/default.nix +++ b/config/hosts/valkyrie/default.nix @@ -4,5 +4,6 @@ ./configuration.nix ./nginx.nix ./containers/uptime-kuma + ./services.nix ]; } diff --git a/config/hosts/valkyrie/services.nix b/config/hosts/valkyrie/services.nix new file mode 100644 index 0000000..895865c --- /dev/null +++ b/config/hosts/valkyrie/services.nix @@ -0,0 +1,30 @@ +{ pkgs, ... }: +let + wireguard-nat-nftables = import ../../../pkgs/wireguard-nat-nftables pkgs; + config = pkgs.writeText "wireguard-nat-nftables-config" (builtins.toJSON { + interface = "ens3"; + wg_interface = "wg0"; + pubkey_port_mapping = { + "SJ8xCRb4hWm5EnXoV4FnwgbiaxmY2wI+xzfk+3HXERg=" = [ 51827 51829 ]; + "BbNeBTe6HwQuHPK+ZQXWYRZJJMPdS0h81n07omYyRl4=" = [ 51828 51830 ]; + "u9h+D8XZ62ABnetBRKnf6tjs+tJwM8fQ4d6ipOCLFyE=" = [ 51821 51824 ]; + }; + }); +in +{ + systemd.services.wireguard-nat-nftables = { + description = "A python script to update nftable dnat rules based on WireGuard peer IPs"; + requires = [ "wireguard-wg0.service" ]; + after = [ "wireguard-wg0.service" ]; + + script = '' + ${wireguard-nat-nftables}/bin/wireguard-nat-nftables.py ${config} + ''; + + serviceConfig = { + Type = "simple"; + User = "root"; + Group = "root"; + }; + }; +} diff --git a/flake.nix b/flake.nix index 4b25dcb..a9af2db 100644 --- a/flake.nix +++ b/flake.nix @@ -9,7 +9,8 @@ simple-nixos-mailserver.url = "gitlab:simple-nixos-mailserver/nixos-mailserver/nixos-23.05"; }; - outputs = { self, nixpkgs, nixpkgs-unstable, nixos-generators, simple-nixos-mailserver, ... }@inputs: let + outputs = { self, nixpkgs, nixpkgs-unstable, nixos-generators, simple-nixos-mailserver, ... }@inputs: + let hosts = import ./hosts.nix inputs; helper = import ./helper.nix inputs; in { @@ -32,9 +33,9 @@ } // builtins.mapAttrs (helper.generateColmenaHost) hosts; hydraJobs = { - nixConfigurations = builtins.mapAttrs ( - host: helper.generateNixConfiguration host { inherit nixpkgs-unstable hosts simple-nixos-mailserver; } - ) hosts; + nixConfigurations = builtins.mapAttrs (host: helper.generateNixConfiguration host { + inherit nixpkgs-unstable hosts simple-nixos-mailserver; + }) hosts; }; # Generate a base VM image for Proxmox with `nix build .#base-proxmox` diff --git a/pkgs/wireguard-nat-nftables/default.nix b/pkgs/wireguard-nat-nftables/default.nix new file mode 100644 index 0000000..4a75703 --- /dev/null +++ b/pkgs/wireguard-nat-nftables/default.nix @@ -0,0 +1,17 @@ +{ pkgs, ... }: +let + nftablesWithPythonOverlay = final: prev: { + nftables = (prev.nftables.override { withPython = true; }); + }; + pkgs-overlay = pkgs.extend nftablesWithPythonOverlay; +in +pkgs-overlay.python310Packages.buildPythonApplication { + pname = "wireguard-nat-nftables"; + version = "0.0.1"; + + propagatedBuildInputs = with pkgs-overlay; [ + python310Packages.nftables + ]; + + src = ./src; +} diff --git a/pkgs/wireguard-nat-nftables/src/setup.py b/pkgs/wireguard-nat-nftables/src/setup.py new file mode 100644 index 0000000..4bcc53c --- /dev/null +++ b/pkgs/wireguard-nat-nftables/src/setup.py @@ -0,0 +1,7 @@ +from distutils.core import setup + +setup( + name='wireguard-nat-nftables', + version='0.0.1', + scripts=['wireguard-nat-nftables.py'] +) diff --git a/pkgs/wireguard-nat-nftables/src/wireguard-nat-nftables.py b/pkgs/wireguard-nat-nftables/src/wireguard-nat-nftables.py new file mode 100644 index 0000000..a1c09c0 --- /dev/null +++ b/pkgs/wireguard-nat-nftables/src/wireguard-nat-nftables.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +import nftables +import json +import subprocess +import time +import sys + +def main(): + f = open(sys.argv[1], "r") + config = json.loads(f.read()) + f.close() + + interface = config["interface"] + wg_interface = config["wg_interface"] + pubkey_port_mapping = config["pubkey_port_mapping"] + + nft = nftables.Nftables() + nft.set_json_output(True) + nft.set_handle_output(True) + + # add nat table rules for dnat and snat masquerade + nft.cmd("add table nat") + nft.cmd("add chain nat prerouting { type nat hook prerouting priority -100; }") + nft.cmd("add chain nat postrouting { type nat hook postrouting priority 100; }") + + # load current nftables rules + rc, output, error = nft.cmd("list ruleset") + if error: + print(error, file=sys.stderr) + nftables_output = json.loads(output) + + add_masquerade = True + for item in nftables_output["nftables"]: + if ("rule" in item + and item["rule"]["family"] == "ip" + and item["rule"]["table"] == "nat" + and item["rule"]["chain"] == "postrouting" + and "masquerade" in item["rule"]["expr"][0] + ): + add_masquerade = False + break + if add_masquerade: + nft.cmd("add rule nat postrouting masquerade") + + while True: + # list WireGuard peer endpoint addresses of WireGuard VPN connection + process = subprocess.Popen(["wg", "show", wg_interface, "endpoints"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + lines = stdout.decode().split("\n")[:-1] + if stderr: + print("{}: {}".format(wg_interface, stderr.decode()), file=sys.stderr) + else: + # map destination port to IP + port_ip_mapping = {} + for line in lines: + pubkey = line.split("\t")[0] + ip = line.split("\t")[1].split(":")[0] # probably only works for IPv4 + for port in pubkey_port_mapping[pubkey]: + port_ip_mapping[port] = ip + + # load current nftables rules + rc, output, error = nft.cmd("list ruleset") + if error: + print(error, file=sys.stderr) + nftables_output = json.loads(output) + + # update existing nftable dnat rules, if the remote IP mismatches + for item in nftables_output["nftables"]: + if "rule" in item and item["rule"]["family"] == "ip" and item["rule"]["table"] == "nat" and item["rule"]["chain"] == "prerouting": + handle = item["rule"]["handle"] + ip = item["rule"]["expr"][2]["dnat"]["addr"] + port = item["rule"]["expr"][1]["match"]["right"] + if not ip == port_ip_mapping[port]: + rc, output, error = nft.cmd("replace rule nat prerouting handle {} iif {} udp dport {} dnat to {}".format(handle, interface, port, port_ip_mapping[port])) + if error: + eprint(error) + else: + print("Changed dnat address from {} to {} for UDP port {}".format(ip, port_ip_mapping[port], port)) + port_ip_mapping.pop(port) + + # loop through all remaining ports and add needed dnat rules + for port in port_ip_mapping: + rc, output, error = nft.cmd("add rule nat prerouting iif {} udp dport {} dnat to {}".format(interface, port, port_ip_mapping[port])) + if error: + print(error, file=sys.stderr) + else: + print("Added dnat rule from UDP port {} to address {}".format(port, port_ip_mapping[port])) + time.sleep(10) + +if __name__ == "__main__": + main()