added nftables support to server.py

This commit is contained in:
amorfo77 2023-02-10 18:19:18 +01:00
parent 83a5389242
commit 1d5b5dbd86
1 changed files with 728 additions and 154 deletions

View File

@ -7,6 +7,7 @@ import time
import atexit import atexit
import signal import signal
import ipaddress import ipaddress
import nftables
from collections import Counter from collections import Counter
from random import randint from random import randint
from threading import Thread from threading import Thread
@ -43,6 +44,10 @@ quit_now = False
exit_code = 0 exit_code = 0
lock = Lock() lock = Lock()
backend = sys.argv[1]
nft = None
nft_chain_names = {}
def log(priority, message): def log(priority, message):
tolog = {} tolog = {}
tolog['time'] = int(round(time.time())) tolog['time'] = int(round(time.time()))
@ -60,6 +65,17 @@ def logCrit(message):
def logInfo(message): def logInfo(message):
log('info', message) log('info', message)
#nftables
if backend == 'nftables':
logInfo('Using Nftables backend')
nft = nftables.Nftables()
nft.set_json_output(True)
nft.set_handle_output(True)
nft_chain_names = {'ip': {'filter': {'input': '', 'forward': ''}, 'nat': {'postrouting': ''} },
'ip6': {'filter': {'input': '', 'forward': ''}, 'nat': {'postrouting': ''} } }
else:
logInfo('Using Iptables backend')
def refreshF2boptions(): def refreshF2boptions():
global f2boptions global f2boptions
global quit_now global quit_now
@ -115,33 +131,472 @@ def refreshF2bregex():
if r.exists('F2B_LOG'): if r.exists('F2B_LOG'):
r.rename('F2B_LOG', 'NETFILTER_LOG') r.rename('F2B_LOG', 'NETFILTER_LOG')
# Nftables functions
def nft_exec_dict(query: dict):
global nft
if not query: return False
rc, output, error = nft.json_cmd(query)
if rc != 0:
#logCrit(f"Nftables Error: {error}")
return False
# Prevent returning False or empty string on commands that do not produce output
if rc == 0 and len(output) == 0:
return True
return output
def get_base_dict():
return {'nftables': [{ 'metainfo': { 'json_schema_version': 1} } ] }
def search_current_chains():
global nft_chain_names
nft_chain_priority = {'ip': {'filter': {'input': 1, 'forward': 1}, 'nat': {'postrouting': 111} },
'ip6': {'filter': {'input': 1, 'forward': 1}, 'nat': {'postrouting': 111} } }
# Command: 'nft list chains'
_list_opts = dict(chains='null')
_list = dict(list=_list_opts)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if kernel_ruleset:
for object in kernel_ruleset['nftables']:
chain = object.get("chain")
if not chain:
continue
_family = chain['family']
_table = chain['table']
hook = chain.get("hook")
if not hook or hook not in nft_chain_names[_family][_table]:
continue
_hook = chain['hook']
priority = chain.get("prio")
if priority is None:
continue
if priority < nft_chain_priority[_family][_table][_hook]:
# at this point, we know the chain has:
# hook and priority set
# and it has the lowest priority
nft_chain_priority[_family][_table][_hook] = priority
nft_chain_names[_family][_table][_hook] = chain['name']
def search_for_chain(kernel_ruleset: dict, chain_name: str):
found = False
for object in kernel_ruleset["nftables"]:
chain = object.get("chain")
if not chain:
continue
ch_name = chain.get("name")
if ch_name == chain_name:
found = True
break
return found
def get_chain_dict(_family: str, _name: str):
# nft (add | create) chain [<family>] <table> <name>
_chain_opts = dict(family = _family,
table = 'filter',
name = _name )
_chain = dict(chain = _chain_opts)
_add = dict(add = _chain)
final_chain = get_base_dict()
final_chain["nftables"].append(_add)
return final_chain
def get_mailcow_jump_rule_dict(_family: str, _chain: str):
_jump_rule = get_base_dict()
_expr_opt=[]
_expr_counter = dict(family = _family, table = 'filter', packets = 0, bytes = 0)
_counter_dict = dict(counter = _expr_counter)
_expr_opt.append(_counter_dict)
_expr_jump = dict(target = 'MAILCOW')
_jump_opts = dict(jump = _expr_jump)
_expr_opt.append(_jump_opts)
_rule_params = dict(family = _family,
table = 'filter',
chain = _chain,
expr = _expr_opt,
comment = "mailcow"
)
_opts_rule = dict(rule = _rule_params)
_add_rule = dict(insert = _opts_rule)
_jump_rule["nftables"].append(_add_rule)
return _jump_rule
def insert_mailcow_chains(_family: str):
nft_input_chain = nft_chain_names[_family]['filter']['input']
nft_forward_chain = nft_chain_names[_family]['filter']['forward']
# Command: 'nft list table <family> filter'
_table_opts = dict(family=_family, name='filter')
_table = dict(table=_table_opts)
_list = dict(list=_table)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if kernel_ruleset:
# MAILCOW chain
if not search_for_chain(kernel_ruleset, "MAILCOW"):
cadena = get_chain_dict(_family, "MAILCOW")
if(nft_exec_dict(cadena)):
logInfo(f"MAILCOW {_family} chain created successfully.")
input_jump_found, forward_jump_found = False, False
for object in kernel_ruleset["nftables"]:
if not object.get("rule"):
continue
rule = object["rule"]
if rule["chain"] == nft_input_chain:
if rule.get("comment") and rule["comment"] == "mailcow":
input_jump_found = True
if rule["chain"] == nft_forward_chain:
if rule.get("comment") and rule["comment"] == "mailcow":
forward_jump_found = True
if not input_jump_found and nft_input_chain:
command = get_mailcow_jump_rule_dict(_family, nft_input_chain)
nft_exec_dict(command)
if not forward_jump_found and nft_forward_chain:
command = get_mailcow_jump_rule_dict(_family, nft_forward_chain)
nft_exec_dict(command)
def delete_nat_rule(_family:str, _chain: str, _handle:str):
delete_command = get_base_dict()
_rule_opts = dict(family = _family,
table = 'nat',
chain = _chain,
handle = _handle
)
_rule = dict(rule = _rule_opts)
_delete = dict(delete = _rule)
delete_command["nftables"].append(_delete)
return nft_exec_dict(delete_command)
def snat_rule(_family: str, snat_target: str):
chain_name = nft_chain_names[_family]['nat']['postrouting']
# no postrouting chain, may occur if docker has ipv6 disabled.
if not chain_name: return
# Command: nft list chain <family> nat <chain_name>
_chain_opts = dict(family=_family, table='nat', name=chain_name)
_chain = dict(chain=_chain_opts)
_list = dict(list=_chain)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if not kernel_ruleset:
return
rule_position = 0
rule_handle = None
rule_found = False
for object in kernel_ruleset["nftables"]:
if not object.get("rule"):
continue
rule = object["rule"]
if not rule.get("comment") or not rule["comment"] == "mailcow":
rule_position +=1
continue
else:
rule_found = True
rule_handle = rule["handle"]
break
if _family == "ip":
source_address = os.getenv('IPV4_NETWORK', '172.22.1') + '.0/24'
else:
source_address = os.getenv('IPV6_NETWORK', 'fd4d:6169:6c63:6f77::/64')
tmp_addr = re.split(r'/', source_address)
dest_ip = tmp_addr[0]
dest_len = int(tmp_addr[1])
if rule_found:
saddr_ip = rule["expr"][0]["match"]["right"]["prefix"]["addr"]
saddr_len = int(rule["expr"][0]["match"]["right"]["prefix"]["len"])
daddr_ip = rule["expr"][1]["match"]["right"]["prefix"]["addr"]
daddr_len = int(rule["expr"][1]["match"]["right"]["prefix"]["len"])
match = all((
saddr_ip == dest_ip,
saddr_len == dest_len,
daddr_ip == dest_ip,
daddr_len == dest_len
))
try:
if rule_position == 0:
if not match:
# Position 0 , it is a mailcow rule , but it does not have the same parameters
if delete_nat_rule(_family, chain_name, rule_handle):
logInfo(f'Remove rule for source network {saddr_ip}/{saddr_len} to SNAT target {snat_target} from POSTROUTING chain with handle {rule_handle}')
else:
# Position > 0 and is mailcow rule
if delete_nat_rule(_family, chain_name, rule_handle):
logInfo(f'Remove rule for source network {saddr_ip}/{saddr_len} to SNAT target {snat_target} from POSTROUTING chain with handle {rule_handle}')
except:
logCrit(f"Error running SNAT on {_family}, retrying..." )
else:
# rule not found
json_command = get_base_dict()
try:
payload_fields = dict(protocol = _family, field = "saddr")
payload_dict = dict(payload = payload_fields)
payload_fields2 = dict(protocol = _family, field = "daddr")
payload_dict2 = dict(payload = payload_fields2)
prefix_fields=dict(addr = dest_ip, len = int(dest_len))
prefix_dict=dict(prefix = prefix_fields)
snat_addr = dict(addr = snat_target)
snat_dict = dict(snat = snat_addr)
expr_counter = dict(family = _family, table = "nat", packets = 0, bytes = 0)
counter_dict = dict(counter = expr_counter)
match_fields1 = dict(op = "==", left = payload_dict, right = prefix_dict)
match_dict1 = dict(match = match_fields1)
match_fields2 = dict(op = "!=", left = payload_dict2, right = prefix_dict )
match_dict2 = dict(match = match_fields2)
expr_list = [
match_dict1,
match_dict2,
counter_dict,
snat_dict
]
rule_fields = dict(family = _family,
table = "nat",
chain = chain_name,
comment = "mailcow",
expr = expr_list
)
rule_dict = dict(rule = rule_fields)
insert_dict = dict(insert = rule_dict)
json_command["nftables"].append(insert_dict)
if(nft_exec_dict(json_command)):
logInfo(f"Added {_family} POSTROUTING rule for source network {dest_ip} to {snat_target}")
except:
logCrit(f"Error running SNAT on {_family}, retrying...")
def get_chain_handle(_family: str, _table: str, chain_name: str):
chain_handle = None
# Command: 'nft list chains {family}'
_chain_opts = dict(family=_family)
_chain = dict(chains=_chain_opts)
_list = dict(list=_chain)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if kernel_ruleset:
for object in kernel_ruleset["nftables"]:
if not object.get("chain"):
continue
chain = object["chain"]
if chain["family"] == _family and chain["table"] == _table and chain["name"] == chain_name:
chain_handle = chain["handle"]
break
return chain_handle
def get_rules_handle(_family: str, _table: str, chain_name: str):
rule_handle = []
# Command: 'nft list chain {family} {table} {chain_name}'
_chain_opts = dict(family=_family, table=_table, name=chain_name)
_chain = dict(chain=_chain_opts)
_list = dict(list=_chain)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if kernel_ruleset:
for object in kernel_ruleset["nftables"]:
if not object.get("rule"):
continue
rule = object["rule"]
if rule["family"] == _family and rule["table"] == _table and rule["chain"] == chain_name:
if rule.get("comment") and rule["comment"] == "mailcow":
rule_handle.append(rule["handle"])
return rule_handle
def get_ban_ip_dict(ipaddr: str, _family: str):
json_command = get_base_dict()
expr_opt = []
if re.search(r'/', ipaddr):
divided = re.split(r'/', ipaddr)
prefix_dict=dict(addr = divided[0],
len = int(divided[1]) )
right_dict = dict(prefix = prefix_dict)
else:
right_dict = ipaddr
payload_dict = dict(protocol = _family, field="saddr" )
left_dict = dict(payload = payload_dict)
match_dict = dict(op = "==", left = left_dict, right = right_dict )
match_base = dict(match = match_dict)
expr_opt.append(match_base)
expr_counter = dict(family = _family, table = "filter", packets = 0, bytes = 0)
counter_dict = dict(counter = expr_counter)
expr_opt.append(counter_dict)
drop_dict = dict(drop = "null")
expr_opt.append(drop_dict)
rule_dict = dict(family = _family, table = "filter", chain = "MAILCOW", expr = expr_opt)
base_rule = dict(rule = rule_dict)
base_dict = dict(insert = base_rule)
json_command["nftables"].append(base_dict)
return json_command
def get_unban_ip_dict(ipaddr:str, _family: str):
json_command = get_base_dict()
# Command: 'nft list chain {s_family} filter MAILCOW'
_chain_opts = dict(family=_family, table='filter', name='MAILCOW')
_chain = dict(chain=_chain_opts)
_list = dict(list=_chain)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
rule_handle = None
if kernel_ruleset:
for object in kernel_ruleset["nftables"]:
if not object.get("rule"):
continue
rule = object["rule"]["expr"][0]["match"]
left_opt = rule["left"]["payload"]
if not left_opt["protocol"] == _family:
continue
if not left_opt["field"] =="saddr":
continue
# ip currently banned
rule_right = rule["right"]
if isinstance(rule_right, dict):
current_rule_ip = rule_right["prefix"]["addr"]
current_rule_len = int(rule_right["prefix"]["len"])
else:
current_rule_ip = rule_right
current_rule_len = 32 if _family == 'ip' else 128
# ip to ban
if re.search(r'/', ipaddr):
divided = re.split(r'/', ipaddr)
candidate_ip = divided[0]
candidate_len = int(divided[1])
else:
candidate_ip = ipaddr
candidate_len = 32 if _family == 'ip' else 128
if all((current_rule_ip == candidate_ip,
current_rule_len and candidate_len,
current_rule_len == candidate_len )):
rule_handle = object["rule"]["handle"]
break
if rule_handle is not None:
mailcow_rule = dict(family = _family, table = "filter", chain = "MAILCOW", handle = rule_handle)
del_rule = dict(rule = mailcow_rule)
delete_rule=dict(delete = del_rule)
json_command["nftables"].append(delete_rule)
else:
return False
return json_command
def check_mailcow_chains(family: str, chain: str):
position = 0
rule_found = False
chain_name = nft_chain_names[family]['filter'][chain]
if not chain_name: return None
_chain_opts = dict(family=family, table='filter', name=chain_name)
_chain = dict(chain=_chain_opts)
_list = dict(list=_chain)
command = get_base_dict()
command['nftables'].append(_list)
kernel_ruleset = nft_exec_dict(command)
if kernel_ruleset:
for object in kernel_ruleset["nftables"]:
if not object.get("rule"):
continue
rule = object["rule"]
if rule.get("comment") and rule["comment"] == "mailcow":
rule_found = True
break
position+=1
return position if rule_found else False
# Mailcow
def mailcowChainOrder(): def mailcowChainOrder():
global lock global lock
global quit_now global quit_now
global exit_code global exit_code
while not quit_now: while not quit_now:
time.sleep(10) time.sleep(10)
with lock: with lock:
filter4_table = iptc.Table(iptc.Table.FILTER) if backend == 'iptables':
filter6_table = iptc.Table6(iptc.Table6.FILTER) filter4_table = iptc.Table(iptc.Table.FILTER)
filter4_table.refresh() filter6_table = iptc.Table6(iptc.Table6.FILTER)
filter6_table.refresh() filter4_table.refresh()
for f in [filter4_table, filter6_table]: filter6_table.refresh()
forward_chain = iptc.Chain(f, 'FORWARD') for f in [filter4_table, filter6_table]:
input_chain = iptc.Chain(f, 'INPUT') forward_chain = iptc.Chain(f, 'FORWARD')
for chain in [forward_chain, input_chain]: input_chain = iptc.Chain(f, 'INPUT')
target_found = False for chain in [forward_chain, input_chain]:
for position, item in enumerate(chain.rules): target_found = False
if item.target.name == 'MAILCOW': for position, item in enumerate(chain.rules):
target_found = True if item.target.name == 'MAILCOW':
if position > 2: target_found = True
logCrit('Error in %s chain order: MAILCOW on position %d, restarting container' % (chain.name, position)) if position > 2:
quit_now = True logCrit('Error in %s chain order: MAILCOW on position %d, restarting container' % (chain.name, position))
exit_code = 2 quit_now = True
if not target_found: exit_code = 2
logCrit('Error in %s chain: MAILCOW target not found, restarting container' % (chain.name)) if not target_found:
quit_now = True logCrit('Error in %s chain: MAILCOW target not found, restarting container' % (chain.name))
exit_code = 2 quit_now = True
exit_code = 2
else:
for family in ["ip", "ip6"]:
for chain in ['input', 'forward']:
chain_position = check_mailcow_chains(family, chain)
if chain_position is None: continue
if chain_position is False:
logCrit('Error in %s %s chain: MAILCOW target not found, restarting container' % (family, chain))
quit_now = True
exit_code = 2
if chain_position > 0:
logCrit('Error in %s %s chain order: MAILCOW on position %d, restarting container' % (family, chain, chain_position))
quit_now = True
exit_code = 2
def ban(address): def ban(address):
global lock global lock
@ -190,22 +645,31 @@ def ban(address):
logCrit('Banning %s for %d minutes' % (net, BAN_TIME / 60)) logCrit('Banning %s for %d minutes' % (net, BAN_TIME / 60))
if type(ip) is ipaddress.IPv4Address: if type(ip) is ipaddress.IPv4Address:
with lock: with lock:
chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') if backend == 'iptables':
rule = iptc.Rule() chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW')
rule.src = net rule = iptc.Rule()
target = iptc.Target(rule, "REJECT") rule.src = net
rule.target = target target = iptc.Target(rule, "REJECT")
if rule not in chain.rules: rule.target = target
chain.insert_rule(rule) if rule not in chain.rules:
chain.insert_rule(rule)
else:
ban_dict = get_ban_ip_dict(net, "ip")
nft_exec_dict(ban_dict)
else: else:
with lock: with lock:
chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') if backend == 'iptables':
rule = iptc.Rule6() chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW')
rule.src = net rule = iptc.Rule6()
target = iptc.Target(rule, "REJECT") rule.src = net
rule.target = target target = iptc.Target(rule, "REJECT")
if rule not in chain.rules: rule.target = target
chain.insert_rule(rule) if rule not in chain.rules:
chain.insert_rule(rule)
else:
ban_dict = get_ban_ip_dict(net, "ip6")
nft_exec_dict(ban_dict)
r.hset('F2B_ACTIVE_BANS', '%s' % net, cur_time + BAN_TIME) r.hset('F2B_ACTIVE_BANS', '%s' % net, cur_time + BAN_TIME)
else: else:
logWarn('%d more attempts in the next %d seconds until %s is banned' % (MAX_ATTEMPTS - bans[net]['attempts'], RETRY_WINDOW, net)) logWarn('%d more attempts in the next %d seconds until %s is banned' % (MAX_ATTEMPTS - bans[net]['attempts'], RETRY_WINDOW, net))
@ -219,22 +683,35 @@ def unban(net):
logInfo('Unbanning %s' % net) logInfo('Unbanning %s' % net)
if type(ipaddress.ip_network(net)) is ipaddress.IPv4Network: if type(ipaddress.ip_network(net)) is ipaddress.IPv4Network:
with lock: with lock:
chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') if backend == 'iptables':
rule = iptc.Rule() chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW')
rule.src = net rule = iptc.Rule()
target = iptc.Target(rule, "REJECT") rule.src = net
rule.target = target target = iptc.Target(rule, "REJECT")
if rule in chain.rules: rule.target = target
chain.delete_rule(rule) if rule in chain.rules:
chain.delete_rule(rule)
else:
dict_unban = get_unban_ip_dict(net, "ip")
if dict_unban:
if(nft_exec_dict(dict_unban)):
logInfo(f"Unbanned ip: {net}")
else: else:
with lock: with lock:
chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') if backend == 'iptables':
rule = iptc.Rule6() chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW')
rule.src = net rule = iptc.Rule6()
target = iptc.Target(rule, "REJECT") rule.src = net
rule.target = target target = iptc.Target(rule, "REJECT")
if rule in chain.rules: rule.target = target
chain.delete_rule(rule) if rule in chain.rules:
chain.delete_rule(rule)
else:
dict_unban = get_unban_ip_dict(net, "ip6")
if dict_unban:
if(nft_exec_dict(dict_unban)):
logInfo(f"Unbanned ip6: {net}")
r.hdel('F2B_ACTIVE_BANS', '%s' % net) r.hdel('F2B_ACTIVE_BANS', '%s' % net)
r.hdel('F2B_QUEUE_UNBAN', '%s' % net) r.hdel('F2B_QUEUE_UNBAN', '%s' % net)
if net in bans: if net in bans:
@ -244,34 +721,60 @@ def permBan(net, unban=False):
global lock global lock
if type(ipaddress.ip_network(net, strict=False)) is ipaddress.IPv4Network: if type(ipaddress.ip_network(net, strict=False)) is ipaddress.IPv4Network:
with lock: with lock:
chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW') if backend == 'iptables':
rule = iptc.Rule() chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), 'MAILCOW')
rule.src = net rule = iptc.Rule()
target = iptc.Target(rule, "REJECT") rule.src = net
rule.target = target target = iptc.Target(rule, "REJECT")
if rule not in chain.rules and not unban: rule.target = target
logCrit('Add host/network %s to blacklist' % net) if rule not in chain.rules and not unban:
chain.insert_rule(rule) logCrit('Add host/network %s to blacklist' % net)
r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time()))) chain.insert_rule(rule)
elif rule in chain.rules and unban: r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time())))
logCrit('Remove host/network %s from blacklist' % net) elif rule in chain.rules and unban:
chain.delete_rule(rule) logCrit('Remove host/network %s from blacklist' % net)
r.hdel('F2B_PERM_BANS', '%s' % net) chain.delete_rule(rule)
r.hdel('F2B_PERM_BANS', '%s' % net)
else:
if not unban:
ban_dict = get_ban_ip_dict(net, "ip")
if(nft_exec_dict(ban_dict)):
logCrit('Add host/network %s to blacklist' % net)
r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time())))
elif unban:
dict_unban = get_unban_ip_dict(net, "ip")
if dict_unban:
if(nft_exec_dict(dict_unban)):
logCrit('Remove host/network %s from blacklist' % net)
r.hdel('F2B_PERM_BANS', '%s' % net)
else: else:
with lock: with lock:
chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW') if backend == 'iptables':
rule = iptc.Rule6() chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), 'MAILCOW')
rule.src = net rule = iptc.Rule6()
target = iptc.Target(rule, "REJECT") rule.src = net
rule.target = target target = iptc.Target(rule, "REJECT")
if rule not in chain.rules and not unban: rule.target = target
logCrit('Add host/network %s to blacklist' % net) if rule not in chain.rules and not unban:
chain.insert_rule(rule) logCrit('Add host/network %s to blacklist' % net)
r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time()))) chain.insert_rule(rule)
elif rule in chain.rules and unban: r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time())))
logCrit('Remove host/network %s from blacklist' % net) elif rule in chain.rules and unban:
chain.delete_rule(rule) logCrit('Remove host/network %s from blacklist' % net)
r.hdel('F2B_PERM_BANS', '%s' % net) chain.delete_rule(rule)
r.hdel('F2B_PERM_BANS', '%s' % net)
else:
if not unban:
ban_dict = get_ban_ip_dict(net, "ip6")
if(nft_exec_dict(ban_dict)):
logCrit('Add host/network %s to blacklist' % net)
r.hset('F2B_PERM_BANS', '%s' % net, int(round(time.time())))
elif unban:
dict_unban = get_unban_ip_dict(net, "ip6")
if dict_unban:
if(nft_exec_dict(dict_unban)):
logCrit('Remove host/network %s from blacklist' % net)
r.hdel('F2B_PERM_BANS', '%s' % net)
def quit(signum, frame): def quit(signum, frame):
global quit_now global quit_now
@ -283,26 +786,78 @@ def clear():
for net in bans.copy(): for net in bans.copy():
unban(net) unban(net)
with lock: with lock:
filter4_table = iptc.Table(iptc.Table.FILTER) if backend == 'iptables':
filter6_table = iptc.Table6(iptc.Table6.FILTER) filter4_table = iptc.Table(iptc.Table.FILTER)
for filter_table in [filter4_table, filter6_table]: filter6_table = iptc.Table6(iptc.Table6.FILTER)
filter_table.autocommit = False for filter_table in [filter4_table, filter6_table]:
forward_chain = iptc.Chain(filter_table, "FORWARD") filter_table.autocommit = False
input_chain = iptc.Chain(filter_table, "INPUT") forward_chain = iptc.Chain(filter_table, "FORWARD")
mailcow_chain = iptc.Chain(filter_table, "MAILCOW") input_chain = iptc.Chain(filter_table, "INPUT")
if mailcow_chain in filter_table.chains: mailcow_chain = iptc.Chain(filter_table, "MAILCOW")
for rule in mailcow_chain.rules: if mailcow_chain in filter_table.chains:
mailcow_chain.delete_rule(rule) for rule in mailcow_chain.rules:
for rule in forward_chain.rules: mailcow_chain.delete_rule(rule)
if rule.target.name == 'MAILCOW': for rule in forward_chain.rules:
forward_chain.delete_rule(rule) if rule.target.name == 'MAILCOW':
for rule in input_chain.rules: forward_chain.delete_rule(rule)
if rule.target.name == 'MAILCOW': for rule in input_chain.rules:
input_chain.delete_rule(rule) if rule.target.name == 'MAILCOW':
filter_table.delete_chain("MAILCOW") input_chain.delete_rule(rule)
filter_table.commit() filter_table.delete_chain("MAILCOW")
filter_table.refresh() filter_table.commit()
filter_table.autocommit = True filter_table.refresh()
filter_table.autocommit = True
else:
for _family in ["ip", "ip6"]:
is_empty_dict = True
json_command = get_base_dict()
chain_handle = get_chain_handle(_family, "filter", "MAILCOW")
# if no handle, the chain doesn't exists
if chain_handle is not None:
is_empty_dict = False
# flush chain MAILCOW
mailcow_chain = dict(family=_family, table="filter", name="MAILCOW")
mc_chain_base = dict(chain=mailcow_chain)
flush_chain = dict(flush=mc_chain_base)
json_command["nftables"].append(flush_chain)
# remove rule in forward chain
# remove rule in input chain
chains_family = [nft_chain_names[_family]['filter']['input'],
nft_chain_names[_family]['filter']['forward'] ]
for chain_base in chains_family:
if not chain_base: continue
rules_handle = get_rules_handle(_family, "filter", chain_base)
if rules_handle is not None:
for r_handle in rules_handle:
is_empty_dict = False
mailcow_rule = dict(family=_family,
table="filter",
chain=chain_base,
handle=r_handle
)
del_rule = dict(rule=mailcow_rule)
delete_rules=dict(delete=del_rule)
json_command["nftables"].append(delete_rules)
# remove chain MAILCOW
# after delete all rules referencing this chain
if chain_handle is not None:
mc_chain_handle = dict(family=_family,
table="filter",
name="MAILCOW",
handle=chain_handle
)
del_chain=dict(chain=mc_chain_handle)
delete_chain = dict(delete=del_chain)
json_command["nftables"].append(delete_chain)
if is_empty_dict == False:
if(nft_exec_dict(json_command)):
logInfo(f"Clear completed: {_family}")
r.delete('F2B_ACTIVE_BANS') r.delete('F2B_ACTIVE_BANS')
r.delete('F2B_PERM_BANS') r.delete('F2B_PERM_BANS')
pubsub.unsubscribe() pubsub.unsubscribe()
@ -354,28 +909,31 @@ def snat4(snat_target):
time.sleep(10) time.sleep(10)
with lock: with lock:
try: try:
table = iptc.Table('nat') if backend == 'iptables':
table.refresh() table = iptc.Table('nat')
chain = iptc.Chain(table, 'POSTROUTING') table.refresh()
table.autocommit = False chain = iptc.Chain(table, 'POSTROUTING')
new_rule = get_snat4_rule() table.autocommit = False
for position, rule in enumerate(chain.rules): new_rule = get_snat4_rule()
match = all(( for position, rule in enumerate(chain.rules):
new_rule.get_src() == rule.get_src(), match = all((
new_rule.get_dst() == rule.get_dst(), new_rule.get_src() == rule.get_src(),
new_rule.target.parameters == rule.target.parameters, new_rule.get_dst() == rule.get_dst(),
new_rule.target.name == rule.target.name new_rule.target.parameters == rule.target.parameters,
)) new_rule.target.name == rule.target.name
if position == 0: ))
if not match: if position == 0:
logInfo(f'Added POSTROUTING rule for source network {new_rule.src} to SNAT target {snat_target}') if not match:
chain.insert_rule(new_rule) logInfo(f'Added POSTROUTING rule for source network {new_rule.src} to SNAT target {snat_target}')
else: chain.insert_rule(new_rule)
if match: else:
logInfo(f'Remove rule for source network {new_rule.src} to SNAT target {snat_target} from POSTROUTING chain at position {position}') if match:
chain.delete_rule(rule) logInfo(f'Remove rule for source network {new_rule.src} to SNAT target {snat_target} from POSTROUTING chain at position {position}')
table.commit() chain.delete_rule(rule)
table.autocommit = True table.commit()
table.autocommit = True
else:
snat_rule("ip", snat_target)
except: except:
print('Error running SNAT4, retrying...') print('Error running SNAT4, retrying...')
@ -395,21 +953,31 @@ def snat6(snat_target):
time.sleep(10) time.sleep(10)
with lock: with lock:
try: try:
table = iptc.Table6('nat') if backend == 'iptables':
table.refresh() table = iptc.Table6('nat')
chain = iptc.Chain(table, 'POSTROUTING') table.refresh()
table.autocommit = False chain = iptc.Chain(table, 'POSTROUTING')
if get_snat6_rule() not in chain.rules: table.autocommit = False
logInfo('Added POSTROUTING rule for source network %s to SNAT target %s' % (get_snat6_rule().src, snat_target)) new_rule = get_snat6_rule()
chain.insert_rule(get_snat6_rule()) for position, rule in enumerate(chain.rules):
match = all((
new_rule.get_src() == rule.get_src(),
new_rule.get_dst() == rule.get_dst(),
new_rule.target.parameters == rule.target.parameters,
new_rule.target.name == rule.target.name
))
if position == 0:
if not match:
logInfo(f'Added POSTROUTING rule for source network {new_rule.src} to SNAT target {snat_target}')
chain.insert_rule(new_rule)
else:
if match:
logInfo(f'Remove rule for source network {new_rule.src} to SNAT target {snat_target} from POSTROUTING chain at position {position}')
chain.delete_rule(rule)
table.commit() table.commit()
table.autocommit = True
else: else:
for position, item in enumerate(chain.rules): snat_rule("ip6", snat_target)
if item == get_snat6_rule():
if position != 0:
chain.delete_rule(get_snat6_rule())
table.commit()
table.autocommit = True
except: except:
print('Error running SNAT6, retrying...') print('Error running SNAT6, retrying...')
@ -435,7 +1003,6 @@ def isIpNetwork(address):
return False return False
return True return True
def genNetworkList(list): def genNetworkList(list):
resolver = dns.resolver.Resolver() resolver = dns.resolver.Resolver()
hostnames = [] hostnames = []
@ -504,33 +1071,40 @@ def blacklistUpdate():
def initChain(): def initChain():
# Is called before threads start, no locking # Is called before threads start, no locking
print("Initializing mailcow netfilter chain") print("Initializing mailcow netfilter chain")
# IPv4 if backend == 'iptables':
if not iptc.Chain(iptc.Table(iptc.Table.FILTER), "MAILCOW") in iptc.Table(iptc.Table.FILTER).chains: # IPv4
iptc.Table(iptc.Table.FILTER).create_chain("MAILCOW") if not iptc.Chain(iptc.Table(iptc.Table.FILTER), "MAILCOW") in iptc.Table(iptc.Table.FILTER).chains:
for c in ['FORWARD', 'INPUT']: iptc.Table(iptc.Table.FILTER).create_chain("MAILCOW")
chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), c) for c in ['FORWARD', 'INPUT']:
rule = iptc.Rule() chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), c)
rule.src = '0.0.0.0/0' rule = iptc.Rule()
rule.dst = '0.0.0.0/0' rule.src = '0.0.0.0/0'
target = iptc.Target(rule, "MAILCOW") rule.dst = '0.0.0.0/0'
rule.target = target target = iptc.Target(rule, "MAILCOW")
if rule not in chain.rules: rule.target = target
chain.insert_rule(rule) if rule not in chain.rules:
# IPv6 chain.insert_rule(rule)
if not iptc.Chain(iptc.Table6(iptc.Table6.FILTER), "MAILCOW") in iptc.Table6(iptc.Table6.FILTER).chains: # IPv6
iptc.Table6(iptc.Table6.FILTER).create_chain("MAILCOW") if not iptc.Chain(iptc.Table6(iptc.Table6.FILTER), "MAILCOW") in iptc.Table6(iptc.Table6.FILTER).chains:
for c in ['FORWARD', 'INPUT']: iptc.Table6(iptc.Table6.FILTER).create_chain("MAILCOW")
chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), c) for c in ['FORWARD', 'INPUT']:
rule = iptc.Rule6() chain = iptc.Chain(iptc.Table6(iptc.Table6.FILTER), c)
rule.src = '::/0' rule = iptc.Rule6()
rule.dst = '::/0' rule.src = '::/0'
target = iptc.Target(rule, "MAILCOW") rule.dst = '::/0'
rule.target = target target = iptc.Target(rule, "MAILCOW")
if rule not in chain.rules: rule.target = target
chain.insert_rule(rule) if rule not in chain.rules:
chain.insert_rule(rule)
else:
for family in ["ip", "ip6"]:
insert_mailcow_chains(family)
if __name__ == '__main__': if __name__ == '__main__':
if backend == 'nftables':
search_current_chains()
# In case a previous session was killed without cleanup # In case a previous session was killed without cleanup
clear() clear()
# Reinit MAILCOW chain # Reinit MAILCOW chain