#!/usr/bin/python3 # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # Copyright © 2017 David Sinquin # Copyright © 2018-2019 Hugo Levy-Falk """ Module for nftables set management. """ # Dependencies: python3-netaddr, python3-requests, nftables, sudo (optionnal) # # For sudo configuration, create a file in /etc/sudoers.d/ with: # " ALL = (root) NOPASSWD: /sbin/nftables" # netaddr : # - https://pypi.python.org/pypi/netaddr/ # - https://github.com/drkjam/netaddr/ # - https://netaddr.readthedocs.io/en/latest/ import logging import subprocess import re import netaddr # MAC, IPv4, IPv6 import requests from collections import Iterable from configparser import ConfigParser class ExecError(Exception): """Simple class to indicate an error in a process execution.""" pass class CommandExec: """Simple class to start a command, logging and returning errors if any.""" @staticmethod def run_check_output(command, allowed_return_codes=(0,), timeout=15): """ Run a command, logging output in case of an error. Actual timeout may be twice the given value in seconds.""" logging.debug("Command to be run: '%s'", "' '".join(command)) process = subprocess.Popen( command, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) try: result = process.communicate(timeout=timeout) return_code = process.wait(timeout=timeout) except subprocess.TimeoutExpired as err: process.kill() raise ExecError from err if return_code not in allowed_return_codes: error_message = ('Error running command: "{}", return code: {}.\n' 'Stderr:\n{}\nStdout:\n{}'.format( '" "'.join(command), return_code, *result)) logging.error(error_message) raise ExecError(error_message) return (return_code, *result) @classmethod def run(cls, *args, **kwargs): """Run a command without checking outputs.""" returncode, _, _ = cls.run_check_output(*args, **kwargs) return returncode class Parser: """Parsers for commonly used formats.""" @staticmethod def MAC(mac): """Check a MAC validity.""" return netaddr.EUI(mac, dialect=netaddr.mac_unix_expanded) @staticmethod def IPv4(ip): """Check an IPv4 validity. Args: ip: can either be a tuple (in this case returns an IPRange), a single IP address or a IP Network. """ if type(ip) in (netaddr.IPAddress, netaddr.IPNetwork, netaddr.IPRange, netaddr.IPGlob): return ip try: return netaddr.IPAddress(ip, version=4) except netaddr.core.AddrFormatError: try: return netaddr.IPNetwork(ip, version=4) except netaddr.core.AddrFormatError: begin, end = ip.split('-') return netaddr.IPRange(begin, end) @staticmethod def IPv6(ip): """Check a IPv6 validity. Args: ip: can either be a tuple (in this case returns an IPRange), a single IP address or a IP Network. """ if isinstance(ip, tuple): begin, end = ip return netaddr.IPRange(begin, end, version=6) try: return netaddr.IPAddress(ip, version=6) except ValueError: return netaddr.IPNetwork(ip, version=6) @staticmethod def protocol(protocol): """Check a protocol validity.""" if protocol in ('tcp', 'udp', 'icmp'): return protocol raise ValueError('Invalid protocol: "{}".'.format(protocol)) @staticmethod def port_number(port): """Check a port validity.""" try: port_number = int(port) if 0 <= port_number < 65536: return port_number except ValueError: begin, end = port.split('-') begin, end = int(begin), int(end) if 0 <= begin < end <= 65536: return port raise ValueError('Invalid port number: "{}".'.format(port)) class NetfilterSet: """Manage a netfilter set using nftables.""" TYPES = {'IPv4': 'ipv4_addr', 'IPv6': 'ipv6_addr', 'MAC': 'ether_addr', 'protocol': 'inet_proto', 'port': 'inet_service'} FILTERS = {'IPv4': Parser.IPv4, 'IPv6': Parser.IPv6, 'MAC': Parser.MAC, 'protocol': Parser.protocol, 'port': Parser.port_number} ADDRESS_FAMILIES = {'ip', 'ip6', 'inet', 'arp', 'bridge', 'netdev'} FLAGS = {'constant', 'interval', 'timeout'} NFT_TYPE = {'set', 'map'} # A.K.A. Really, I don't hate you, so please don't hate me... pattern = re.compile( r"table (?P\w+)+ (?P\w+) \{\n" r"\s*set (?P\w+) \{\n" r"\s*type (?P(\w+( \. )?)+)\n" r"(\s*flags (?P(\w+(, )?)+)\n)?" r"(\s*elements = \{ " r"(?P((\n?\s*)?([\w:\.-/]+( \. )?)+,?)*) " r"\n?\s*\}\n)?" r"\s*\}\n" r"\s*\}" ) def __init__(self, name, type_, # e.g.: ('MAC', 'IPv4') target_content=None, use_sudo=True, address_family='inet', # Manage both IPv4 and IPv6. table_name='filter', flags = [], ): self.name = name self.content = set() # self.type self.set_type(type_) self.filters = tuple(self.FILTERS[i] for i in self.type) self.set_flags(flags) # self.address_family self.set_address_family(address_family) self.table = table_name sudo = ["/usr/bin/sudo"] * int(bool(use_sudo)) self.nft = [*sudo, "/usr/sbin/nft"] if target_content: self._target_content = self.validate_data(target_content) else: self._target_content = set() @property def target_content(self): return self._target_content.copy() # Forbid in-place modification @target_content.setter def target_content(self, target_content): self._target_content = self.validate_data(target_content) def filter(self, elements): return (self.filters[i](element) for i, element in enumerate(elements)) def set_type(self, type_): """Check set type validity and store it along with a type checker.""" for element_type in type_: if element_type not in self.TYPES: raise ValueError('Invalid type: "{}".'.format(element_type)) self.type = type_ def set_address_family(self, address_family='ip'): """Set set addres_family, defaulting to "ip" like nftables.""" if address_family not in self.ADDRESS_FAMILIES: raise ValueError( 'Invalid address_family: "{}".'.format(address_family)) self.address_family = address_family def set_flags(self, flags_): """Check set flags validity before saving them.""" for f in flags_: if f not in self.FLAGS: raise ValueError('Invalid flag: "{}".'.format(f)) self.flags = set(flags_) or None def create_in_kernel(self): """Create the set, removing existing set if needed.""" # Delete set if it exists with wrong type current_set = self._get_raw_netfilter(parse_elements=False) if current_set is None: self._create_new_set_in_kernel() elif not self.has_type(current_set['type']): self._delete_in_kernel() self._create_new_set_in_kernel() def _delete_in_kernel(self, nft_type='set'): """Delete the set, table and set must exist.""" CommandExec.run([ *self.nft, 'delete {nft_type} {addr_family} {table} {set_}'.format( addr_family=self.address_family, table=self.table, nft_type=nft_type, set_=self.name) ]) def _create_new_set_in_kernel(self, nft_type='set'): """Create the non-existing set, creating table if needed.""" if self.flags: nft_command = 'add {nft_type} {addr_family} {table} {set_} {{ type {type_} ; flags {flags};}}'.format( nft_type=nft_type, addr_family=self.address_family, table=self.table, set_=self.name, type_=self.format_type(), flags=', '.join(self.flags) ) else: nft_command = 'add {nft_type} {addr_family} {table} {set_} {{ type {type_} ;}}'.format( nft_type=nft_type, addr_family=self.address_family, table=self.table, set_=self.name, type_=self.format_type(), ) create_set = [ *self.nft, nft_command ] return_code = CommandExec.run(create_set, allowed_return_codes=(0, 1)) if return_code == 0: return # Set creation successful. # return_code was 1, one error was detected in the rules. # Attempt to create the table first. create_table = [*self.nft, 'add table {addr_family} {table}'.format( addr_family=self.address_family, table=self.table)] CommandExec.run(create_table) CommandExec.run(create_set) def validate_data(self, set_data): """ Validate data, returning it or raising a ValueError. For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables. """ set_ = set() errors = [] for n_uplet in set_data: try: set_.add(tuple(self.filter(n_uplet))) except Exception as err: errors.append(err) if errors: raise ValueError( 'Error parsing data, encountered the folowing {} errors.\n"{}"' .format(len(errors), '",\n"'.join(map(str, errors)))) return set_ def _apply_target_content(self): """Change netfilter content to target set.""" current_set = self.get_netfilter_content() if current_set is None: raise ValueError('Cannot change "{}" netfilter set content: set ' 'do not exist in "{}" "{}".'.format( self.name, self.address_family, self.table)) to_delete = current_set - self._target_content to_add = self._target_content - current_set self._change_content(delete=to_delete, add=to_add) def _change_content(self, delete=None, add=None): todo = [tuple_ for tuple_ in (('add', add), ('delete', delete)) if tuple_[1]] for action, elements in todo: content = ', '.join(' . '.join(str(element) for element in tuple_) for tuple_ in elements) command = [ *self.nft, '{action} element {addr_family} {table} {set_} {{{content}}}' \ .format(action=action, addr_family=self.address_family, table=self.table, set_=self.name, content=content) ] CommandExec.run(command) def _get_raw_netfilter(self, parse_elements=True): """Return a dict describing the netfilter set matching self or None.""" _, stdout, _ = CommandExec.run_check_output( [*self.nft, '-nn', 'list set {addr_family} {table} {set_}'.format( addr_family=self.address_family, table=self.table, set_=self.name)], allowed_return_codes=(0, 1) # In case table do not exist ) if not stdout: return None else: netfilter_set = self._parse_netfilter_string(stdout) if netfilter_set['name'] != self.name \ or netfilter_set['address_family'] != self.address_family \ or netfilter_set['table'] != self.table \ or not self.has_type(netfilter_set['type']) \ or netfilter_set.get('flags', set()) != self.flags: raise ValueError( 'Did not get the right set, too wrong to fix. Got ' + str(netfilter_set) + ("\nExpected : " "\n\tname: \t{name} \t[{name_check}]" "\n\taddress_family: \t{family} \t[{family_check}]" "\n\ttable: \t{table} \t[{table_check}]" "\n\tflags: \t{flags} \t[{flags_check}]" "\n\ttypes: \t{types} \t[{types_check}]" ).format( name=self.name, family=self.address_family, table=self.table, flags=self.flags, types=tuple(self.TYPES[t] for t in self.type), name_check= 'v' if self.name == netfilter_set['name'] else 'x', family_check= 'v' if self.address_family == netfilter_set['address_family'] else 'x', table_check= 'v' if self.table == netfilter_set['table'] else 'x', flags_check= 'v' if self.flags == netfilter_set.get('flags', set()) else 'x', types_check= 'v' if self.has_type(netfilter_set['type']) else 'x', ) ) if parse_elements: if netfilter_set['raw_content']: netfilter_set['content'] = self.validate_data(( (element.strip() for element in n_uplet.split(' . ')) for n_uplet in netfilter_set['raw_content'].split(','))) else: netfilter_set['content'] = set() return netfilter_set @classmethod def _parse_netfilter_string(cls, set_string): """ Parse netfilter set definition and return set as dict. Do not validate content type against detected set type. Return a dict with 'name', 'address_family', 'table', 'type', 'flags', 'raw_content' keys (all strings, 'raw_content' can be None). Raise ValueError in case of unexpected syntax. """ try: values = cls.pattern.match(set_string).groupdict() except Exception as e: raise ValueError("Malformed expression :\n" + set_string) return { 'address_family': values['address_family'], 'table': values['table'], 'name': values['name'], 'type': values['type'].split(' . '), 'raw_content': values['elements'], 'flags': set(values['flags'].split(', ')) if values['flags'] else None, } def get_netfilter_content(self): """Return current set content from netfilter.""" netfilter_set = self._get_raw_netfilter(parse_elements=True) if netfilter_set is None: return None else: return netfilter_set['content'] def has_type(self, type_): """Check if some type match the set's one.""" return tuple(self.TYPES[t] for t in self.type) == tuple(type_) def manage(self): """Create set if needed and populate it with target content.""" self.create_in_kernel() self._apply_target_content() def format_type(self): return ' . '.join(self.TYPES[i] for i in self.type) class NetfilterMap(NetfilterSet): # A.K.A. Again, I don't hate you, so please don't hate me... pattern = re.compile( r"table (?P\w+)+ (?P
\w+) \{\n" r"\s*map (?P\w+) \{\n" r"\s*type (?P(\w+( \. )?)+) : (?P\w+)\n" r"(\s*flags (?P(\w+(, )?)+)\n)?" r"(\s*elements = \{ " r"(?P(\n?\s*([\w:\.-/]+( \. )?)+ : [\w:\.-/]+,?)*)" r"\n?\s*\}\n)?" r"\s*\}" r"\n\s*\}" ) def __init__(self, name, type_, type_from, target_content=None, use_sudo=True, address_family='inet', table_name='filter', flags=[] ): super().__init__(name, type_, use_sudo=use_sudo, address_family=address_family, table_name=table_name, flags=flags) self.set_type_from(type_from) self.key_filters = tuple(self.FILTERS[i] for i in self.type_from) if target_content: self._target_content = self.validate_data(target_content) else: self._target_content = {} def filter_key(self, elements): return (self.key_filters[i](element) for i, element in enumerate(elements)) def set_type_from(self, type_): """Check set type validity and store it along with a type checker.""" for element_type in type_: if element_type not in self.TYPES: raise ValueError('Invalid type: "{}".'.format(element_type)) self.type_from = type_ def _delete_in_kernel(self): """Delete the map, table and map must exist.""" super()._delete_in_kernel(nft_type='map') def _create_new_set_in_kernel(self): """Create the non-existing set, creating table if needed.""" super()._create_new_set_in_kernel(nft_type='map') def validate_data(self, dict_data): """ Validate data, returning it or raising a ValueError. For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables. """ set_ = {} errors = [] for key in dict_data: try: set_[tuple(self.filter_key(key))] = tuple(self.filter(dict_data[key])) except Exception as err: errors.append(err) if errors: raise ValueError( 'Error parsing data, encountered the folowing {} errors.\n"{}"' .format(len(errors), '",\n"'.join(map(str, errors)))) return set_ def _apply_target_content(self): """Change netfilter map content to target map.""" current_map = self.get_netfilter_content() if current_map is None: raise ValueError('Cannot change "{}" netfilter map content: map ' 'do not exist in "{}" "{}".'.format( self.name, self.address_family, self.table)) keys_to_delete = current_map.keys() - self._target_content.keys() keys_to_add = self._target_content.keys() - current_map.keys() keys_to_check = current_map.keys() & self._target_content.keys() for k in keys_to_check: if current_map[k] != self._target_content[k]: keys_to_add.add(k) keys_to_delete.add(k) to_add = {k : self._target_content[k] for k in keys_to_add} self._change_content(delete=keys_to_delete, add=to_add) def _change_content(self, delete=None, add=None): if delete: content = ', '.join(' . '.join(str(element) for element in tuple_) for tuple_ in delete) command = [ *self.nft, 'delete element {addr_family} {table} {set_} {{{content}}}' \ .format(addr_family=self.address_family, table=self.table, set_=self.name, content=content) ] CommandExec.run(command) if add: content = ', '.join( ' . '.join(str(element) for element in tuple_) + ' : ' + ' . '.join(str(element) for element in add[tuple_]) for tuple_ in add ) command = [ *self.nft, 'add element {addr_family} {table} {set_} {{{content}}}' \ .format(addr_family=self.address_family, table=self.table, set_=self.name, content=content) ] CommandExec.run(command) def _get_raw_netfilter(self, parse_elements=True): """Return a dict describing the netfilter map matching self or None.""" _, stdout, _ = CommandExec.run_check_output( [*self.nft, '-nn', 'list map {addr_family} {table} {set_}'.format( addr_family=self.address_family, table=self.table, set_=self.name)], allowed_return_codes=(0, 1) # In case table do not exist ) if not stdout: return None else: netfilter_set = self._parse_netfilter_string(stdout) if netfilter_set['name'] != self.name \ or netfilter_set['address_family'] != self.address_family \ or netfilter_set['table'] != self.table \ or not self.has_type((netfilter_set['type_from'], netfilter_set['type'])): raise ValueError('Did not get the right map, too wrong to fix.') if parse_elements: if netfilter_set['raw_content']: netfilter_set['content'] = self.validate_data({ (element.strip() for element in n_uplet.split(' : ')[0].split(' . ')): n_uplet.split(' : ')[1].strip() for n_uplet in netfilter_set['raw_content'].split(',') }) else: netfilter_set['content'] = {} return netfilter_set @classmethod def _parse_netfilter_string(cls, set_string): """ Parse netfilter map definition and return map as dict. Do not validate content type against detected map type. Return a dict with 'name', 'address_family', 'table', 'type', 'flags' 'raw_content' and 'type_from' keys (all strings, 'raw_content' and 'flags' can be None). Raise ValueError in case of unexpected syntax. """ try: values = cls.pattern.match(set_string).groupdict() except Exception as e: raise ValueError("Malformed expression :\n" + set_string) return { 'address_family': values['address_family'], 'table': values['table'], 'name': values['name'], 'type': values['type'], 'type_from': values['type_from'].split(' . '), 'raw_content': values['elements'], 'flags': values['flags'], } def has_type(self, type_): """Check if some type match the set's one.""" return tuple(self.TYPES[t] for t in self.type) == (type_[1],) and \ tuple(self.TYPES[t] for t in self.type_from) == tuple(type_[0]) def format_type(self): return ' . '.join(self.TYPES[i] for i in self.type_from) + ' : ' + ' . '.join(self.TYPES[i] for i in self.type) def filter(self, elements): return (self.filters[0](elements),) def get_ip_iterable_from_str(ip): try: ret = netaddr.IPGlob(ip) except netaddr.core.AddrFormatError: try: ret = netaddr.IPNetwork(ip) except netaddr.core.AddrFormatError: begin,end = ip.split('-') ret = netaddr.IPRange(begin,end) return ret class NAT: PROTOCOLS = ( 'tcp', 'udp', 'icmp' ) def __init__(self, name, range_in, range_out, first_port, last_port, use_sudo=True ): """Creates a NAT object for the given range of IP-Addresses. Args: name: name of the sets range_in: an IPRange with the private IP address range_out: an IPRange with the public IP address first_port: the first port used for the nat last_port: the last port used for the nat use_sudo: Should the nft commands be run in sudo ? """ assert 0 <= first_port < last_port < 65536, (name + ": Your first_port " "is lower than your last_port") self.name = name self.range_in = get_ip_iterable_from_str(range_in) self.range_out = get_ip_iterable_from_str(range_out) self.first_port = first_port self.last_port = last_port self.nb_private_by_public = self.range_in.size // self.range_out.size + 1 sudo = ["/usr/bin/sudo"] * int(bool(use_sudo)) self.nft = [*sudo, "/usr/sbin/nft"] def create_nat_rule(self, grp, ports): """Create a nat rules in the form : ip saddr @_nat_port_ ip protocol tcp snat ip saddr map @_nat_address : ip saddr @_nat_port_ ip protocol udp snat ip saddr map @_nat_address : Args: grp: The name of the group ports: The port range (str) """ for protocol in self.PROTOCOLS: CommandExec.run([ *self.nft, "add rule ip nat {name}_nat ip saddr @{name}_nat_port_{grp} ip protocol {protocol} snat ip saddr map @{name}_nat_address : {ports}".format( protocol=protocol, name=self.name, grp=grp, ports=ports ) ]) def manage(self): """Creates the port sets, ip map and rules """ ips = {} ports = [ set() for i in range(self.nb_private_by_public) ] port_range = lambda i : '-'.join([ str(int(self.first_port + i/self.nb_private_by_public * (self.last_port - self.first_port))), str(int(self.first_port + (i+1)/self.nb_private_by_public * (self.last_port - self.first_port)-1)) ]) nat_log = "" for ip_out, ip in zip( self.range_out, range(self.range_in.first, self.range_in.last, self.nb_private_by_public) ): range_size = self.nb_private_by_public if int(ip + self.nb_private_by_public) <= self.range_in.last else (self.range_in.last - ip) ips[(netaddr.IPRange(ip, ip+range_size-1),)] = ip_out for i in range(range_size): ip_in = netaddr.IPAddress(ip+i) ports[i].add((ip_in,)) nat_log += '\t'.join((str(ip_out), port_range(i), str(ip_in), '\n')) print(nat_log) ip_map = NetfilterMap( target_content=ips, type_=('IPv4',), name=self.name+'_nat_address', table_name='nat', flags=('interval',), type_from=('IPv4',), address_family='ip', ) ip_map.manage() for i, grp in enumerate(ports): grp_set = NetfilterSet( name=self.name+'_nat_port_'+str(i), target_content=grp, type_=('IPv4',), table_name='nat', address_family='ip', ) grp_set.manage() self.create_nat_rule( str(i), port_range(i) ) return nat_log class Firewall: """Manages the firewall using nftables.""" @staticmethod def manage_sets(sets, address_family=None, table=None, use_sudo=None): CONFIG = ConfigParser() CONFIG.read('config.ini') address_family = address_family or CONFIG['address_family'] or 'inet' table = table or CONFIG['table'] or 'filter' sudo = use_sudo or (use_sudo is None and CONFIG['use_sudo']) for set_ in sets: NetfilterSet( name=set_['name'], type_=set_['type'], target_content=set_['content'], use_sudo=sudo, address_family=address_family, table_name=table).manage()