Commit 61baa344 authored by Kaarle Ritvanen's avatar Kaarle Ritvanen

Limit: set address mask length

parent a211a70e
......@@ -22,6 +22,7 @@ local util = require('awall.util')
local contains = util.contains
local extend = util.extend
local filter = util.filter
local join = util.join
local listpairs = util.listpairs
local maplist = util.maplist
local setdefault = util.setdefault
......@@ -33,15 +34,6 @@ local startswith = require('stringy').startswith
local RECENT_MAX_COUNT = 20
local function join(a, b)
local comps = {}
local function add(s) if s and s > '' then table.insert(comps, s) end end
add(a)
add(b)
if comps[1] then return table.concat(comps, '-') end
end
M.ConfigObject = M.class()
function M.ConfigObject:init(context, location)
......@@ -72,7 +64,7 @@ function M.ConfigObject:create(cls, params, label, index)
end
if type(params) ~= 'table' then params = {params} end
params.label = join(self.label, label)
params.label = join(self.label, '-', label)
local obj = cls.morph(params, self.context, self.location)
if key then self.extraobjs[key] = obj end
......@@ -84,7 +76,7 @@ function M.ConfigObject:uniqueid(key)
if self.uniqueids[key] then return self.uniqueids[key] end
local lastid = setdefault(self.context, 'lastid', {})
local res = join(key, self.label)
local res = join(key, '-', self.label)
lastid[res] = setdefault(lastid, res, -1) + 1
res = res..'-'..lastid[res]
......@@ -587,6 +579,20 @@ function M.Limit:init(...)
end
setdefault(self, 'interval', 1)
if type(setdefault(self, 'mask', {})) == 'number' then
self.mask = {src=self.mask}
end
for family, len in pairs{inet=32, inet6=128} do
setdefault(self.mask, family, util.copy(self.mask))
for attr, default in pairs{src=len, dest=0} do
local mask = setdefault(self.mask[family], attr, default)
if mask > 0 then
self.mask[family].mode =
self.mask[family].mode and true or {attr, mask}
end
end
end
end
function M.Limit:rate() return math.ceil(self.count / self.interval) end
......@@ -602,22 +608,84 @@ function M.Limit:recentofrags(name)
if count > RECENT_MAX_COUNT then return end
local rec = '-m recent --name '..name
return {
{opts=rec..' --update --hitcount '..count..' --seconds '..interval}
}, {{opts=rec..' --set'}}
local uofs = {}
local sofs = {}
for _, family in ipairs{'inet', 'inet6'} do
if type(self.mask[family].mode) ~= 'table' then return end
local mask = ''
local attr, len = unpack(self.mask[family].mode)
if family == 'inet' then
local octet
for i = 0, 3 do
if len <= i * 8 then octet = 0
elseif len > i * 8 + 7 then octet = 255
else octet = 256 - 2^(8 - len % 8) end
mask = util.join(mask, '.', octet)
end
elseif family == 'inet6' then
while len > 0 do
if #mask % 5 == 4 then mask = mask..':' end
mask = mask..('%x'):format(16 - 2^math.max(0, 4 - len))
len = len - 4
end
while #mask % 5 < 4 do mask = mask..'0' end
if #mask < 39 then mask = mask..'::' end
end
local rec = {
{
family=family,
opts='-m recent --name '..name..' --r'..
({src='source', dest='dest'})[attr]..' --mask '..mask
}
}
extend(
uofs,
combinations(
rec,
{{opts='--update --hitcount '..count..' --seconds '..interval}}
)
)
extend(sofs, combinations(rec, {{opts='--set'}}))
end
return uofs, sofs
end
function M.Limit:limitofrags(name)
local rate = self:rate()
return {
{
opts='-m hashlimit --hashlimit-upto '..rate..
'/second --hashlimit-burst '..rate..
' --hashlimit-mode srcip --hashlimit-name '..
(name or self:uniqueid())
}
}
local ofrags = {}
for _, family in ipairs{'inet', 'inet6'} do
local keys = {}
local maskopts = ''
for attr, opt in pairs{src='src', dest='dst'} do
local mask = self.mask[family][attr]
if mask > 0 then
table.insert(keys, opt..'ip')
maskopts = maskopts..' --hashlimit-'..opt..'mask '..mask
end
end
table.insert(
ofrags,
{
family=family,
opts=keys[1] and
'-m hashlimit --hashlimit-upto '..rate..
'/second --hashlimit-burst '..rate..' --hashlimit-mode '..
table.concat(keys, ',')..maskopts..' --hashlimit-name '..
(name or self:uniqueid()) or
'-m limit --limit '..rate..'/second'
}
)
end
return ofrags
end
......
......@@ -102,6 +102,19 @@ function M.compare(a, b)
return true
end
function M.join(a, sep, b)
local comps = {}
local function add(s)
if not s then return end
s = tostring(s)
if s > '' then table.insert(comps, s) end
end
add(a)
add(b)
if comps[1] then return table.concat(comps, sep) end
end
function M.printtabulars(tables)
local colwidth = {}
for i, tbl in ipairs(tables) do
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment