675 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Lua
		
	
	
	
	
	
			
		
		
	
	
			675 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Lua
		
	
	
	
	
	
--[[
 | 
						|
Copyright (c) 2011-2017, Vsevolod Stakhov <vsevolod@highsecure.ru>
 | 
						|
Copyright (c) 2016-2017, Andrew Lewis <nerf@judo.za.org>
 | 
						|
 | 
						|
Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
you may not use this file except in compliance with the License.
 | 
						|
You may obtain a copy of the License at
 | 
						|
 | 
						|
    http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
 | 
						|
Unless required by applicable law or agreed to in writing, software
 | 
						|
distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
See the License for the specific language governing permissions and
 | 
						|
limitations under the License.
 | 
						|
]]--
 | 
						|
 | 
						|
if confighelp then
 | 
						|
  return
 | 
						|
end
 | 
						|
 | 
						|
-- A plugin that implements ratelimits using redis
 | 
						|
 | 
						|
local E = {}
 | 
						|
local N = 'ratelimit'
 | 
						|
local redis_params
 | 
						|
-- Senders that are considered as bounce
 | 
						|
local settings = {
 | 
						|
  bounce_senders = { 'postmaster', 'mailer-daemon', '', 'null', 'fetchmail-daemon', 'mdaemon' },
 | 
						|
-- Do not check ratelimits for these recipients
 | 
						|
  whitelisted_rcpts = { 'postmaster', 'mailer-daemon' },
 | 
						|
  prefix = 'RL',
 | 
						|
  ham_factor_rate = 1.01,
 | 
						|
  spam_factor_rate = 0.99,
 | 
						|
  ham_factor_burst = 1.02,
 | 
						|
  spam_factor_burst = 0.98,
 | 
						|
  max_rate_mult = 5,
 | 
						|
  max_bucket_mult = 10,
 | 
						|
  expire = 60 * 60 * 24 * 2, -- 2 days by default
 | 
						|
  limits = {},
 | 
						|
  allow_local = false,
 | 
						|
}
 | 
						|
 | 
						|
-- Checks bucket, updating it if needed
 | 
						|
-- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
 | 
						|
-- KEYS[2] - current time in milliseconds
 | 
						|
-- KEYS[3] - bucket leak rate (messages per millisecond)
 | 
						|
-- KEYS[4] - bucket burst
 | 
						|
-- KEYS[5] - expire for a bucket
 | 
						|
-- return 1 if message should be ratelimited and 0 if not
 | 
						|
-- Redis keys used:
 | 
						|
--   l - last hit
 | 
						|
--   b - current burst
 | 
						|
--   dr - current dynamic rate multiplier (*10000)
 | 
						|
--   db - current dynamic burst multiplier (*10000)
 | 
						|
local bucket_check_script = [[
 | 
						|
  local last = redis.call('HGET', KEYS[1], 'l')
 | 
						|
  local now = tonumber(KEYS[2])
 | 
						|
  local dynr, dynb = 0, 0
 | 
						|
  if not last then
 | 
						|
    -- New bucket
 | 
						|
    redis.call('HSET', KEYS[1], 'l', KEYS[2])
 | 
						|
    redis.call('HSET', KEYS[1], 'b', '0')
 | 
						|
    redis.call('HSET', KEYS[1], 'dr', '10000')
 | 
						|
    redis.call('HSET', KEYS[1], 'db', '10000')
 | 
						|
    redis.call('EXPIRE', KEYS[1], KEYS[5])
 | 
						|
    return {0, 0, 1, 1}
 | 
						|
  end
 | 
						|
 | 
						|
  last = tonumber(last)
 | 
						|
  local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))
 | 
						|
  -- Perform leak
 | 
						|
  if burst > 0 then
 | 
						|
   if last < tonumber(KEYS[2]) then
 | 
						|
    local rate = tonumber(KEYS[3])
 | 
						|
    dynr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000.0
 | 
						|
    rate = rate * dynr
 | 
						|
    local leaked = ((now - last) * rate)
 | 
						|
    burst = burst - leaked
 | 
						|
    redis.call('HINCRBYFLOAT', KEYS[1], 'b', -(leaked))
 | 
						|
   end
 | 
						|
  else
 | 
						|
   burst = 0
 | 
						|
   redis.call('HSET', KEYS[1], 'b', '0')
 | 
						|
  end
 | 
						|
 | 
						|
  dynb = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000.0
 | 
						|
 | 
						|
  if (burst + 1) * dynb > tonumber(KEYS[4]) then
 | 
						|
   return {1, tostring(burst), tostring(dynr), tostring(dynb)}
 | 
						|
  end
 | 
						|
 | 
						|
  return {0, tostring(burst), tostring(dynr), tostring(dynb)}
 | 
						|
]]
 | 
						|
local bucket_check_id
 | 
						|
 | 
						|
 | 
						|
-- Updates a bucket
 | 
						|
-- KEYS[1] - prefix to update, e.g. RL_<triplet>_<seconds>
 | 
						|
-- KEYS[2] - current time in milliseconds
 | 
						|
-- KEYS[3] - dynamic rate multiplier
 | 
						|
-- KEYS[4] - dynamic burst multiplier
 | 
						|
-- KEYS[5] - max dyn rate (min: 1/x)
 | 
						|
-- KEYS[6] - max burst rate (min: 1/x)
 | 
						|
-- KEYS[7] - expire for a bucket
 | 
						|
-- Redis keys used:
 | 
						|
--   l - last hit
 | 
						|
--   b - current burst
 | 
						|
--   dr - current dynamic rate multiplier
 | 
						|
--   db - current dynamic burst multiplier
 | 
						|
local bucket_update_script = [[
 | 
						|
  local last = redis.call('HGET', KEYS[1], 'l')
 | 
						|
  local now = tonumber(KEYS[2])
 | 
						|
  if not last then
 | 
						|
    -- New bucket
 | 
						|
    redis.call('HSET', KEYS[1], 'l', KEYS[2])
 | 
						|
    redis.call('HSET', KEYS[1], 'b', '1')
 | 
						|
    redis.call('HSET', KEYS[1], 'dr', '10000')
 | 
						|
    redis.call('HSET', KEYS[1], 'db', '10000')
 | 
						|
    redis.call('EXPIRE', KEYS[1], KEYS[7])
 | 
						|
    return {1, 1, 1}
 | 
						|
  end
 | 
						|
 | 
						|
  local burst = tonumber(redis.call('HGET', KEYS[1], 'b'))
 | 
						|
  local db = tonumber(redis.call('HGET', KEYS[1], 'db')) / 10000
 | 
						|
  local dr = tonumber(redis.call('HGET', KEYS[1], 'dr')) / 10000
 | 
						|
 | 
						|
  if dr < tonumber(KEYS[5]) and dr > 1.0 / tonumber(KEYS[5]) then
 | 
						|
    dr = dr * tonumber(KEYS[3])
 | 
						|
    redis.call('HSET', KEYS[1], 'dr', tostring(math.floor(dr * 10000)))
 | 
						|
  end
 | 
						|
 | 
						|
  if db < tonumber(KEYS[6]) and db > 1.0 / tonumber(KEYS[6]) then
 | 
						|
    db = db * tonumber(KEYS[4])
 | 
						|
    redis.call('HSET', KEYS[1], 'db', tostring(math.floor(db * 10000)))
 | 
						|
  end
 | 
						|
 | 
						|
  redis.call('HINCRBYFLOAT', KEYS[1], 'b', 1)
 | 
						|
  redis.call('HSET', KEYS[1], 'l', KEYS[2])
 | 
						|
  redis.call('EXPIRE', KEYS[1], KEYS[7])
 | 
						|
 | 
						|
  return {tostring(burst), tostring(dr), tostring(db)}
 | 
						|
]]
 | 
						|
local bucket_update_id
 | 
						|
 | 
						|
-- message_func(task, limit_type, prefix, bucket)
 | 
						|
local message_func = function(_, limit_type, _, _)
 | 
						|
  return string.format('Ratelimit "%s" exceeded', limit_type)
 | 
						|
end
 | 
						|
 | 
						|
local rspamd_logger = require "rspamd_logger"
 | 
						|
local rspamd_util = require "rspamd_util"
 | 
						|
local rspamd_lua_utils = require "lua_util"
 | 
						|
local lua_redis = require "lua_redis"
 | 
						|
local fun = require "fun"
 | 
						|
local lua_maps = require "lua_maps"
 | 
						|
local lua_util = require "lua_util"
 | 
						|
local rspamd_hash = require "rspamd_cryptobox_hash"
 | 
						|
 | 
						|
 | 
						|
local function load_scripts(cfg, ev_base)
 | 
						|
  bucket_check_id = lua_redis.add_redis_script(bucket_check_script, redis_params)
 | 
						|
  bucket_update_id = lua_redis.add_redis_script(bucket_update_script, redis_params)
 | 
						|
end
 | 
						|
 | 
						|
local limit_parser
 | 
						|
local function parse_string_limit(lim, no_error)
 | 
						|
  local function parse_time_suffix(s)
 | 
						|
    if s == 's' then
 | 
						|
      return 1
 | 
						|
    elseif s == 'm' then
 | 
						|
      return 60
 | 
						|
    elseif s == 'h' then
 | 
						|
      return 3600
 | 
						|
    elseif s == 'd' then
 | 
						|
      return 86400
 | 
						|
    end
 | 
						|
  end
 | 
						|
  local function parse_num_suffix(s)
 | 
						|
    if s == '' then
 | 
						|
      return 1
 | 
						|
    elseif s == 'k' then
 | 
						|
      return 1000
 | 
						|
    elseif s == 'm' then
 | 
						|
      return 1000000
 | 
						|
    elseif s == 'g' then
 | 
						|
      return 1000000000
 | 
						|
    end
 | 
						|
  end
 | 
						|
  local lpeg = require "lpeg"
 | 
						|
 | 
						|
  if not limit_parser then
 | 
						|
    local digit = lpeg.R("09")
 | 
						|
    limit_parser = {}
 | 
						|
    limit_parser.integer =
 | 
						|
    (lpeg.S("+-") ^ -1) *
 | 
						|
            (digit   ^  1)
 | 
						|
    limit_parser.fractional =
 | 
						|
    (lpeg.P(".")   ) *
 | 
						|
            (digit ^ 1)
 | 
						|
    limit_parser.number =
 | 
						|
    (limit_parser.integer *
 | 
						|
            (limit_parser.fractional ^ -1)) +
 | 
						|
            (lpeg.S("+-") * limit_parser.fractional)
 | 
						|
    limit_parser.time = lpeg.Cf(lpeg.Cc(1) *
 | 
						|
            (limit_parser.number / tonumber) *
 | 
						|
            ((lpeg.S("smhd") / parse_time_suffix) ^ -1),
 | 
						|
      function (acc, val) return acc * val end)
 | 
						|
    limit_parser.suffixed_number = lpeg.Cf(lpeg.Cc(1) *
 | 
						|
            (limit_parser.number / tonumber) *
 | 
						|
            ((lpeg.S("kmg") / parse_num_suffix) ^ -1),
 | 
						|
      function (acc, val) return acc * val end)
 | 
						|
    limit_parser.limit = lpeg.Ct(limit_parser.suffixed_number *
 | 
						|
            (lpeg.S(" ") ^ 0) * lpeg.S("/") * (lpeg.S(" ") ^ 0) *
 | 
						|
            limit_parser.time)
 | 
						|
  end
 | 
						|
  local t = lpeg.match(limit_parser.limit, lim)
 | 
						|
 | 
						|
  if t and t[1] and t[2] and t[2] ~= 0 then
 | 
						|
    return t[2], t[1]
 | 
						|
  end
 | 
						|
 | 
						|
  if not no_error then
 | 
						|
    rspamd_logger.errx(rspamd_config, 'bad limit: %s', lim)
 | 
						|
  end
 | 
						|
 | 
						|
  return nil
 | 
						|
end
 | 
						|
 | 
						|
local function parse_limit(name, data)
 | 
						|
  local buckets = {}
 | 
						|
  if type(data) == 'table' then
 | 
						|
    -- 3 cases here:
 | 
						|
    --  * old limit in format [burst, rate]
 | 
						|
    --  * vector of strings in Andrew's string format
 | 
						|
    --  * proper bucket table
 | 
						|
    if #data == 2 and tonumber(data[1]) and tonumber(data[2]) then
 | 
						|
      -- Old style ratelimit
 | 
						|
      rspamd_logger.warnx(rspamd_config, 'old style ratelimit for %s', name)
 | 
						|
      if tonumber(data[1]) > 0 and tonumber(data[2]) > 0 then
 | 
						|
        table.insert(buckets, {
 | 
						|
          burst = data[1],
 | 
						|
          rate = data[2]
 | 
						|
        })
 | 
						|
      elseif data[1] ~= 0 then
 | 
						|
        rspamd_logger.warnx(rspamd_config, 'invalid numbers for %s', name)
 | 
						|
      else
 | 
						|
        rspamd_logger.infox(rspamd_config, 'disable limit %s, burst is zero', name)
 | 
						|
      end
 | 
						|
    else
 | 
						|
      -- Recursively map parse_limit and flatten the list
 | 
						|
      fun.each(function(l)
 | 
						|
        -- Flatten list
 | 
						|
        for _,b in ipairs(l) do table.insert(buckets, b) end
 | 
						|
      end, fun.map(function(d) return parse_limit(d, name) end, data))
 | 
						|
    end
 | 
						|
  elseif type(data) == 'string' then
 | 
						|
    local rep_rate, burst = parse_string_limit(data)
 | 
						|
 | 
						|
    if rep_rate and burst then
 | 
						|
      table.insert(buckets, {
 | 
						|
        burst = burst,
 | 
						|
        rate = 1.0 / rep_rate -- reciprocal
 | 
						|
      })
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  -- Filter valid
 | 
						|
  return fun.totable(fun.filter(function(val)
 | 
						|
    return type(val.burst) == 'number' and type(val.rate) == 'number'
 | 
						|
  end, buckets))
 | 
						|
end
 | 
						|
 | 
						|
--- Check whether this addr is bounce
 | 
						|
local function check_bounce(from)
 | 
						|
  return fun.any(function(b) return b == from end, settings.bounce_senders)
 | 
						|
end
 | 
						|
 | 
						|
local keywords = {
 | 
						|
  ['ip'] = {
 | 
						|
    ['get_value'] = function(task)
 | 
						|
      local ip = task:get_ip()
 | 
						|
      if ip and ip:is_valid() then return tostring(ip) end
 | 
						|
      return nil
 | 
						|
    end,
 | 
						|
  },
 | 
						|
  ['rip'] = {
 | 
						|
    ['get_value'] = function(task)
 | 
						|
      local ip = task:get_ip()
 | 
						|
      if ip and ip:is_valid() and not ip:is_local() then return tostring(ip) end
 | 
						|
      return nil
 | 
						|
    end,
 | 
						|
  },
 | 
						|
  ['from'] = {
 | 
						|
    ['get_value'] = function(task)
 | 
						|
      local from = task:get_from(0)
 | 
						|
      if ((from or E)[1] or E).addr then
 | 
						|
        return string.lower(from[1]['addr'])
 | 
						|
      end
 | 
						|
      return nil
 | 
						|
    end,
 | 
						|
  },
 | 
						|
  ['bounce'] = {
 | 
						|
    ['get_value'] = function(task)
 | 
						|
      local from = task:get_from(0)
 | 
						|
      if not ((from or E)[1] or E).user then
 | 
						|
        return '_'
 | 
						|
      end
 | 
						|
      if check_bounce(from[1]['user']) then return '_' else return nil end
 | 
						|
    end,
 | 
						|
  },
 | 
						|
  ['asn'] = {
 | 
						|
    ['get_value'] = function(task)
 | 
						|
      local asn = task:get_mempool():get_variable('asn')
 | 
						|
      if not asn then
 | 
						|
        return nil
 | 
						|
      else
 | 
						|
        return asn
 | 
						|
      end
 | 
						|
    end,
 | 
						|
  },
 | 
						|
  ['user'] = {
 | 
						|
    ['get_value'] = function(task)
 | 
						|
      local auser = task:get_user()
 | 
						|
      if not auser then
 | 
						|
        return nil
 | 
						|
      else
 | 
						|
        return auser
 | 
						|
      end
 | 
						|
    end,
 | 
						|
  },
 | 
						|
  ['to'] = {
 | 
						|
    ['get_value'] = function(task)
 | 
						|
      return task:get_principal_recipient()
 | 
						|
    end,
 | 
						|
  },
 | 
						|
}
 | 
						|
 | 
						|
local function gen_rate_key(task, rtype, bucket)
 | 
						|
  local key_t = {tostring(lua_util.round(100000.0 / bucket.burst))}
 | 
						|
  local key_keywords = lua_util.str_split(rtype, '_')
 | 
						|
  local have_user = false
 | 
						|
 | 
						|
  for _, v in ipairs(key_keywords) do
 | 
						|
    local ret
 | 
						|
 | 
						|
    if keywords[v] and type(keywords[v]['get_value']) == 'function' then
 | 
						|
      ret = keywords[v]['get_value'](task)
 | 
						|
    end
 | 
						|
    if not ret then return nil end
 | 
						|
    if v == 'user' then have_user = true end
 | 
						|
    if type(ret) ~= 'string' then ret = tostring(ret) end
 | 
						|
    table.insert(key_t, ret)
 | 
						|
  end
 | 
						|
 | 
						|
  if have_user and not task:get_user() then
 | 
						|
    return nil
 | 
						|
  end
 | 
						|
 | 
						|
  return table.concat(key_t, ":")
 | 
						|
end
 | 
						|
 | 
						|
local function make_prefix(redis_key, name, bucket)
 | 
						|
  local hash_len = 24
 | 
						|
  if hash_len > #redis_key then hash_len = #redis_key end
 | 
						|
  local hash = settings.prefix ..
 | 
						|
      string.sub(rspamd_hash.create(redis_key):base32(), 1, hash_len)
 | 
						|
  -- Fill defaults
 | 
						|
  if not bucket.spam_factor_rate then
 | 
						|
    bucket.spam_factor_rate = settings.spam_factor_rate
 | 
						|
  end
 | 
						|
  if not bucket.ham_factor_rate then
 | 
						|
    bucket.ham_factor_rate = settings.ham_factor_rate
 | 
						|
  end
 | 
						|
  if not bucket.spam_factor_burst then
 | 
						|
    bucket.spam_factor_burst = settings.spam_factor_burst
 | 
						|
  end
 | 
						|
  if not bucket.ham_factor_burst then
 | 
						|
    bucket.ham_factor_burst = settings.ham_factor_burst
 | 
						|
  end
 | 
						|
 | 
						|
  return {
 | 
						|
    bucket = bucket,
 | 
						|
    name = name,
 | 
						|
    hash = hash
 | 
						|
  }
 | 
						|
end
 | 
						|
 | 
						|
local function limit_to_prefixes(task, k, v, prefixes)
 | 
						|
  local n = 0
 | 
						|
  for _,bucket in ipairs(v) do
 | 
						|
    local prefix = gen_rate_key(task, k, bucket)
 | 
						|
 | 
						|
    if prefix then
 | 
						|
      prefixes[prefix] = make_prefix(prefix, k, bucket)
 | 
						|
      n = n + 1
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  return n
 | 
						|
end
 | 
						|
 | 
						|
local function ratelimit_cb(task)
 | 
						|
  if not settings.allow_local and
 | 
						|
          rspamd_lua_utils.is_rspamc_or_controller(task) then return end
 | 
						|
 | 
						|
  -- Get initial task data
 | 
						|
  local ip = task:get_from_ip()
 | 
						|
  if ip and ip:is_valid() and settings.whitelisted_ip then
 | 
						|
    if settings.whitelisted_ip:get_key(ip) then
 | 
						|
      -- Do not check whitelisted ip
 | 
						|
      rspamd_logger.infox(task, 'skip ratelimit for whitelisted IP')
 | 
						|
      return
 | 
						|
    end
 | 
						|
  end
 | 
						|
  -- Parse all rcpts
 | 
						|
  local rcpts = task:get_recipients()
 | 
						|
  local rcpts_user = {}
 | 
						|
  if rcpts then
 | 
						|
    fun.each(function(r)
 | 
						|
      fun.each(function(type) table.insert(rcpts_user, r[type]) end, {'user', 'addr'})
 | 
						|
    end, rcpts)
 | 
						|
 | 
						|
    if fun.any(function(r) return settings.whitelisted_rcpts:get_key(r) end, rcpts_user) then
 | 
						|
      rspamd_logger.infox(task, 'skip ratelimit for whitelisted recipient')
 | 
						|
      return
 | 
						|
    end
 | 
						|
  end
 | 
						|
  -- Get user (authuser)
 | 
						|
  if settings.whitelisted_user then
 | 
						|
    local auser = task:get_user()
 | 
						|
    if settings.whitelisted_user:get_key(auser) then
 | 
						|
      rspamd_logger.infox(task, 'skip ratelimit for whitelisted user')
 | 
						|
      return
 | 
						|
    end
 | 
						|
  end
 | 
						|
  -- Now create all ratelimit prefixes
 | 
						|
  local prefixes = {}
 | 
						|
  local nprefixes = 0
 | 
						|
 | 
						|
  for k,v in pairs(settings.limits) do
 | 
						|
    nprefixes = nprefixes + limit_to_prefixes(task, k, v, prefixes)
 | 
						|
  end
 | 
						|
 | 
						|
  for k, hdl in pairs(settings.custom_keywords or E) do
 | 
						|
    local ret, redis_key, bd = pcall(hdl, task)
 | 
						|
 | 
						|
    if ret then
 | 
						|
      local bucket = parse_limit(k, bd)
 | 
						|
      if bucket[1] then
 | 
						|
        prefixes[redis_key] = make_prefix(redis_key, k, bucket[1])
 | 
						|
      end
 | 
						|
      nprefixes = nprefixes + 1
 | 
						|
    else
 | 
						|
      rspamd_logger.errx(task, 'cannot call handler for %s: %s',
 | 
						|
          k, redis_key)
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  local function gen_check_cb(prefix, bucket, lim_name)
 | 
						|
    return function(err, data)
 | 
						|
      if err then
 | 
						|
        rspamd_logger.errx('cannot check limit %s: %s %s', prefix, err, data)
 | 
						|
      elseif type(data) == 'table' and data[1] and data[1] == 1 then
 | 
						|
        -- set symbol only and do NOT soft reject
 | 
						|
        if settings.symbol then
 | 
						|
          task:insert_result(settings.symbol, 0.0, lim_name .. "(" .. prefix .. ")")
 | 
						|
          rspamd_logger.infox(task,
 | 
						|
              'set_symbol_only: ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)',
 | 
						|
              lim_name, prefix,
 | 
						|
              bucket.burst, bucket.rate,
 | 
						|
              data[2], data[3], data[4])
 | 
						|
          return
 | 
						|
        -- set INFO symbol and soft reject
 | 
						|
        elseif settings.info_symbol then
 | 
						|
          task:insert_result(settings.info_symbol, 1.0,
 | 
						|
              lim_name .. "(" .. prefix .. ")")
 | 
						|
        end
 | 
						|
        rspamd_logger.infox(task,
 | 
						|
            'ratelimit "%s(%s)" exceeded, (%s / %s): %s (%s:%s dyn)',
 | 
						|
            lim_name, prefix,
 | 
						|
            bucket.burst, bucket.rate,
 | 
						|
            data[2], data[3], data[4])
 | 
						|
        task:set_pre_result('soft reject',
 | 
						|
                message_func(task, lim_name, prefix, bucket))
 | 
						|
      end
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  -- Don't do anything if pre-result has been already set
 | 
						|
  if task:has_pre_result() then return end
 | 
						|
 | 
						|
  if nprefixes > 0 then
 | 
						|
    -- Save prefixes to the cache to allow update
 | 
						|
    task:cache_set('ratelimit_prefixes', prefixes)
 | 
						|
    local now = rspamd_util.get_time()
 | 
						|
    now = lua_util.round(now * 1000.0) -- Get milliseconds
 | 
						|
    -- Now call check script for all defined prefixes
 | 
						|
 | 
						|
    for pr,value in pairs(prefixes) do
 | 
						|
      local bucket = value.bucket
 | 
						|
      local rate = (bucket.rate) / 1000.0 -- Leak rate in messages/ms
 | 
						|
      rspamd_logger.debugm(N, task, "check limit %s:%s -> %s (%s/%s)",
 | 
						|
          value.name, pr, value.hash, bucket.burst, bucket.rate)
 | 
						|
      lua_redis.exec_redis_script(bucket_check_id,
 | 
						|
              {key = value.hash, task = task, is_write = true},
 | 
						|
              gen_check_cb(pr, bucket, value.name),
 | 
						|
              {value.hash, tostring(now), tostring(rate), tostring(bucket.burst),
 | 
						|
                  tostring(settings.expire)})
 | 
						|
    end
 | 
						|
  end
 | 
						|
end
 | 
						|
 | 
						|
local function ratelimit_update_cb(task)
 | 
						|
  local prefixes = task:cache_get('ratelimit_prefixes')
 | 
						|
 | 
						|
  if prefixes then
 | 
						|
    if task:has_pre_result() then
 | 
						|
      -- Already rate limited/greylisted, do nothing
 | 
						|
      rspamd_logger.debugm(N, task, 'pre-action has been set, do not update')
 | 
						|
      return
 | 
						|
    end
 | 
						|
 | 
						|
    local is_spam = not (task:get_metric_action() == 'no action')
 | 
						|
 | 
						|
    -- Update each bucket
 | 
						|
    for k, v in pairs(prefixes) do
 | 
						|
      local bucket = v.bucket
 | 
						|
      local function update_bucket_cb(err, data)
 | 
						|
        if err then
 | 
						|
          rspamd_logger.errx(task, 'cannot update rate bucket %s: %s',
 | 
						|
                  k, err)
 | 
						|
        else
 | 
						|
          rspamd_logger.debugm(N, task,
 | 
						|
              "updated limit %s:%s -> %s (%s/%s), burst: %s, dyn_rate: %s, dyn_burst: %s",
 | 
						|
              v.name, k, v.hash,
 | 
						|
              bucket.burst, bucket.rate,
 | 
						|
              data[1], data[2], data[3])
 | 
						|
        end
 | 
						|
      end
 | 
						|
      local now = rspamd_util.get_time()
 | 
						|
      now = lua_util.round(now * 1000.0) -- Get milliseconds
 | 
						|
      local mult_burst = bucket.ham_factor_burst or 1.0
 | 
						|
      local mult_rate = bucket.ham_factor_burst or 1.0
 | 
						|
 | 
						|
      if is_spam then
 | 
						|
        mult_burst = bucket.spam_factor_burst or 1.0
 | 
						|
        mult_rate = bucket.spam_factor_rate or 1.0
 | 
						|
      end
 | 
						|
 | 
						|
      lua_redis.exec_redis_script(bucket_update_id,
 | 
						|
              {key = v.hash, task = task, is_write = true},
 | 
						|
              update_bucket_cb,
 | 
						|
              {v.hash, tostring(now), tostring(mult_rate), tostring(mult_burst),
 | 
						|
               tostring(settings.max_rate_mult), tostring(settings.max_bucket_mult),
 | 
						|
               tostring(settings.expire)})
 | 
						|
    end
 | 
						|
  end
 | 
						|
end
 | 
						|
 | 
						|
local opts = rspamd_config:get_all_opt(N)
 | 
						|
if opts then
 | 
						|
 | 
						|
  settings = lua_util.override_defaults(settings, opts)
 | 
						|
 | 
						|
  if opts['limit'] then
 | 
						|
    rspamd_logger.errx(rspamd_config, 'Legacy ratelimit config format no longer supported')
 | 
						|
  end
 | 
						|
 | 
						|
  if opts['rates'] and type(opts['rates']) == 'table' then
 | 
						|
    -- new way of setting limits
 | 
						|
    fun.each(function(t, lim)
 | 
						|
      local buckets = parse_limit(t, lim)
 | 
						|
 | 
						|
      if buckets and #buckets > 0 then
 | 
						|
        settings.limits[t] = buckets
 | 
						|
      end
 | 
						|
    end, opts['rates'])
 | 
						|
  end
 | 
						|
 | 
						|
  local enabled_limits = fun.totable(fun.map(function(t)
 | 
						|
    return t
 | 
						|
  end, settings.limits))
 | 
						|
  rspamd_logger.infox(rspamd_config,
 | 
						|
          'enabled rate buckets: [%1]', table.concat(enabled_limits, ','))
 | 
						|
 | 
						|
  -- Ret, ret, ret: stupid legacy stuff:
 | 
						|
  -- If we have a string with commas then load it as as static map
 | 
						|
  -- otherwise, apply normal logic of Rspamd maps
 | 
						|
 | 
						|
  local wrcpts = opts['whitelisted_rcpts']
 | 
						|
  if type(wrcpts) == 'string' then
 | 
						|
    if string.find(wrcpts, ',') then
 | 
						|
      settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
 | 
						|
        lua_util.rspamd_str_split(wrcpts, ','), 'set', 'Ratelimit whitelisted rcpts')
 | 
						|
    else
 | 
						|
      settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
 | 
						|
        'Ratelimit whitelisted rcpts')
 | 
						|
    end
 | 
						|
  elseif type(opts['whitelisted_rcpts']) == 'table' then
 | 
						|
    settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(wrcpts, 'set',
 | 
						|
      'Ratelimit whitelisted rcpts')
 | 
						|
  else
 | 
						|
    -- Stupid default...
 | 
						|
    settings.whitelisted_rcpts = lua_maps.rspamd_map_add_from_ucl(
 | 
						|
        settings.whitelisted_rcpts, 'set', 'Ratelimit whitelisted rcpts')
 | 
						|
  end
 | 
						|
 | 
						|
  if opts['whitelisted_ip'] then
 | 
						|
    settings.whitelisted_ip = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_ip', 'radix',
 | 
						|
      'Ratelimit whitelist ip map')
 | 
						|
  end
 | 
						|
 | 
						|
  if opts['whitelisted_user'] then
 | 
						|
    settings.whitelisted_user = lua_maps.rspamd_map_add('ratelimit', 'whitelisted_user', 'set',
 | 
						|
      'Ratelimit whitelist user map')
 | 
						|
  end
 | 
						|
 | 
						|
  settings.custom_keywords = {}
 | 
						|
  if opts['custom_keywords'] then
 | 
						|
    local ret, res_or_err = pcall(loadfile(opts['custom_keywords']))
 | 
						|
 | 
						|
    if ret then
 | 
						|
      opts['custom_keywords'] = {}
 | 
						|
      if type(res_or_err) == 'table' then
 | 
						|
        for k,hdl in pairs(res_or_err) do
 | 
						|
          settings['custom_keywords'][k] = hdl
 | 
						|
        end
 | 
						|
      elseif type(res_or_err) == 'function' then
 | 
						|
        settings['custom_keywords']['custom'] = res_or_err
 | 
						|
      end
 | 
						|
    else
 | 
						|
      rspamd_logger.errx(rspamd_config, 'cannot execute %s: %s',
 | 
						|
          opts['custom_keywords'], res_or_err)
 | 
						|
      settings['custom_keywords'] = {}
 | 
						|
    end
 | 
						|
  end
 | 
						|
 | 
						|
  if opts['message_func'] then
 | 
						|
    message_func = assert(load(opts['message_func']))()
 | 
						|
  end
 | 
						|
 | 
						|
  redis_params = lua_redis.parse_redis_server('ratelimit')
 | 
						|
 | 
						|
  if not redis_params then
 | 
						|
    rspamd_logger.infox(rspamd_config, 'no servers are specified, disabling module')
 | 
						|
    lua_util.disable_module(N, "redis")
 | 
						|
  else
 | 
						|
    local s = {
 | 
						|
      type = 'prefilter,nostat',
 | 
						|
      name = 'RATELIMIT_CHECK',
 | 
						|
      priority = 7,
 | 
						|
      callback = ratelimit_cb,
 | 
						|
      flags = 'empty',
 | 
						|
    }
 | 
						|
 | 
						|
    if settings.symbol then
 | 
						|
      s.name = settings.symbol
 | 
						|
    elseif settings.info_symbol then
 | 
						|
      s.name = settings.info_symbol
 | 
						|
    end
 | 
						|
 | 
						|
    rspamd_config:register_symbol(s)
 | 
						|
    rspamd_config:register_symbol {
 | 
						|
      type = 'idempotent',
 | 
						|
      name = 'RATELIMIT_UPDATE',
 | 
						|
      callback = ratelimit_update_cb,
 | 
						|
    }
 | 
						|
  end
 | 
						|
end
 | 
						|
 | 
						|
rspamd_config:add_on_load(function(cfg, ev_base, worker)
 | 
						|
  load_scripts(cfg, ev_base)
 | 
						|
end)
 |