diff --git a/awall/init.lua b/awall/init.lua index 09c35c166029f1a1f465999b76b021106e68d07b..1bd3b8dc060bd164faa5e5003ed2979c52976a77 100644 --- a/awall/init.lua +++ b/awall/init.lua @@ -36,9 +36,9 @@ end local function readconfig() - config = {} - awall.model.reset() - awall.iptables.reset() + local config = {} + local iptables = awall.iptables.new() + local context = {input=config, iptables=iptables} for i, dir in ipairs(confdirs) do local fnames = {} @@ -93,7 +93,7 @@ local function readconfig() function insertrule(trule) - local t = awall.iptables.config[trule.family][trule.table][trule.chain] + local t = iptables.config[trule.family][trule.table][trule.chain] if trule.position == 'prepend' then table.insert(t, 1, trule.opts) else @@ -106,7 +106,8 @@ local function readconfig() for i, mod in ipairs(modules) do for path, cls in pairs(mod.classmap) do if config[path] then - awall.util.map(config[path], cls.morph) + awall.util.map(config[path], + function(obj) return cls.morph(obj, context) end) table.insert(locations, config[path]) end end @@ -120,16 +121,20 @@ local function readconfig() for i, trule in ipairs(rule:trules()) do insertrule(trule) end end end + + context.ipset = awall.ipset.new(config.ipset) + + return context end function dump() - readconfig() - awall.ipset.dump(ipsfile) - awall.iptables.dump(iptdir) + local context = readconfig() + context.ipset:dump(ipsfile) + context.iptables:dump(iptdir) end function test() - readconfig() - awall.ipset.create() - awall.iptables.test() + local context = readconfig() + context.ipset:create() + context.iptables:test() end diff --git a/awall/ipset.lua b/awall/ipset.lua index baff4042e313248815b690ea045686af88c9fa1a..26253c87c2db98f839007ab3916dacd7803bce94 100644 --- a/awall/ipset.lua +++ b/awall/ipset.lua @@ -7,11 +7,18 @@ Licensed under the terms of GPL2 module(..., package.seeall) -local function commands() - local config = awall.config +local IPSet = {} + +function new(config) + local res = {config=config} + setmetatable(res, {__index=IPSet}) + return res +end + +function IPSet:commands() local res = {} - if config.ipset then - for name, params in pairs(config.ipset) do + if self.config then + for name, params in pairs(self.config) do if not params.type then error('Type not defined for set '..name) end local line = 'create '..name..' '..params.type if params.family then line = line..' family '..params.family end @@ -21,8 +28,8 @@ local function commands() return res end -function create() - for i, line in ipairs(commands()) do +function IPSet:create() + for i, line in ipairs(self:commands()) do local pid, stdin = lpc.run('ipset', '-!', 'restore') stdin:write(line) stdin:close() @@ -32,8 +39,8 @@ function create() end end -function dump(ipsfile) +function IPSet:dump(ipsfile) local file = io.output(ipsfile) - for i, line in ipairs(commands()) do file:write(line) end + for i, line in ipairs(self:commands()) do file:write(line) end file:close() end diff --git a/awall/iptables.lua b/awall/iptables.lua index b8b6b13b3fdf74a66a9e050285947c5bf10aa800..67ed7a160fbaefb9f77c79f678515c6b6ba5dbc2 100644 --- a/awall/iptables.lua +++ b/awall/iptables.lua @@ -18,20 +18,25 @@ local families = {inet={cmd='iptables-restore', file='rules-save'}, local builtin = {'INPUT', 'FORWARD', 'OUTPUT', 'PREROUTING', 'POSTROUTING'} -function reset() - config = {} +local IPTables = {} + +function new() + local config = {} setmetatable(config, {__index=function(t, k) t[k] = {} setmetatable(t[k], getmetatable(t)) return t[k] end}) + + local res = {config=config} + setmetatable(res, {__index=IPTables}) + return res end -reset() -local function dumpfile(family, iptfile) +function IPTables:dumpfile(family, iptfile) iptfile:write('# '..families[family].file..' generated by awall\n') - for tbl, chains in pairs(config[family]) do + for tbl, chains in pairs(self.config[family]) do iptfile:write('*'..tbl..'\n') for chain, rules in pairs(chains) do iptfile:write(':'..chain..' '..(contains(builtin, chain) and @@ -46,17 +51,17 @@ local function dumpfile(family, iptfile) end end -function test() - for family, tbls in pairs(config) do +function IPTables:test() + for family, tbls in pairs(self.config) do local pid, stdin = lpc.run(families[family].cmd, '-t') - dumpfile(family, stdin) + self:dumpfile(family, stdin) stdin:close() assert(lpc.wait(pid) == 0) end end -function dump(dir) - for family, tbls in pairs(config) do - dumpfile(family, io.output(dir..'/'..families[family].file)) +function IPTables:dump(dir) + for family, tbls in pairs(self.config) do + self:dumpfile(family, io.output(dir..'/'..families[family].file)) end end diff --git a/awall/model.lua b/awall/model.lua index 403b617f239ecba51d2171a34511099d72860aca..28e34dba7551aea223215761f3e12534d37336a4 100644 --- a/awall/model.lua +++ b/awall/model.lua @@ -28,8 +28,14 @@ function class(base) return inst end - function cls:morph() + function cls:morph(context) setmetatable(self, mt) + + if context then + self.context = context + self.root = context.input + end + self:init() end @@ -78,13 +84,11 @@ Rule = class(Object) function Rule:init() - local config = awall.config - for i, prop in ipairs({'in', 'out'}) do self[prop] = self[prop] and util.maplist(self[prop], function(z) return z == '_fw' and fwzone or - config.zone[z] or + self.root.zone[z] or error('Invalid zone: '..z) end) or self:defaultzones() end @@ -93,7 +97,7 @@ function Rule:init() if type(self.service) == 'string' then self.label = self.service end self.service = util.maplist(self.service, function(s) - return config.service[s] or error('Invalid service: '..s) + return self.root.service[s] or error('Invalid service: '..s) end) end end @@ -280,7 +284,7 @@ function Rule:trules() for i, ipset in util.listpairs(self.ipset) do if not ipset.name then error('Set name not defined') end - local setdef = awall.config.ipset and awall.config.ipset[ipset.name] + local setdef = self.root.ipset and self.root.ipset[ipset.name] if not setdef then error('Invalid set name') end if not ipset.args then @@ -352,18 +356,16 @@ end function Rule:extraoptfrags() return {} end -local lastid = {} function Rule:newchain(base) + if not self.context.lastid then self.context.lastid = {} end + local lastid = self.context.lastid + if self.label then base = base..'-'..self.label end if not lastid[base] then lastid[base] = -1 end lastid[base] = lastid[base] + 1 return base..'-'..lastid[base] end -function reset() - lastid = {} -end - classmap = {zone=Zone}