diff --git a/data/Dockerfiles/netfilter/server.py b/data/Dockerfiles/netfilter/server.py index 1ccc150e..c206585a 100644 --- a/data/Dockerfiles/netfilter/server.py +++ b/data/Dockerfiles/netfilter/server.py @@ -7,6 +7,7 @@ import time import atexit import signal import ipaddress +import nftables from collections import Counter from random import randint from threading import Thread @@ -43,6 +44,10 @@ quit_now = False exit_code = 0 lock = Lock() +backend = sys.argv[1] +nft = None +nft_chain_names = {} + def log(priority, message): tolog = {} tolog['time'] = int(round(time.time())) @@ -60,6 +65,17 @@ def logCrit(message): def logInfo(message): log('info', message) +#nftables +if backend == 'nftables': + logInfo('Using Nftables backend') + nft = nftables.Nftables() + nft.set_json_output(True) + nft.set_handle_output(True) + nft_chain_names = {'ip': {'filter': {'input': '', 'forward': ''}, 'nat': {'postrouting': ''} }, + 'ip6': {'filter': {'input': '', 'forward': ''}, 'nat': {'postrouting': ''} } } +else: + logInfo('Using Iptables backend') + def refreshF2boptions(): global f2boptions global quit_now @@ -115,33 +131,472 @@ def refreshF2bregex(): if r.exists('F2B_LOG'): r.rename('F2B_LOG', 'NETFILTER_LOG') +# Nftables functions +def nft_exec_dict(query: dict): + global nft + + if not query: return False + + rc, output, error = nft.json_cmd(query) + if rc != 0: + #logCrit(f"Nftables Error: {error}") + return False + + # Prevent returning False or empty string on commands that do not produce output + if rc == 0 and len(output) == 0: + return True + + return output + +def get_base_dict(): + return {'nftables': [{ 'metainfo': { 'json_schema_version': 1} } ] } + +def search_current_chains(): + global nft_chain_names + nft_chain_priority = {'ip': {'filter': {'input': 1, 'forward': 1}, 'nat': {'postrouting': 111} }, + 'ip6': {'filter': {'input': 1, 'forward': 1}, 'nat': {'postrouting': 111} } } + + # Command: 'nft list chains' + _list_opts = dict(chains='null') + _list = dict(list=_list_opts) + command = get_base_dict() + command['nftables'].append(_list) + kernel_ruleset = nft_exec_dict(command) + if kernel_ruleset: + for object in kernel_ruleset['nftables']: + chain = object.get("chain") + if not chain: + continue + + _family = chain['family'] + _table = chain['table'] + + hook = chain.get("hook") + if not hook or hook not in nft_chain_names[_family][_table]: + continue + + _hook = chain['hook'] + + priority = chain.get("prio") + if priority is None: + continue + + if priority < nft_chain_priority[_family][_table][_hook]: + # at this point, we know the chain has: + # hook and priority set + # and it has the lowest priority + nft_chain_priority[_family][_table][_hook] = priority + nft_chain_names[_family][_table][_hook] = chain['name'] + +def search_for_chain(kernel_ruleset: dict, chain_name: str): + found = False + for object in kernel_ruleset["nftables"]: + chain = object.get("chain") + if not chain: + continue + ch_name = chain.get("name") + if ch_name == chain_name: + found = True + break + return found + +def get_chain_dict(_family: str, _name: str): + # nft (add | create) chain [] + _chain_opts = dict(family = _family, + table = 'filter', + name = _name ) + + _chain = dict(chain = _chain_opts) + _add = dict(add = _chain) + final_chain = get_base_dict() + final_chain["nftables"].append(_add) + return final_chain + +def get_mailcow_jump_rule_dict(_family: str, _chain: str): + _jump_rule = get_base_dict() + _expr_opt=[] + _expr_counter = dict(family = _family, table = 'filter', packets = 0, bytes = 0) + _counter_dict = dict(counter = _expr_counter) + _expr_opt.append(_counter_dict) + + _expr_jump = dict(target = 'MAILCOW') + _jump_opts = dict(jump = _expr_jump) + + _expr_opt.append(_jump_opts) + + _rule_params = dict(family = _family, + table = 'filter', + chain = _chain, + expr = _expr_opt, + comment = "mailcow" + ) + _opts_rule = dict(rule = _rule_params) + _add_rule = dict(insert = _opts_rule) + + _jump_rule["nftables"].append(_add_rule) + + return _jump_rule + +def insert_mailcow_chains(_family: str): + nft_input_chain = nft_chain_names[_family]['filter']['input'] + nft_forward_chain = nft_chain_names[_family]['filter']['forward'] + # Command: 'nft list table filter' + _table_opts = dict(family=_family, name='filter') + _table = dict(table=_table_opts) + _list = dict(list=_table) + command = get_base_dict() + command['nftables'].append(_list) + kernel_ruleset = nft_exec_dict(command) + if kernel_ruleset: + # MAILCOW chain + if not search_for_chain(kernel_ruleset, "MAILCOW"): + cadena = get_chain_dict(_family, "MAILCOW") + if(nft_exec_dict(cadena)): + logInfo(f"MAILCOW {_family} chain created successfully.") + + input_jump_found, forward_jump_found = False, False + + for object in kernel_ruleset["nftables"]: + if not object.get("rule"): + continue + + rule = object["rule"] + if rule["chain"] == nft_input_chain: + if rule.get("comment") and rule["comment"] == "mailcow": + input_jump_found = True + if rule["chain"] == nft_forward_chain: + if rule.get("comment") and rule["comment"] == "mailcow": + forward_jump_found = True + + if not input_jump_found and nft_input_chain: + command = get_mailcow_jump_rule_dict(_family, nft_input_chain) + nft_exec_dict(command) + + if not forward_jump_found and nft_forward_chain: + command = get_mailcow_jump_rule_dict(_family, nft_forward_chain) + nft_exec_dict(command) + +def delete_nat_rule(_family:str, _chain: str, _handle:str): + delete_command = get_base_dict() + _rule_opts = dict(family = _family, + table = 'nat', + chain = _chain, + handle = _handle + ) + _rule = dict(rule = _rule_opts) + _delete = dict(delete = _rule) + delete_command["nftables"].append(_delete) + + return nft_exec_dict(delete_command) + +def snat_rule(_family: str, snat_target: str): + chain_name = nft_chain_names[_family]['nat']['postrouting'] + + # no postrouting chain, may occur if docker has ipv6 disabled. + if not chain_name: return + + # Command: nft list chain nat + _chain_opts = dict(family=_family, table='nat', name=chain_name) + _chain = dict(chain=_chain_opts) + _list = dict(list=_chain) + command = get_base_dict() + command['nftables'].append(_list) + kernel_ruleset = nft_exec_dict(command) + if not kernel_ruleset: + return + + rule_position = 0 + rule_handle = None + rule_found = False + for object in kernel_ruleset["nftables"]: + if not object.get("rule"): + continue + + rule = object["rule"] + if not rule.get("comment") or not rule["comment"] == "mailcow": + rule_position +=1 + continue + else: + rule_found = True + rule_handle = rule["handle"] + break + + if _family == "ip": + source_address = os.getenv('IPV4_NETWORK', '172.22.1') + '.0/24' + else: + source_address = os.getenv('IPV6_NETWORK', 'fd4d:6169:6c63:6f77::/64') + + tmp_addr = re.split(r'/', source_address) + dest_ip = tmp_addr[0] + dest_len = int(tmp_addr[1]) + + if rule_found: + saddr_ip = rule["expr"][0]["match"]["right"]["prefix"]["addr"] + saddr_len = int(rule["expr"][0]["match"]["right"]["prefix"]["len"]) + + daddr_ip = rule["expr"][1]["match"]["right"]["prefix"]["addr"] + daddr_len = int(rule["expr"][1]["match"]["right"]["prefix"]["len"]) + match = all(( + saddr_ip == dest_ip, + saddr_len == dest_len, + daddr_ip == dest_ip, + daddr_len == dest_len + )) + try: + if rule_position == 0: + if not match: + # Position 0 , it is a mailcow rule , but it does not have the same parameters + if delete_nat_rule(_family, chain_name, rule_handle): + logInfo(f'Remove rule for source network {saddr_ip}/{saddr_len} to SNAT target {snat_target} from POSTROUTING chain with handle {rule_handle}') + + else: + # Position > 0 and is mailcow rule + if delete_nat_rule(_family, chain_name, rule_handle): + logInfo(f'Remove rule for source network {saddr_ip}/{saddr_len} to SNAT target {snat_target} from POSTROUTING chain with handle {rule_handle}') + except: + logCrit(f"Error running SNAT on {_family}, retrying..." ) + else: + # rule not found + json_command = get_base_dict() + try: + payload_fields = dict(protocol = _family, field = "saddr") + payload_dict = dict(payload = payload_fields) + payload_fields2 = dict(protocol = _family, field = "daddr") + payload_dict2 = dict(payload = payload_fields2) + prefix_fields=dict(addr = dest_ip, len = int(dest_len)) + prefix_dict=dict(prefix = prefix_fields) + + snat_addr = dict(addr = snat_target) + snat_dict = dict(snat = snat_addr) + + expr_counter = dict(family = _family, table = "nat", packets = 0, bytes = 0) + counter_dict = dict(counter = expr_counter) + + match_fields1 = dict(op = "==", left = payload_dict, right = prefix_dict) + match_dict1 = dict(match = match_fields1) + + match_fields2 = dict(op = "!=", left = payload_dict2, right = prefix_dict ) + match_dict2 = dict(match = match_fields2) + expr_list = [ + match_dict1, + match_dict2, + counter_dict, + snat_dict + ] + rule_fields = dict(family = _family, + table = "nat", + chain = chain_name, + comment = "mailcow", + expr = expr_list + ) + rule_dict = dict(rule = rule_fields) + insert_dict = dict(insert = rule_dict) + json_command["nftables"].append(insert_dict) + if(nft_exec_dict(json_command)): + logInfo(f"Added {_family} POSTROUTING rule for source network {dest_ip} to {snat_target}") + except: + logCrit(f"Error running SNAT on {_family}, retrying...") + +def get_chain_handle(_family: str, _table: str, chain_name: str): + chain_handle = None + # Command: 'nft list chains {family}' + _chain_opts = dict(family=_family) + _chain = dict(chains=_chain_opts) + _list = dict(list=_chain) + command = get_base_dict() + command['nftables'].append(_list) + kernel_ruleset = nft_exec_dict(command) + if kernel_ruleset: + for object in kernel_ruleset["nftables"]: + if not object.get("chain"): + continue + chain = object["chain"] + if chain["family"] == _family and chain["table"] == _table and chain["name"] == chain_name: + chain_handle = chain["handle"] + break + return chain_handle + +def get_rules_handle(_family: str, _table: str, chain_name: str): + rule_handle = [] + # Command: 'nft list chain {family} {table} {chain_name}' + _chain_opts = dict(family=_family, table=_table, name=chain_name) + _chain = dict(chain=_chain_opts) + _list = dict(list=_chain) + command = get_base_dict() + command['nftables'].append(_list) + + kernel_ruleset = nft_exec_dict(command) + if kernel_ruleset: + for object in kernel_ruleset["nftables"]: + if not object.get("rule"): + continue + + rule = object["rule"] + if rule["family"] == _family and rule["table"] == _table and rule["chain"] == chain_name: + if rule.get("comment") and rule["comment"] == "mailcow": + rule_handle.append(rule["handle"]) + return rule_handle + +def get_ban_ip_dict(ipaddr: str, _family: str): + json_command = get_base_dict() + + expr_opt = [] + if re.search(r'/', ipaddr): + divided = re.split(r'/', ipaddr) + prefix_dict=dict(addr = divided[0], + len = int(divided[1]) ) + right_dict = dict(prefix = prefix_dict) + else: + right_dict = ipaddr + + payload_dict = dict(protocol = _family, field="saddr" ) + left_dict = dict(payload = payload_dict) + match_dict = dict(op = "==", left = left_dict, right = right_dict ) + match_base = dict(match = match_dict) + expr_opt.append(match_base) + + expr_counter = dict(family = _family, table = "filter", packets = 0, bytes = 0) + counter_dict = dict(counter = expr_counter) + expr_opt.append(counter_dict) + + drop_dict = dict(drop = "null") + expr_opt.append(drop_dict) + + rule_dict = dict(family = _family, table = "filter", chain = "MAILCOW", expr = expr_opt) + + base_rule = dict(rule = rule_dict) + base_dict = dict(insert = base_rule) + json_command["nftables"].append(base_dict) + + return json_command + +def get_unban_ip_dict(ipaddr:str, _family: str): + json_command = get_base_dict() + # Command: 'nft list chain {s_family} filter MAILCOW' + _chain_opts = dict(family=_family, table='filter', name='MAILCOW') + _chain = dict(chain=_chain_opts) + _list = dict(list=_chain) + command = get_base_dict() + command['nftables'].append(_list) + kernel_ruleset = nft_exec_dict(command) + rule_handle = None + if kernel_ruleset: + for object in kernel_ruleset["nftables"]: + if not object.get("rule"): + continue + + rule = object["rule"]["expr"][0]["match"] + left_opt = rule["left"]["payload"] + if not left_opt["protocol"] == _family: + continue + if not left_opt["field"] =="saddr": + continue + + # ip currently banned + rule_right = rule["right"] + if isinstance(rule_right, dict): + current_rule_ip = rule_right["prefix"]["addr"] + current_rule_len = int(rule_right["prefix"]["len"]) + else: + current_rule_ip = rule_right + current_rule_len = 32 if _family == 'ip' else 128 + + # ip to ban + if re.search(r'/', ipaddr): + divided = re.split(r'/', ipaddr) + candidate_ip = divided[0] + candidate_len = int(divided[1]) + else: + candidate_ip = ipaddr + candidate_len = 32 if _family == 'ip' else 128 + + if all((current_rule_ip == candidate_ip, + current_rule_len and candidate_len, + current_rule_len == candidate_len )): + rule_handle = object["rule"]["handle"] + break + + if rule_handle is not None: + mailcow_rule = dict(family = _family, table = "filter", chain = "MAILCOW", handle = rule_handle) + del_rule = dict(rule = mailcow_rule) + delete_rule=dict(delete = del_rule) + json_command["nftables"].append(delete_rule) + else: + return False + + return json_command + +def check_mailcow_chains(family: str, chain: str): + position = 0 + rule_found = False + chain_name = nft_chain_names[family]['filter'][chain] + + if not chain_name: return None + + _chain_opts = dict(family=family, table='filter', name=chain_name) + _chain = dict(chain=_chain_opts) + _list = dict(list=_chain) + command = get_base_dict() + command['nftables'].append(_list) + kernel_ruleset = nft_exec_dict(command) + if kernel_ruleset: + for object in kernel_ruleset["nftables"]: + if not object.get("rule"): + continue + rule = object["rule"] + if rule.get("comment") and rule["comment"] == "mailcow": + rule_found = True + break + + position+=1 + + return position if rule_found else False + +# Mailcow def mailcowChainOrder(): global lock global quit_now global exit_code + while not quit_now: time.sleep(10) with lock: - filter4_table = iptc.Table(iptc.Table.FILTER) - filter6_table = iptc.Table6(iptc.Table6.FILTER) - filter4_table.refresh() - filter6_table.refresh() - for f in [filter4_table, filter6_table]: - forward_chain = iptc.Chain(f, 'FORWARD') - input_chain = iptc.Chain(f, 'INPUT') - for chain in [forward_chain, input_chain]: - target_found = False - for position, item in enumerate(chain.rules): - if item.target.name == 'MAILCOW': - target_found = True - if position > 2: - logCrit('Error in %s chain order: MAILCOW on position %d, restarting container' % (chain.name, position)) - quit_now = True - exit_code = 2 - if not target_found: - logCrit('Error in %s chain: MAILCOW target not found, restarting container' % (chain.name)) - quit_now = True - exit_code = 2 + if backend == 'iptables': + filter4_table = iptc.Table(iptc.Table.FILTER) + filter6_table = iptc.Table6(iptc.Table6.FILTER) + filter4_table.refresh() + filter6_table.refresh() + for f in [filter4_table, filter6_table]: + forward_chain = iptc.Chain(f, 'FORWARD') + input_chain = iptc.Chain(f, 'INPUT') + for chain in [forward_chain, input_chain]: + target_found = False + for position, item in enumerate(chain.rules): + if item.target.name == 'MAILCOW': + target_found = True + if position > 2: + logCrit('Error in %s chain order: MAILCOW on position %d, restarting container' % (chain.name, position)) + quit_now = True + exit_code = 2 + if not target_found: + logCrit('Error in %s chain: MAILCOW target not found, restarting container' % (chain.name)) + quit_now = True + exit_code = 2 + else: + for family in ["ip", "ip6"]: + for chain in ['input', 'forward']: + chain_position = check_mailcow_chains(family, chain) + if chain_position is None: continue + + if chain_position is False: + logCrit('Error in %s %s chain: MAILCOW target not found, restarting container' % (family, chain)) + quit_now = True + exit_code = 2 + + if chain_position > 0: + logCrit('Error in %s %s chain order: MAILCOW on position %d, restarting container' % (family, chain, chain_position)) + quit_now = True + exit_code = 2 def ban(address): global lock @@ -190,22 +645,31 @@ def ban(address): logCrit('Banning %s for %d minutes' % (net, BAN_TIME / 60)) if type(ip) is ipaddress.IPv4Address: with lock: - chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') - rule = iptc.Rule() - rule.src = net - target = iptc.Target(rule, "REJECT") - rule.target = target - if rule not in chain.rules: - chain.insert_rule(rule) + if backend == 'iptables': + chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') + rule = iptc.Rule() + rule.src = net + target = iptc.Target(rule, "REJECT") + rule.target = target + if rule not in chain.rules: + chain.insert_rule(rule) + else: + ban_dict = get_ban_ip_dict(net, "ip") + nft_exec_dict(ban_dict) else: with lock: - chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') - rule = iptc.Rule6() - rule.src = net - target = iptc.Target(rule, "REJECT") - rule.target = target - if rule not in chain.rules: - chain.insert_rule(rule) + if backend == 'iptables': + chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') + rule = iptc.Rule6() + rule.src = net + target = iptc.Target(rule, "REJECT") + rule.target = target + if rule not in chain.rules: + chain.insert_rule(rule) + else: + ban_dict = get_ban_ip_dict(net, "ip6") + nft_exec_dict(ban_dict) + r.hset('F2B_ACTIVE_BANS', '%s' % net, cur_time + BAN_TIME) else: logWarn('%d more attempts in the next %d seconds until %s is banned' % (MAX_ATTEMPTS - bans[net]['attempts'], RETRY_WINDOW, net)) @@ -219,22 +683,35 @@ def unban(net): logInfo('Unbanning %s' % net) if type(ipaddress.ip_network(net)) is ipaddress.IPv4Network: with lock: - chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') - rule = iptc.Rule() - rule.src = net - target = iptc.Target(rule, "REJECT") - rule.target = target - if rule in chain.rules: - chain.delete_rule(rule) + if backend == 'iptables': + chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') + rule = iptc.Rule() + rule.src = net + target = iptc.Target(rule, "REJECT") + rule.target = target + if rule in chain.rules: + chain.delete_rule(rule) + else: + dict_unban = get_unban_ip_dict(net, "ip") + if dict_unban: + if(nft_exec_dict(dict_unban)): + logInfo(f"Unbanned ip: {net}") else: with lock: - chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') - rule = iptc.Rule6() - rule.src = net - target = iptc.Target(rule, "REJECT") - rule.target = target - if rule in chain.rules: - chain.delete_rule(rule) + if backend == 'iptables': + chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') + rule = iptc.Rule6() + rule.src = net + target = iptc.Target(rule, "REJECT") + rule.target = target + if rule in chain.rules: + chain.delete_rule(rule) + else: + dict_unban = get_unban_ip_dict(net, "ip6") + if dict_unban: + if(nft_exec_dict(dict_unban)): + logInfo(f"Unbanned ip6: {net}") + r.hdel('F2B_ACTIVE_BANS', '%s' % net) r.hdel('F2B_QUEUE_UNBAN', '%s' % net) if net in bans: @@ -244,34 +721,60 @@ def permBan(net, unban=False): global lock if type(ipaddress.ip_network(net, strict=False)) is ipaddress.IPv4Network: with lock: - chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') - rule = iptc.Rule() - rule.src = net - target = iptc.Target(rule, "REJECT") - rule.target = target - if rule not in chain.rules and not unban: - logCrit('Add host/network %s to blacklist' % net) - chain.insert_rule(rule) - r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time()))) - elif rule in chain.rules and unban: - logCrit('Remove host/network %s from blacklist' % net) - chain.delete_rule(rule) - r.hdel('F2B_PERM_BANS', '%s' % net) + if backend == 'iptables': + chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') + rule = iptc.Rule() + rule.src = net + target = iptc.Target(rule, "REJECT") + rule.target = target + if rule not in chain.rules and not unban: + logCrit('Add host/network %s to blacklist' % net) + chain.insert_rule(rule) + r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time()))) + elif rule in chain.rules and unban: + logCrit('Remove host/network %s from blacklist' % net) + chain.delete_rule(rule) + r.hdel('F2B_PERM_BANS', '%s' % net) + else: + if not unban: + ban_dict = get_ban_ip_dict(net, "ip") + if(nft_exec_dict(ban_dict)): + logCrit('Add host/network %s to blacklist' % net) + r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time()))) + elif unban: + dict_unban = get_unban_ip_dict(net, "ip") + if dict_unban: + if(nft_exec_dict(dict_unban)): + logCrit('Remove host/network %s from blacklist' % net) + r.hdel('F2B_PERM_BANS', '%s' % net) else: with lock: - chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') - rule = iptc.Rule6() - rule.src = net - target = iptc.Target(rule, "REJECT") - rule.target = target - if rule not in chain.rules and not unban: - logCrit('Add host/network %s to blacklist' % net) - chain.insert_rule(rule) - r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time()))) - elif rule in chain.rules and unban: - logCrit('Remove host/network %s from blacklist' % net) - chain.delete_rule(rule) - r.hdel('F2B_PERM_BANS', '%s' % net) + if backend == 'iptables': + chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') + rule = iptc.Rule6() + rule.src = net + target = iptc.Target(rule, "REJECT") + rule.target = target + if rule not in chain.rules and not unban: + logCrit('Add host/network %s to blacklist' % net) + chain.insert_rule(rule) + r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time()))) + elif rule in chain.rules and unban: + logCrit('Remove host/network %s from blacklist' % net) + chain.delete_rule(rule) + r.hdel('F2B_PERM_BANS', '%s' % net) + else: + if not unban: + ban_dict = get_ban_ip_dict(net, "ip6") + if(nft_exec_dict(ban_dict)): + logCrit('Add host/network %s to blacklist' % net) + r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time()))) + elif unban: + dict_unban = get_unban_ip_dict(net, "ip6") + if dict_unban: + if(nft_exec_dict(dict_unban)): + logCrit('Remove host/network %s from blacklist' % net) + r.hdel('F2B_PERM_BANS', '%s' % net) def quit(signum, frame): global quit_now @@ -283,26 +786,78 @@ def clear(): for net in bans.copy(): unban(net) with lock: - filter4_table = iptc.Table(iptc.Table.FILTER) - filter6_table = iptc.Table6(iptc.Table6.FILTER) - for filter_table in [filter4_table, filter6_table]: - filter_table.autocommit = False - forward_chain = iptc.Chain(filter_table, "FORWARD") - input_chain = iptc.Chain(filter_table, "INPUT") - mailcow_chain = iptc.Chain(filter_table, "MAILCOW") - if mailcow_chain in filter_table.chains: - for rule in mailcow_chain.rules: - mailcow_chain.delete_rule(rule) - for rule in forward_chain.rules: - if rule.target.name == 'MAILCOW': - forward_chain.delete_rule(rule) - for rule in input_chain.rules: - if rule.target.name == 'MAILCOW': - input_chain.delete_rule(rule) - filter_table.delete_chain("MAILCOW") - filter_table.commit() - filter_table.refresh() - filter_table.autocommit = True + if backend == 'iptables': + filter4_table = iptc.Table(iptc.Table.FILTER) + filter6_table = iptc.Table6(iptc.Table6.FILTER) + for filter_table in [filter4_table, filter6_table]: + filter_table.autocommit = False + forward_chain = iptc.Chain(filter_table, "FORWARD") + input_chain = iptc.Chain(filter_table, "INPUT") + mailcow_chain = iptc.Chain(filter_table, "MAILCOW") + if mailcow_chain in filter_table.chains: + for rule in mailcow_chain.rules: + mailcow_chain.delete_rule(rule) + for rule in forward_chain.rules: + if rule.target.name == 'MAILCOW': + forward_chain.delete_rule(rule) + for rule in input_chain.rules: + if rule.target.name == 'MAILCOW': + input_chain.delete_rule(rule) + filter_table.delete_chain("MAILCOW") + filter_table.commit() + filter_table.refresh() + filter_table.autocommit = True + else: + for _family in ["ip", "ip6"]: + is_empty_dict = True + json_command = get_base_dict() + chain_handle = get_chain_handle(_family, "filter", "MAILCOW") + # if no handle, the chain doesn't exists + if chain_handle is not None: + is_empty_dict = False + # flush chain MAILCOW + mailcow_chain = dict(family=_family, table="filter", name="MAILCOW") + mc_chain_base = dict(chain=mailcow_chain) + flush_chain = dict(flush=mc_chain_base) + json_command["nftables"].append(flush_chain) + + # remove rule in forward chain + # remove rule in input chain + chains_family = [nft_chain_names[_family]['filter']['input'], + nft_chain_names[_family]['filter']['forward'] ] + + for chain_base in chains_family: + if not chain_base: continue + + rules_handle = get_rules_handle(_family, "filter", chain_base) + if rules_handle is not None: + for r_handle in rules_handle: + is_empty_dict = False + mailcow_rule = dict(family=_family, + table="filter", + chain=chain_base, + handle=r_handle + ) + del_rule = dict(rule=mailcow_rule) + delete_rules=dict(delete=del_rule) + json_command["nftables"].append(delete_rules) + + # remove chain MAILCOW + # after delete all rules referencing this chain + if chain_handle is not None: + mc_chain_handle = dict(family=_family, + table="filter", + name="MAILCOW", + handle=chain_handle + ) + del_chain=dict(chain=mc_chain_handle) + delete_chain = dict(delete=del_chain) + json_command["nftables"].append(delete_chain) + + if is_empty_dict == False: + if(nft_exec_dict(json_command)): + logInfo(f"Clear completed: {_family}") + r.delete('F2B_ACTIVE_BANS') r.delete('F2B_PERM_BANS') pubsub.unsubscribe() @@ -354,28 +909,31 @@ def snat4(snat_target): time.sleep(10) with lock: try: - table = iptc.Table('nat') - table.refresh() - chain = iptc.Chain(table, 'POSTROUTING') - table.autocommit = False - new_rule = get_snat4_rule() - for position, rule in enumerate(chain.rules): - match = all(( - new_rule.get_src() == rule.get_src(), - new_rule.get_dst() == rule.get_dst(), - new_rule.target.parameters == rule.target.parameters, - new_rule.target.name == rule.target.name - )) - if position == 0: - if not match: - logInfo(f'Added POSTROUTING rule for source network {new_rule.src} to SNAT target {snat_target}') - chain.insert_rule(new_rule) - else: - if match: - logInfo(f'Remove rule for source network {new_rule.src} to SNAT target {snat_target} from POSTROUTING chain at position {position}') - chain.delete_rule(rule) - table.commit() - table.autocommit = True + if backend == 'iptables': + table = iptc.Table('nat') + table.refresh() + chain = iptc.Chain(table, 'POSTROUTING') + table.autocommit = False + new_rule = get_snat4_rule() + for position, rule in enumerate(chain.rules): + match = all(( + new_rule.get_src() == rule.get_src(), + new_rule.get_dst() == rule.get_dst(), + new_rule.target.parameters == rule.target.parameters, + new_rule.target.name == rule.target.name + )) + if position == 0: + if not match: + logInfo(f'Added POSTROUTING rule for source network {new_rule.src} to SNAT target {snat_target}') + chain.insert_rule(new_rule) + else: + if match: + logInfo(f'Remove rule for source network {new_rule.src} to SNAT target {snat_target} from POSTROUTING chain at position {position}') + chain.delete_rule(rule) + table.commit() + table.autocommit = True + else: + snat_rule("ip", snat_target) except: print('Error running SNAT4, retrying...') @@ -395,21 +953,31 @@ def snat6(snat_target): time.sleep(10) with lock: try: - table = iptc.Table6('nat') - table.refresh() - chain = iptc.Chain(table, 'POSTROUTING') - table.autocommit = False - if get_snat6_rule() not in chain.rules: - logInfo('Added POSTROUTING rule for source network %s to SNAT target %s' % (get_snat6_rule().src, snat_target)) - chain.insert_rule(get_snat6_rule()) + if backend == 'iptables': + table = iptc.Table6('nat') + table.refresh() + chain = iptc.Chain(table, 'POSTROUTING') + table.autocommit = False + new_rule = get_snat6_rule() + for position, rule in enumerate(chain.rules): + match = all(( + new_rule.get_src() == rule.get_src(), + new_rule.get_dst() == rule.get_dst(), + new_rule.target.parameters == rule.target.parameters, + new_rule.target.name == rule.target.name + )) + if position == 0: + if not match: + logInfo(f'Added POSTROUTING rule for source network {new_rule.src} to SNAT target {snat_target}') + chain.insert_rule(new_rule) + else: + if match: + logInfo(f'Remove rule for source network {new_rule.src} to SNAT target {snat_target} from POSTROUTING chain at position {position}') + chain.delete_rule(rule) table.commit() + table.autocommit = True else: - for position, item in enumerate(chain.rules): - if item == get_snat6_rule(): - if position != 0: - chain.delete_rule(get_snat6_rule()) - table.commit() - table.autocommit = True + snat_rule("ip6", snat_target) except: print('Error running SNAT6, retrying...') @@ -435,7 +1003,6 @@ def isIpNetwork(address): return False return True - def genNetworkList(list): resolver = dns.resolver.Resolver() hostnames = [] @@ -504,33 +1071,40 @@ def blacklistUpdate(): def initChain(): # Is called before threads start, no locking print("Initializing mailcow netfilter chain") - # IPv4 - if not iptc.Chain(iptc.Table(iptc.Table.FILTER), "MAILCOW") in iptc.Table(iptc.Table.FILTER).chains: - iptc.Table(iptc.Table.FILTER).create_chain("MAILCOW") - for c in ['FORWARD', 'INPUT']: - chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), c) - rule = iptc.Rule() - rule.src = '0.0.0.0/0' - rule.dst = '0.0.0.0/0' - target = iptc.Target(rule, "MAILCOW") - rule.target = target - if rule not in chain.rules: - chain.insert_rule(rule) - # IPv6 - if not iptc.Chain(iptc.Table6(iptc.Table6.FILTER), "MAILCOW") in iptc.Table6(iptc.Table6.FILTER).chains: - iptc.Table6(iptc.Table6.FILTER).create_chain("MAILCOW") - for c in ['FORWARD', 'INPUT']: - chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), c) - rule = iptc.Rule6() - rule.src = '::/0' - rule.dst = '::/0' - target = iptc.Target(rule, "MAILCOW") - rule.target = target - if rule not in chain.rules: - chain.insert_rule(rule) + if backend == 'iptables': + # IPv4 + if not iptc.Chain(iptc.Table(iptc.Table.FILTER), "MAILCOW") in iptc.Table(iptc.Table.FILTER).chains: + iptc.Table(iptc.Table.FILTER).create_chain("MAILCOW") + for c in ['FORWARD', 'INPUT']: + chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), c) + rule = iptc.Rule() + rule.src = '0.0.0.0/0' + rule.dst = '0.0.0.0/0' + target = iptc.Target(rule, "MAILCOW") + rule.target = target + if rule not in chain.rules: + chain.insert_rule(rule) + # IPv6 + if not iptc.Chain(iptc.Table6(iptc.Table6.FILTER), "MAILCOW") in iptc.Table6(iptc.Table6.FILTER).chains: + iptc.Table6(iptc.Table6.FILTER).create_chain("MAILCOW") + for c in ['FORWARD', 'INPUT']: + chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), c) + rule = iptc.Rule6() + rule.src = '::/0' + rule.dst = '::/0' + target = iptc.Target(rule, "MAILCOW") + rule.target = target + if rule not in chain.rules: + chain.insert_rule(rule) + else: + for family in ["ip", "ip6"]: + insert_mailcow_chains(family) + if __name__ == '__main__': + if backend == 'nftables': + search_current_chains() # In case a previous session was killed without cleanup clear() # Reinit MAILCOW chain