added nftables support to server.py

This commit is contained in:
amorfo77 2023-02-10 18:19:18 +01:00
parent 83a5389242
commit 1d5b5dbd86
1 changed files with 728 additions and 154 deletions

View File

@ -7,6 +7,7 @@ import time
import atexit import atexit
import signal import signal
import ipaddress import ipaddress
import nftables
from collections import Counter from collections import Counter
from random import randint from random import randint
from threading import Thread from threading import Thread
@ -43,6 +44,10 @@ quit_now = False
exit_code = 0 exit_code = 0
lock = Lock() lock = Lock()
backend = sys.argv[1]
nft = None
nft_chain_names = {}
def log(priority, message): def log(priority, message):
tolog = {} tolog = {}
tolog['time'] = int(round(time.time())) tolog['time'] = int(round(time.time()))
@ -60,6 +65,17 @@ def logCrit(message):
def logInfo(message): def logInfo(message):
log('info', message) log('info', message)
#nftables
if backend == 'nftables':
logInfo('Using Nftables backend')
nft = nftables.Nftables()
nft.set_json_output(True)
nft.set_handle_output(True)
nft_chain_names = {'ip': {'filter': {'input': '', 'forward': ''}, 'nat': {'postrouting': ''} },
'ip6': {'filter': {'input': '', 'forward': ''}, 'nat': {'postrouting': ''} } }
else:
logInfo('Using Iptables backend')
def refreshF2boptions(): def refreshF2boptions():
global f2boptions global f2boptions
global quit_now global quit_now
@ -115,13 +131,437 @@ def refreshF2bregex():
if r.exists('F2B_LOG'): if r.exists('F2B_LOG'):
r.rename('F2B_LOG', 'NETFILTER_LOG') r.rename('F2B_LOG', 'NETFILTER_LOG')
# Nftables functions
def nft_exec_dict(query: dict):
global nft
if not query: return False
rc, output, error = nft.json_cmd(query)
if rc != 0:
#logCrit(f"Nftables Error: {error}")
return False
# Prevent returning False or empty string on commands that do not produce output
if rc == 0 and len(output) == 0:
return True
return output
def get_base_dict():
return {'nftables': [{ 'metainfo': { 'json_schema_version': 1} } ] }
def search_current_chains():
global nft_chain_names
nft_chain_priority = {'ip': {'filter': {'input': 1, 'forward': 1}, 'nat': {'postrouting': 111} },
'ip6': {'filter': {'input': 1, 'forward': 1}, 'nat': {'postrouting': 111} } }
# Command: 'nft list chains'
_list_opts = dict(chains='null')
_list = dict(list=_list_opts)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if kernel_ruleset:
for object in kernel_ruleset['nftables']:
chain = object.get("chain")
if not chain:
continue
_family = chain['family']
_table = chain['table']
hook = chain.get("hook")
if not hook or hook not in nft_chain_names[_family][_table]:
continue
_hook = chain['hook']
priority = chain.get("prio")
if priority is None:
continue
if priority < nft_chain_priority[_family][_table][_hook]:
# at this point, we know the chain has:
# hook and priority set
# and it has the lowest priority
nft_chain_priority[_family][_table][_hook] = priority
nft_chain_names[_family][_table][_hook] = chain['name']
def search_for_chain(kernel_ruleset: dict, chain_name: str):
found = False
for object in kernel_ruleset["nftables"]:
chain = object.get("chain")
if not chain:
continue
ch_name = chain.get("name")
if ch_name == chain_name:
found = True
break
return found
def get_chain_dict(_family: str, _name: str):
# nft (add | create) chain [<family>] <table> <name>
_chain_opts = dict(family = _family,
table = 'filter',
name = _name )
_chain = dict(chain = _chain_opts)
_add = dict(add = _chain)
final_chain = get_base_dict()
final_chain["nftables"].append(_add)
return final_chain
def get_mailcow_jump_rule_dict(_family: str, _chain: str):
_jump_rule = get_base_dict()
_expr_opt=[]
_expr_counter = dict(family = _family, table = 'filter', packets = 0, bytes = 0)
_counter_dict = dict(counter = _expr_counter)
_expr_opt.append(_counter_dict)
_expr_jump = dict(target = 'MAILCOW')
_jump_opts = dict(jump = _expr_jump)
_expr_opt.append(_jump_opts)
_rule_params = dict(family = _family,
table = 'filter',
chain = _chain,
expr = _expr_opt,
comment = "mailcow"
)
_opts_rule = dict(rule = _rule_params)
_add_rule = dict(insert = _opts_rule)
_jump_rule["nftables"].append(_add_rule)
return _jump_rule
def insert_mailcow_chains(_family: str):
nft_input_chain = nft_chain_names[_family]['filter']['input']
nft_forward_chain = nft_chain_names[_family]['filter']['forward']
# Command: 'nft list table <family> filter'
_table_opts = dict(family=_family, name='filter')
_table = dict(table=_table_opts)
_list = dict(list=_table)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if kernel_ruleset:
# MAILCOW chain
if not search_for_chain(kernel_ruleset, "MAILCOW"):
cadena = get_chain_dict(_family, "MAILCOW")
if(nft_exec_dict(cadena)):
logInfo(f"MAILCOW {_family} chain created successfully.")
input_jump_found, forward_jump_found = False, False
for object in kernel_ruleset["nftables"]:
if not object.get("rule"):
continue
rule = object["rule"]
if rule["chain"] == nft_input_chain:
if rule.get("comment") and rule["comment"] == "mailcow":
input_jump_found = True
if rule["chain"] == nft_forward_chain:
if rule.get("comment") and rule["comment"] == "mailcow":
forward_jump_found = True
if not input_jump_found and nft_input_chain:
command = get_mailcow_jump_rule_dict(_family, nft_input_chain)
nft_exec_dict(command)
if not forward_jump_found and nft_forward_chain:
command = get_mailcow_jump_rule_dict(_family, nft_forward_chain)
nft_exec_dict(command)
def delete_nat_rule(_family:str, _chain: str, _handle:str):
delete_command = get_base_dict()
_rule_opts = dict(family = _family,
table = 'nat',
chain = _chain,
handle = _handle
)
_rule = dict(rule = _rule_opts)
_delete = dict(delete = _rule)
delete_command["nftables"].append(_delete)
return nft_exec_dict(delete_command)
def snat_rule(_family: str, snat_target: str):
chain_name = nft_chain_names[_family]['nat']['postrouting']
# no postrouting chain, may occur if docker has ipv6 disabled.
if not chain_name: return
# Command: nft list chain <family> nat <chain_name>
_chain_opts = dict(family=_family, table='nat', name=chain_name)
_chain = dict(chain=_chain_opts)
_list = dict(list=_chain)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if not kernel_ruleset:
return
rule_position = 0
rule_handle = None
rule_found = False
for object in kernel_ruleset["nftables"]:
if not object.get("rule"):
continue
rule = object["rule"]
if not rule.get("comment") or not rule["comment"] == "mailcow":
rule_position +=1
continue
else:
rule_found = True
rule_handle = rule["handle"]
break
if _family == "ip":
source_address = os.getenv('IPV4_NETWORK', '172.22.1') + '.0/24'
else:
source_address = os.getenv('IPV6_NETWORK', 'fd4d:6169:6c63:6f77::/64')
tmp_addr = re.split(r'/', source_address)
dest_ip = tmp_addr[0]
dest_len = int(tmp_addr[1])
if rule_found:
saddr_ip = rule["expr"][0]["match"]["right"]["prefix"]["addr"]
saddr_len = int(rule["expr"][0]["match"]["right"]["prefix"]["len"])
daddr_ip = rule["expr"][1]["match"]["right"]["prefix"]["addr"]
daddr_len = int(rule["expr"][1]["match"]["right"]["prefix"]["len"])
match = all((
saddr_ip == dest_ip,
saddr_len == dest_len,
daddr_ip == dest_ip,
daddr_len == dest_len
))
try:
if rule_position == 0:
if not match:
# Position 0 , it is a mailcow rule , but it does not have the same parameters
if delete_nat_rule(_family, chain_name, rule_handle):
logInfo(f'Remove rule for source network {saddr_ip}/{saddr_len} to SNAT target {snat_target} from POSTROUTING chain with handle {rule_handle}')
else:
# Position > 0 and is mailcow rule
if delete_nat_rule(_family, chain_name, rule_handle):
logInfo(f'Remove rule for source network {saddr_ip}/{saddr_len} to SNAT target {snat_target} from POSTROUTING chain with handle {rule_handle}')
except:
logCrit(f"Error running SNAT on {_family}, retrying..." )
else:
# rule not found
json_command = get_base_dict()
try:
payload_fields = dict(protocol = _family, field = "saddr")
payload_dict = dict(payload = payload_fields)
payload_fields2 = dict(protocol = _family, field = "daddr")
payload_dict2 = dict(payload = payload_fields2)
prefix_fields=dict(addr = dest_ip, len = int(dest_len))
prefix_dict=dict(prefix = prefix_fields)
snat_addr = dict(addr = snat_target)
snat_dict = dict(snat = snat_addr)
expr_counter = dict(family = _family, table = "nat", packets = 0, bytes = 0)
counter_dict = dict(counter = expr_counter)
match_fields1 = dict(op = "==", left = payload_dict, right = prefix_dict)
match_dict1 = dict(match = match_fields1)
match_fields2 = dict(op = "!=", left = payload_dict2, right = prefix_dict )
match_dict2 = dict(match = match_fields2)
expr_list = [
match_dict1,
match_dict2,
counter_dict,
snat_dict
]
rule_fields = dict(family = _family,
table = "nat",
chain = chain_name,
comment = "mailcow",
expr = expr_list
)
rule_dict = dict(rule = rule_fields)
insert_dict = dict(insert = rule_dict)
json_command["nftables"].append(insert_dict)
if(nft_exec_dict(json_command)):
logInfo(f"Added {_family} POSTROUTING rule for source network {dest_ip} to {snat_target}")
except:
logCrit(f"Error running SNAT on {_family}, retrying...")
def get_chain_handle(_family: str, _table: str, chain_name: str):
chain_handle = None
# Command: 'nft list chains {family}'
_chain_opts = dict(family=_family)
_chain = dict(chains=_chain_opts)
_list = dict(list=_chain)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if kernel_ruleset:
for object in kernel_ruleset["nftables"]:
if not object.get("chain"):
continue
chain = object["chain"]
if chain["family"] == _family and chain["table"] == _table and chain["name"] == chain_name:
chain_handle = chain["handle"]
break
return chain_handle
def get_rules_handle(_family: str, _table: str, chain_name: str):
rule_handle = []
# Command: 'nft list chain {family} {table} {chain_name}'
_chain_opts = dict(family=_family, table=_table, name=chain_name)
_chain = dict(chain=_chain_opts)
_list = dict(list=_chain)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if kernel_ruleset:
for object in kernel_ruleset["nftables"]:
if not object.get("rule"):
continue
rule = object["rule"]
if rule["family"] == _family and rule["table"] == _table and rule["chain"] == chain_name:
if rule.get("comment") and rule["comment"] == "mailcow":
rule_handle.append(rule["handle"])
return rule_handle
def get_ban_ip_dict(ipaddr: str, _family: str):
json_command = get_base_dict()
expr_opt = []
if re.search(r'/', ipaddr):
divided = re.split(r'/', ipaddr)
prefix_dict=dict(addr = divided[0],
len = int(divided[1]) )
right_dict = dict(prefix = prefix_dict)
else:
right_dict = ipaddr
payload_dict = dict(protocol = _family, field="saddr" )
left_dict = dict(payload = payload_dict)
match_dict = dict(op = "==", left = left_dict, right = right_dict )
match_base = dict(match = match_dict)
expr_opt.append(match_base)
expr_counter = dict(family = _family, table = "filter", packets = 0, bytes = 0)
counter_dict = dict(counter = expr_counter)
expr_opt.append(counter_dict)
drop_dict = dict(drop = "null")
expr_opt.append(drop_dict)
rule_dict = dict(family = _family, table = "filter", chain = "MAILCOW", expr = expr_opt)
base_rule = dict(rule = rule_dict)
base_dict = dict(insert = base_rule)
json_command["nftables"].append(base_dict)
return json_command
def get_unban_ip_dict(ipaddr:str, _family: str):
json_command = get_base_dict()
# Command: 'nft list chain {s_family} filter MAILCOW'
_chain_opts = dict(family=_family, table='filter', name='MAILCOW')
_chain = dict(chain=_chain_opts)
_list = dict(list=_chain)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
rule_handle = None
if kernel_ruleset:
for object in kernel_ruleset["nftables"]:
if not object.get("rule"):
continue
rule = object["rule"]["expr"][0]["match"]
left_opt = rule["left"]["payload"]
if not left_opt["protocol"] == _family:
continue
if not left_opt["field"] =="saddr":
continue
# ip currently banned
rule_right = rule["right"]
if isinstance(rule_right, dict):
current_rule_ip = rule_right["prefix"]["addr"]
current_rule_len = int(rule_right["prefix"]["len"])
else:
current_rule_ip = rule_right
current_rule_len = 32 if _family == 'ip' else 128
# ip to ban
if re.search(r'/', ipaddr):
divided = re.split(r'/', ipaddr)
candidate_ip = divided[0]
candidate_len = int(divided[1])
else:
candidate_ip = ipaddr
candidate_len = 32 if _family == 'ip' else 128
if all((current_rule_ip == candidate_ip,
current_rule_len and candidate_len,
current_rule_len == candidate_len )):
rule_handle = object["rule"]["handle"]
break
if rule_handle is not None:
mailcow_rule = dict(family = _family, table = "filter", chain = "MAILCOW", handle = rule_handle)
del_rule = dict(rule = mailcow_rule)
delete_rule=dict(delete = del_rule)
json_command["nftables"].append(delete_rule)
else:
return False
return json_command
def check_mailcow_chains(family: str, chain: str):
position = 0
rule_found = False
chain_name = nft_chain_names[family]['filter'][chain]
if not chain_name: return None
_chain_opts = dict(family=family, table='filter', name=chain_name)
_chain = dict(chain=_chain_opts)
_list = dict(list=_chain)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if kernel_ruleset:
for object in kernel_ruleset["nftables"]:
if not object.get("rule"):
continue
rule = object["rule"]
if rule.get("comment") and rule["comment"] == "mailcow":
rule_found = True
break
position+=1
return position if rule_found else False
# Mailcow
def mailcowChainOrder(): def mailcowChainOrder():
global lock global lock
global quit_now global quit_now
global exit_code global exit_code
while not quit_now: while not quit_now:
time.sleep(10) time.sleep(10)
with lock: with lock:
if backend == 'iptables':
filter4_table = iptc.Table(iptc.Table.FILTER) filter4_table = iptc.Table(iptc.Table.FILTER)
filter6_table = iptc.Table6(iptc.Table6.FILTER) filter6_table = iptc.Table6(iptc.Table6.FILTER)
filter4_table.refresh() filter4_table.refresh()
@ -142,6 +582,21 @@ def mailcowChainOrder():
logCrit('Error in %s chain: MAILCOW target not found, restarting container' % (chain.name)) logCrit('Error in %s chain: MAILCOW target not found, restarting container' % (chain.name))
quit_now = True quit_now = True
exit_code = 2 exit_code = 2
else:
for family in ["ip", "ip6"]:
for chain in ['input', 'forward']:
chain_position = check_mailcow_chains(family, chain)
if chain_position is None: continue
if chain_position is False:
logCrit('Error in %s %s chain: MAILCOW target not found, restarting container' % (family, chain))
quit_now = True
exit_code = 2
if chain_position > 0:
logCrit('Error in %s %s chain order: MAILCOW on position %d, restarting container' % (family, chain, chain_position))
quit_now = True
exit_code = 2
def ban(address): def ban(address):
global lock global lock
@ -190,6 +645,7 @@ def ban(address):
logCrit('Banning %s for %d minutes' % (net, BAN_TIME / 60)) logCrit('Banning %s for %d minutes' % (net, BAN_TIME / 60))
if type(ip) is ipaddress.IPv4Address: if type(ip) is ipaddress.IPv4Address:
with lock: with lock:
if backend == 'iptables':
chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW')
rule = iptc.Rule() rule = iptc.Rule()
rule.src = net rule.src = net
@ -197,8 +653,12 @@ def ban(address):
rule.target = target rule.target = target
if rule not in chain.rules: if rule not in chain.rules:
chain.insert_rule(rule) chain.insert_rule(rule)
else:
ban_dict = get_ban_ip_dict(net, "ip")
nft_exec_dict(ban_dict)
else: else:
with lock: with lock:
if backend == 'iptables':
chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW')
rule = iptc.Rule6() rule = iptc.Rule6()
rule.src = net rule.src = net
@ -206,6 +666,10 @@ def ban(address):
rule.target = target rule.target = target
if rule not in chain.rules: if rule not in chain.rules:
chain.insert_rule(rule) chain.insert_rule(rule)
else:
ban_dict = get_ban_ip_dict(net, "ip6")
nft_exec_dict(ban_dict)
r.hset('F2B_ACTIVE_BANS', '%s' % net, cur_time + BAN_TIME) r.hset('F2B_ACTIVE_BANS', '%s' % net, cur_time + BAN_TIME)
else: else:
logWarn('%d more attempts in the next %d seconds until %s is banned' % (MAX_ATTEMPTS - bans[net]['attempts'], RETRY_WINDOW, net)) logWarn('%d more attempts in the next %d seconds until %s is banned' % (MAX_ATTEMPTS - bans[net]['attempts'], RETRY_WINDOW, net))
@ -219,6 +683,7 @@ def unban(net):
logInfo('Unbanning %s' % net) logInfo('Unbanning %s' % net)
if type(ipaddress.ip_network(net)) is ipaddress.IPv4Network: if type(ipaddress.ip_network(net)) is ipaddress.IPv4Network:
with lock: with lock:
if backend == 'iptables':
chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW')
rule = iptc.Rule() rule = iptc.Rule()
rule.src = net rule.src = net
@ -226,8 +691,14 @@ def unban(net):
rule.target = target rule.target = target
if rule in chain.rules: if rule in chain.rules:
chain.delete_rule(rule) chain.delete_rule(rule)
else:
dict_unban = get_unban_ip_dict(net, "ip")
if dict_unban:
if(nft_exec_dict(dict_unban)):
logInfo(f"Unbanned ip: {net}")
else: else:
with lock: with lock:
if backend == 'iptables':
chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW')
rule = iptc.Rule6() rule = iptc.Rule6()
rule.src = net rule.src = net
@ -235,6 +706,12 @@ def unban(net):
rule.target = target rule.target = target
if rule in chain.rules: if rule in chain.rules:
chain.delete_rule(rule) chain.delete_rule(rule)
else:
dict_unban = get_unban_ip_dict(net, "ip6")
if dict_unban:
if(nft_exec_dict(dict_unban)):
logInfo(f"Unbanned ip6: {net}")
r.hdel('F2B_ACTIVE_BANS', '%s' % net) r.hdel('F2B_ACTIVE_BANS', '%s' % net)
r.hdel('F2B_QUEUE_UNBAN', '%s' % net) r.hdel('F2B_QUEUE_UNBAN', '%s' % net)
if net in bans: if net in bans:
@ -244,6 +721,7 @@ def permBan(net, unban=False):
global lock global lock
if type(ipaddress.ip_network(net, strict=False)) is ipaddress.IPv4Network: if type(ipaddress.ip_network(net, strict=False)) is ipaddress.IPv4Network:
with lock: with lock:
if backend == 'iptables':
chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW')
rule = iptc.Rule() rule = iptc.Rule()
rule.src = net rule.src = net
@ -257,8 +735,21 @@ def permBan(net, unban=False):
logCrit('Remove host/network %s from blacklist' % net) logCrit('Remove host/network %s from blacklist' % net)
chain.delete_rule(rule) chain.delete_rule(rule)
r.hdel('F2B_PERM_BANS', '%s' % net) r.hdel('F2B_PERM_BANS', '%s' % net)
else:
if not unban:
ban_dict = get_ban_ip_dict(net, "ip")
if(nft_exec_dict(ban_dict)):
logCrit('Add host/network %s to blacklist' % net)
r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time())))
elif unban:
dict_unban = get_unban_ip_dict(net, "ip")
if dict_unban:
if(nft_exec_dict(dict_unban)):
logCrit('Remove host/network %s from blacklist' % net)
r.hdel('F2B_PERM_BANS', '%s' % net)
else: else:
with lock: with lock:
if backend == 'iptables':
chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW')
rule = iptc.Rule6() rule = iptc.Rule6()
rule.src = net rule.src = net
@ -272,6 +763,18 @@ def permBan(net, unban=False):
logCrit('Remove host/network %s from blacklist' % net) logCrit('Remove host/network %s from blacklist' % net)
chain.delete_rule(rule) chain.delete_rule(rule)
r.hdel('F2B_PERM_BANS', '%s' % net) r.hdel('F2B_PERM_BANS', '%s' % net)
else:
if not unban:
ban_dict = get_ban_ip_dict(net, "ip6")
if(nft_exec_dict(ban_dict)):
logCrit('Add host/network %s to blacklist' % net)
r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time())))
elif unban:
dict_unban = get_unban_ip_dict(net, "ip6")
if dict_unban:
if(nft_exec_dict(dict_unban)):
logCrit('Remove host/network %s from blacklist' % net)
r.hdel('F2B_PERM_BANS', '%s' % net)
def quit(signum, frame): def quit(signum, frame):
global quit_now global quit_now
@ -283,6 +786,7 @@ def clear():
for net in bans.copy(): for net in bans.copy():
unban(net) unban(net)
with lock: with lock:
if backend == 'iptables':
filter4_table = iptc.Table(iptc.Table.FILTER) filter4_table = iptc.Table(iptc.Table.FILTER)
filter6_table = iptc.Table6(iptc.Table6.FILTER) filter6_table = iptc.Table6(iptc.Table6.FILTER)
for filter_table in [filter4_table, filter6_table]: for filter_table in [filter4_table, filter6_table]:
@ -303,6 +807,57 @@ def clear():
filter_table.commit() filter_table.commit()
filter_table.refresh() filter_table.refresh()
filter_table.autocommit = True filter_table.autocommit = True
else:
for _family in ["ip", "ip6"]:
is_empty_dict = True
json_command = get_base_dict()
chain_handle = get_chain_handle(_family, "filter", "MAILCOW")
# if no handle, the chain doesn't exists
if chain_handle is not None:
is_empty_dict = False
# flush chain MAILCOW
mailcow_chain = dict(family=_family, table="filter", name="MAILCOW")
mc_chain_base = dict(chain=mailcow_chain)
flush_chain = dict(flush=mc_chain_base)
json_command["nftables"].append(flush_chain)
# remove rule in forward chain
# remove rule in input chain
chains_family = [nft_chain_names[_family]['filter']['input'],
nft_chain_names[_family]['filter']['forward'] ]
for chain_base in chains_family:
if not chain_base: continue
rules_handle = get_rules_handle(_family, "filter", chain_base)
if rules_handle is not None:
for r_handle in rules_handle:
is_empty_dict = False
mailcow_rule = dict(family=_family,
table="filter",
chain=chain_base,
handle=r_handle
)
del_rule = dict(rule=mailcow_rule)
delete_rules=dict(delete=del_rule)
json_command["nftables"].append(delete_rules)
# remove chain MAILCOW
# after delete all rules referencing this chain
if chain_handle is not None:
mc_chain_handle = dict(family=_family,
table="filter",
name="MAILCOW",
handle=chain_handle
)
del_chain=dict(chain=mc_chain_handle)
delete_chain = dict(delete=del_chain)
json_command["nftables"].append(delete_chain)
if is_empty_dict == False:
if(nft_exec_dict(json_command)):
logInfo(f"Clear completed: {_family}")
r.delete('F2B_ACTIVE_BANS') r.delete('F2B_ACTIVE_BANS')
r.delete('F2B_PERM_BANS') r.delete('F2B_PERM_BANS')
pubsub.unsubscribe() pubsub.unsubscribe()
@ -354,6 +909,7 @@ def snat4(snat_target):
time.sleep(10) time.sleep(10)
with lock: with lock:
try: try:
if backend == 'iptables':
table = iptc.Table('nat') table = iptc.Table('nat')
table.refresh() table.refresh()
chain = iptc.Chain(table, 'POSTROUTING') chain = iptc.Chain(table, 'POSTROUTING')
@ -376,6 +932,8 @@ def snat4(snat_target):
chain.delete_rule(rule) chain.delete_rule(rule)
table.commit() table.commit()
table.autocommit = True table.autocommit = True
else:
snat_rule("ip", snat_target)
except: except:
print('Error running SNAT4, retrying...') print('Error running SNAT4, retrying...')
@ -395,21 +953,31 @@ def snat6(snat_target):
time.sleep(10) time.sleep(10)
with lock: with lock:
try: try:
if backend == 'iptables':
table = iptc.Table6('nat') table = iptc.Table6('nat')
table.refresh() table.refresh()
chain = iptc.Chain(table, 'POSTROUTING') chain = iptc.Chain(table, 'POSTROUTING')
table.autocommit = False table.autocommit = False
if get_snat6_rule() not in chain.rules: new_rule = get_snat6_rule()
logInfo('Added POSTROUTING rule for source network %s to SNAT target %s' % (get_snat6_rule().src, snat_target)) for position, rule in enumerate(chain.rules):
chain.insert_rule(get_snat6_rule()) match = all((
table.commit() new_rule.get_src() == rule.get_src(),
new_rule.get_dst() == rule.get_dst(),
new_rule.target.parameters == rule.target.parameters,
new_rule.target.name == rule.target.name
))
if position == 0:
if not match:
logInfo(f'Added POSTROUTING rule for source network {new_rule.src} to SNAT target {snat_target}')
chain.insert_rule(new_rule)
else: else:
for position, item in enumerate(chain.rules): if match:
if item == get_snat6_rule(): logInfo(f'Remove rule for source network {new_rule.src} to SNAT target {snat_target} from POSTROUTING chain at position {position}')
if position != 0: chain.delete_rule(rule)
chain.delete_rule(get_snat6_rule())
table.commit() table.commit()
table.autocommit = True table.autocommit = True
else:
snat_rule("ip6", snat_target)
except: except:
print('Error running SNAT6, retrying...') print('Error running SNAT6, retrying...')
@ -435,7 +1003,6 @@ def isIpNetwork(address):
return False return False
return True return True
def genNetworkList(list): def genNetworkList(list):
resolver = dns.resolver.Resolver() resolver = dns.resolver.Resolver()
hostnames = [] hostnames = []
@ -504,6 +1071,7 @@ def blacklistUpdate():
def initChain(): def initChain():
# Is called before threads start, no locking # Is called before threads start, no locking
print("Initializing mailcow netfilter chain") print("Initializing mailcow netfilter chain")
if backend == 'iptables':
# IPv4 # IPv4
if not iptc.Chain(iptc.Table(iptc.Table.FILTER), "MAILCOW") in iptc.Table(iptc.Table.FILTER).chains: if not iptc.Chain(iptc.Table(iptc.Table.FILTER), "MAILCOW") in iptc.Table(iptc.Table.FILTER).chains:
iptc.Table(iptc.Table.FILTER).create_chain("MAILCOW") iptc.Table(iptc.Table.FILTER).create_chain("MAILCOW")
@ -528,9 +1096,15 @@ def initChain():
rule.target = target rule.target = target
if rule not in chain.rules: if rule not in chain.rules:
chain.insert_rule(rule) chain.insert_rule(rule)
else:
for family in ["ip", "ip6"]:
insert_mailcow_chains(family)
if __name__ == '__main__': if __name__ == '__main__':
if backend == 'nftables':
search_current_chains()
# In case a previous session was killed without cleanup # In case a previous session was killed without cleanup
clear() clear()
# Reinit MAILCOW chain # Reinit MAILCOW chain