diff --git a/spec/types_spec.lua b/spec/types_spec.lua index 9d87380f..0adca7c1 100644 --- a/spec/types_spec.lua +++ b/spec/types_spec.lua @@ -8,41 +8,41 @@ local types = require "pallene.types" describe("Pallene types", function() it("pretty-prints types", function() - assert.same("{ integer }", types.tostring(types.T.Array(types.T.Integer()))) + assert.same("{ integer }", types.tostring(types.T.Array(types.T.Integer))) assert.same("{ x: float, y: float }", types.tostring( - types.T.Table({x = types.T.Float(), y = types.T.Float()}))) + types.T.Table({x = types.T.Float, y = types.T.Float}))) end) it("is_gc works", function() - assert.falsy(types.is_gc(types.T.Integer())) - assert.truthy(types.is_gc(types.T.String())) - assert.truthy(types.is_gc(types.T.Array(types.T.Integer()))) - assert.truthy(types.is_gc(types.T.Table({x = types.T.Float()}))) + assert.falsy(types.is_gc(types.T.Integer)) + assert.truthy(types.is_gc(types.T.String)) + assert.truthy(types.is_gc(types.T.Array(types.T.Integer))) + assert.truthy(types.is_gc(types.T.Table({x = types.T.Float}))) assert.truthy(types.is_gc(types.T.Function({}, {}))) end) describe("equality", function() it("works for primitive types", function() - assert.truthy(types.equals(types.T.Integer(), types.T.Integer())) - assert.falsy(types.equals(types.T.Integer(), types.T.String())) + assert.truthy(types.equals(types.T.Integer, types.T.Integer)) + assert.falsy(types.equals(types.T.Integer, types.T.String)) end) it("is true for two identical tables", function() local t1 = types.T.Table({ - y = types.T.Integer(), x = types.T.Integer()}) + y = types.T.Integer, x = types.T.Integer}) local t2 = types.T.Table({ - x = types.T.Integer(), y = types.T.Integer()}) + x = types.T.Integer, y = types.T.Integer}) assert.truthy(types.equals(t1, t2)) assert.truthy(types.equals(t2, t1)) end) it("is false for tables with different number of fields", function() - local t1 = types.T.Table({x = types.T.Integer()}) - local t2 = types.T.Table({x = types.T.Integer(), - y = types.T.Integer()}) - local t3 = types.T.Table({x = types.T.Integer(), - y = types.T.Integer(), z = types.T.Integer()}) + local t1 = types.T.Table({x = types.T.Integer}) + local t2 = types.T.Table({x = types.T.Integer, + y = types.T.Integer}) + local t3 = types.T.Table({x = types.T.Integer, + y = types.T.Integer, z = types.T.Integer}) assert.falsy(types.equals(t1, t2)) assert.falsy(types.equals(t2, t1)) assert.falsy(types.equals(t2, t3)) @@ -52,39 +52,39 @@ describe("Pallene types", function() end) it("is false for tables with different field names", function() - local t1 = types.T.Table({x = types.T.Integer()}) - local t2 = types.T.Table({y = types.T.Integer()}) + local t1 = types.T.Table({x = types.T.Integer}) + local t2 = types.T.Table({y = types.T.Integer}) assert.falsy(types.equals(t1, t2)) assert.falsy(types.equals(t2, t1)) end) it("is false for tables with different field types", function() - local t1 = types.T.Table({x = types.T.Integer()}) - local t2 = types.T.Table({x = types.T.Float()}) + local t1 = types.T.Table({x = types.T.Integer}) + local t2 = types.T.Table({x = types.T.Float}) assert.falsy(types.equals(t1, t2)) assert.falsy(types.equals(t2, t1)) end) it("is true for identical functions", function() - local f1 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()}) - local f2 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()}) + local f1 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean}) + local f2 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean}) assert.truthy(types.equals(f1, f2)) end) it("is false for functions with different input types", function() - local f1 = types.T.Function({types.T.String(), types.T.Boolean()}, {types.T.Boolean()}) - local f2 = types.T.Function({types.T.Integer(), types.T.Integer()}, {types.T.Boolean()}) + local f1 = types.T.Function({types.T.String, types.T.Boolean}, {types.T.Boolean}) + local f2 = types.T.Function({types.T.Integer, types.T.Integer}, {types.T.Boolean}) assert.falsy(types.equals(f1, f2)) end) it("is false for functions with different output types", function() - local f1 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()}) - local f2 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Integer()}) + local f1 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean}) + local f2 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Integer}) assert.falsy(types.equals(f1, f2)) end) it("is false for functions with different input arity", function() - local s = types.T.String() + local s = types.T.String local f1 = types.T.Function({}, {s}) local f2 = types.T.Function({s}, {s}) local f3 = types.T.Function({s, s}, {s}) @@ -97,7 +97,7 @@ describe("Pallene types", function() end) it("is false for functions with different output arity", function() - local s = types.T.String() + local s = types.T.String local f1 = types.T.Function({s}, {}) local f2 = types.T.Function({s}, {s}) local f3 = types.T.Function({s}, {s, s}) @@ -123,42 +123,42 @@ describe("Pallene types", function() describe("consistency", function() it("allows 'any' on either side", function() - assert.truthy(types.consistent(types.T.Any(), types.T.Any())) - assert.truthy(types.consistent(types.T.Any(), types.T.Integer())) - assert.truthy(types.consistent(types.T.Integer(), types.T.Any())) + assert.truthy(types.consistent(types.T.Any, types.T.Any)) + assert.truthy(types.consistent(types.T.Any, types.T.Integer)) + assert.truthy(types.consistent(types.T.Integer, types.T.Any)) end) it("allows types with same tag", function() assert.truthy(types.consistent( - types.T.Integer(), - types.T.Integer() + types.T.Integer, + types.T.Integer )) assert.truthy(types.consistent( - types.T.Array(types.T.Integer()), - types.T.Array(types.T.Integer()) + types.T.Array(types.T.Integer), + types.T.Array(types.T.Integer) )) assert.truthy(types.consistent( - types.T.Array(types.T.Integer()), - types.T.Array(types.T.String()) + types.T.Array(types.T.Integer), + types.T.Array(types.T.String) )) assert.truthy(types.consistent( - types.T.Function({types.T.Integer()}, {types.T.Integer()}), - types.T.Function({types.T.String(), types.T.String()}, {}) + types.T.Function({types.T.Integer}, {types.T.Integer}), + types.T.Function({types.T.String, types.T.String}, {}) )) end) it("forbids different tags", function() assert.falsy(types.consistent( - types.T.Integer(), - types.T.String() + types.T.Integer, + types.T.String )) assert.falsy(types.consistent( - types.T.Array(types.T.Integer()), - types.T.Function({types.T.Integer()},{types.T.Integer()}) + types.T.Array(types.T.Integer), + types.T.Function({types.T.Integer},{types.T.Integer}) )) end) end) diff --git a/src/pallene/builtins.lua b/src/pallene/builtins.lua index a433026b..ba44279b 100644 --- a/src/pallene/builtins.lua +++ b/src/pallene/builtins.lua @@ -10,38 +10,38 @@ local builtins = {} -- TODO: It will be easier to read this is we could write down the types using the normal grammar -local ipairs_itertype = T.Function({T.Any(), T.Any()}, {T.Any(), T.Any()}) +local ipairs_itertype = T.Function({T.Any, T.Any}, {T.Any, T.Any}) builtins.functions = { - type = T.Function({ T.Any() }, { T.String() }), - tostring = T.Function({ T.Any() }, { T.String() }), - ipairs = T.Function({T.Array(T.Any())}, {ipairs_itertype, T.Any(), T.Any()}) + type = T.Function({ T.Any }, { T.String }), + tostring = T.Function({ T.Any }, { T.String }), + ipairs = T.Function({T.Array(T.Any)}, {ipairs_itertype, T.Any, T.Any}) } builtins.modules = { io = { - write = T.Function({ T.String() }, {}), + write = T.Function({ T.String }, {}), }, math = { - abs = T.Function({ T.Float() }, { T.Float() }), - ceil = T.Function({ T.Float() }, { T.Integer() }), - floor = T.Function({ T.Float() }, { T.Integer() }), - fmod = T.Function({ T.Float(), T.Float() }, { T.Float() }), - exp = T.Function({ T.Float() }, { T.Float() }), - ln = T.Function({ T.Float() }, { T.Float() }), - log = T.Function({ T.Float(), T.Float() }, { T.Float() }), - modf = T.Function({ T.Float() }, { T.Integer(), T.Float() }), - pow = T.Function({ T.Float(), T.Float() }, { T.Float() }), - sqrt = T.Function({ T.Float() }, { T.Float() }), + abs = T.Function({ T.Float }, { T.Float }), + ceil = T.Function({ T.Float }, { T.Integer }), + floor = T.Function({ T.Float }, { T.Integer }), + fmod = T.Function({ T.Float, T.Float }, { T.Float }), + exp = T.Function({ T.Float }, { T.Float }), + ln = T.Function({ T.Float }, { T.Float }), + log = T.Function({ T.Float, T.Float }, { T.Float }), + modf = T.Function({ T.Float }, { T.Integer, T.Float }), + pow = T.Function({ T.Float, T.Float }, { T.Float }), + sqrt = T.Function({ T.Float }, { T.Float }), -- constant numbers - huge = T.Float(), - mininteger = T.Integer(), - maxinteger = T.Integer(), - pi = T.Float(), + huge = T.Float, + mininteger = T.Integer, + maxinteger = T.Integer, + pi = T.Float, }, string = { - char = T.Function({ T.Integer() }, { T.String() }), - sub = T.Function({ T.String(), T.Integer(), T.Integer() }, { T.String() }), + char = T.Function({ T.Integer }, { T.String }), + sub = T.Function({ T.String, T.Integer, T.Integer }, { T.String }), }, } diff --git a/src/pallene/coder.lua b/src/pallene/coder.lua index f0bb5751..72cb9df7 100644 --- a/src/pallene/coder.lua +++ b/src/pallene/coder.lua @@ -357,7 +357,7 @@ function Coder:c_value(value) return C.float(value.value) elseif tag == "ir.Value.String" then local str = value.value - return lua_value(types.T.String(), self:string_upvalue_slot(str)) + return lua_value(types.T.String, self:string_upvalue_slot(str)) elseif tag == "ir.Value.LocalVar" then return self:c_var(value.id) elseif tag == "ir.Value.Upvalue" then @@ -963,10 +963,12 @@ end -- the stack top before a GC point so the GC can look at the right set of variables. We also need -- to do it before function calls because the stack-gowing logic relies on having the right "top". -function Coder:update_stack_top(func, cmd) +function Coder:update_stack_top(cmd_position) + local gc_info = self.gc[self.current_func] + local live_vars = gc_info.live_gc_vars[cmd_position.block_index][cmd_position.cmd_index] local offset = 0 - for _, v_id in ipairs(self.gc[func].live_gc_vars[cmd]) do - local slot = self.gc[func].slot_of_variable[v_id] + for _, v_id in ipairs(live_vars) do + local slot = gc_info.slot_of_variable[v_id] offset = math.max(offset, slot + 1) end return util.render("L->top.p = base + $offset;", { offset = C.integer(offset) }) @@ -986,15 +988,15 @@ end local gen_cmd = {} -gen_cmd["Move"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local src = self:c_value(cmd.src) +gen_cmd["Move"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local src = self:c_value(args.cmd.src) return (util.render([[ $dst = $src; ]], { dst = dst, src = src })) end -gen_cmd["Unop"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local x = self:c_value(cmd.src) +gen_cmd["Unop"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local x = self:c_value(args.cmd.src) -- For when we can directly translate to a C operator: local function unop(op) @@ -1012,8 +1014,8 @@ gen_cmd["Unop"] = function(self, cmd, _func) ${check_no_metatable} $dst = luaH_getn($x); ]], { - check_no_metatable = check_no_metatable(self, x, cmd.loc), - line = C.integer(cmd.loc.line), + check_no_metatable = check_no_metatable(self, x, args.cmd.loc), + line = C.integer(args.cmd.loc.line), dst = dst, x = x })) @@ -1024,7 +1026,7 @@ gen_cmd["Unop"] = function(self, cmd, _func) dst = dst, x = x })) end - local op = cmd.op + local op = args.cmd.op if op == "ArrLen" then return arr_len() elseif op == "StrLen" then return str_len() elseif op == "IntNeg" then return int_neg() @@ -1036,10 +1038,10 @@ gen_cmd["Unop"] = function(self, cmd, _func) end end -gen_cmd["Binop"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local x = self:c_value(cmd.src1) - local y = self:c_value(cmd.src2) +gen_cmd["Binop"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local x = self:c_value(args.cmd.src1) + local y = self:c_value(args.cmd.src2) -- For when we can be directly translate to a C operator: local function binop(op) @@ -1061,7 +1063,7 @@ gen_cmd["Binop"] = function(self, cmd, _func) -- For integer division and modulus: local function int_division(fname) - local line = cmd.loc.line + local line = args.cmd.loc.line return (util.render([[ $dst = $fname(L, $x, $y, PALLENE_SOURCE_FILE, $line); ]], { fname = fname, dst = dst, @@ -1103,7 +1105,7 @@ gen_cmd["Binop"] = function(self, cmd, _func) dst = dst, x = x, y = y, op = op })) end - local op = cmd.op + local op = args.cmd.op if op == "IntAdd" then return int_binop("+") elseif op == "IntSub" then return int_binop("-") elseif op == "IntMul" then return int_binop("*") @@ -1159,11 +1161,11 @@ gen_cmd["Binop"] = function(self, cmd, _func) end end -gen_cmd["Concat"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) +gen_cmd["Concat"] = function(self, args) + local dst = self:c_var(args.cmd.dst) local init_input_array = {} - for ix, srcv in ipairs(cmd.srcs) do + for ix, srcv in ipairs(args.cmd.srcs) do local src = self:c_value(srcv) table.insert(init_input_array, util.render([[ ss[$i] = $src; ]], { @@ -1180,60 +1182,60 @@ gen_cmd["Concat"] = function(self, cmd, _func) } ]], { dst = dst, - N = C.integer(#cmd.srcs), + N = C.integer(#args.cmd.srcs), init_input_array = table.concat(init_input_array, "\n"), })) end -gen_cmd["ToFloat"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local v = self:c_value(cmd.src) +gen_cmd["ToFloat"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local v = self:c_value(args.cmd.src) return util.render([[ $dst = (lua_Number) $v; ]], { dst = dst, v = v }) end -gen_cmd["ToDyn"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local src = self:c_value(cmd.src) - local src_typ = cmd.src_typ +gen_cmd["ToDyn"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local src = self:c_value(args.cmd.src) + local src_typ = args.cmd.src_typ return (set_stack_slot(src_typ, "&"..dst, src)) end -gen_cmd["FromDyn"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local src = self:c_value(cmd.src) - local dst_typ = cmd.dst_typ +gen_cmd["FromDyn"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local src = self:c_value(args.cmd.src) + local dst_typ = args.cmd.dst_typ return self:get_stack_slot(dst_typ, dst, "&"..src, - cmd.loc, "downcasted value") + args.cmd.loc, "downcasted value") end -gen_cmd["IsTruthy"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local src = self:c_value(cmd.src) +gen_cmd["IsTruthy"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local src = self:c_value(args.cmd.src) return (util.render([[ $dst = pallene_is_truthy(&$src); ]], { dst = dst, src = src })) end -gen_cmd["IsNil"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local src = self:c_value(cmd.src) +gen_cmd["IsNil"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local src = self:c_value(args.cmd.src) return (util.render([[ $dst = ttisnil(&$src); ]], { dst = dst, src = src })) end -gen_cmd["NewArr"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local n = self:c_value(cmd.src_size) +gen_cmd["NewArr"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local n = self:c_value(args.cmd.src_size) return (util.render([[ $dst = pallene_createtable(L, $n, 0); ]], { dst = dst, n = n, })) end -gen_cmd["GetArr"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local arr = self:c_value(cmd.src_arr) - local i = self:c_value(cmd.src_i) - local dst_typ = cmd.dst_typ - local line = C.integer(cmd.loc.line) +gen_cmd["GetArr"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local arr = self:c_value(args.cmd.src_arr) + local i = self:c_value(args.cmd.src_i) + local dst_typ = args.cmd.dst_typ + local line = C.integer(args.cmd.loc.line) return (util.render([[ { @@ -1246,16 +1248,16 @@ gen_cmd["GetArr"] = function(self, cmd, _func) i = i, line = line, get_slot = self:get_luatable_slot(dst_typ, dst, "slot", arr, - cmd.loc, "array element"), + args.cmd.loc, "array element"), })) end -gen_cmd["SetArr"] = function(self, cmd, _func) - local arr = self:c_value(cmd.src_arr) - local i = self:c_value(cmd.src_i) - local v = self:c_value(cmd.src_v) - local src_typ = cmd.src_typ - local line = C.integer(cmd.loc.line) +gen_cmd["SetArr"] = function(self, args) + local arr = self:c_value(args.cmd.src_arr) + local i = self:c_value(args.cmd.src_i) + local v = self:c_value(args.cmd.src_v) + local src_typ = args.cmd.src_typ + local line = C.integer(args.cmd.loc.line) return (util.render([[ { pallene_renormalize_array(L, $arr, $i, PALLENE_SOURCE_FILE, $line); @@ -1271,23 +1273,23 @@ gen_cmd["SetArr"] = function(self, cmd, _func) })) end -gen_cmd["NewTable"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local n = self:c_value(cmd.src_size) +gen_cmd["NewTable"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local n = self:c_value(args.cmd.src_size) return (util.render([[ $dst = pallene_createtable(L, 0, $n); ]], { dst = dst, n = n, })) end -gen_cmd["GetTable"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dst) - local tab = self:c_value(cmd.src_tab) - local key = self:c_value(cmd.src_k) - local dst_typ = cmd.dst_typ +gen_cmd["GetTable"] = function(self, args) + local dst = self:c_var(args.cmd.dst) + local tab = self:c_value(args.cmd.src_tab) + local key = self:c_value(args.cmd.src_k) + local dst_typ = args.cmd.dst_typ - assert(cmd.src_k._tag == "ir.Value.String") - local field_name = cmd.src_k.value + assert(args.cmd.src_k._tag == "ir.Value.String") + local field_name = args.cmd.src_k.value return util.render([[ { @@ -1299,18 +1301,18 @@ gen_cmd["GetTable"] = function(self, cmd, _func) field_len = tostring(#field_name), tab = tab, key = key, - get_slot = self:get_luatable_slot(dst_typ, dst, "slot", tab, cmd.loc, "table field"), + get_slot = self:get_luatable_slot(dst_typ, dst, "slot", tab, args.cmd.loc, "table field"), }) end -gen_cmd["SetTable"] = function(self, cmd, _func) - local tab = self:c_value(cmd.src_tab) - local key = self:c_value(cmd.src_k) - local val = self:c_value(cmd.src_v) - local src_typ = cmd.src_typ +gen_cmd["SetTable"] = function(self, args) + local tab = self:c_value(args.cmd.src_tab) + local key = self:c_value(args.cmd.src_k) + local val = self:c_value(args.cmd.src_v) + local src_typ = args.cmd.src_typ - assert(cmd.src_k._tag == "ir.Value.String") - local field_name = cmd.src_k.value + assert(args.cmd.src_k._tag == "ir.Value.String") + local field_name = args.cmd.src_k.value return util.render([[ { @@ -1330,7 +1332,7 @@ gen_cmd["SetTable"] = function(self, cmd, _func) tab = tab, key = key, val = val, - init_keyv = set_stack_slot(types.T.String(), "&keyv", key), + init_keyv = set_stack_slot(types.T.String, "&keyv", key), init_valv = set_stack_slot(src_typ, "&valv", val), -- Here we use set_stack_slot slot on a heap object, because -- we call the barrier by hand outside the if statement. @@ -1339,22 +1341,22 @@ gen_cmd["SetTable"] = function(self, cmd, _func) }) end -gen_cmd["NewRecord"] = function(self, cmd, _func) - local rc = self.record_coders[cmd.rec_typ] - local rec = self:c_var(cmd.dst) +gen_cmd["NewRecord"] = function(self, args) + local rc = self.record_coders[args.cmd.rec_typ] + local rec = self:c_var(args.cmd.dst) return (util.render([[$rec = $constructor(L, K);]] , { rec = rec, constructor = rc:constructor_name(), })) end -gen_cmd["GetField"] = function(self, cmd, _func) - local rec_typ = cmd.rec_typ +gen_cmd["GetField"] = function(self, args) + local rec_typ = args.cmd.rec_typ local rc = self.record_coders[rec_typ] - local dst = self:c_var(cmd.dst) - local rec = self:c_value(cmd.src_rec) - local field_name = cmd.field_name + local dst = self:c_var(args.cmd.dst) + local rec = self:c_value(args.cmd.src_rec) + local field_name = args.cmd.field_name local f_typ = rec_typ.field_types[field_name] if types.is_gc(f_typ) then @@ -1370,13 +1372,13 @@ gen_cmd["GetField"] = function(self, cmd, _func) end end -gen_cmd["SetField"] = function(self, cmd, _func) - local rec_typ = cmd.rec_typ +gen_cmd["SetField"] = function(self, args) + local rec_typ = args.cmd.rec_typ local rc = self.record_coders[rec_typ] - local rec = self:c_value(cmd.src_rec) - local v = self:c_value(cmd.src_v) - local field_name = cmd.field_name + local rec = self:c_value(args.cmd.src_rec) + local v = self:c_value(args.cmd.src_v) + local field_name = args.cmd.field_name local f_typ = rec_typ.field_types[field_name] if types.is_gc(f_typ) then @@ -1388,8 +1390,8 @@ gen_cmd["SetField"] = function(self, cmd, _func) end end -gen_cmd["NewClosure"] = function (self, cmd, _func) - local func = self.module.functions[cmd.f_id] +gen_cmd["NewClosure"] = function (self, args) + local func = self.module.functions[args.cmd.f_id] -- The number of upvalues must fit inside a byte (the nupvalues in the ClosureHeader). -- However, we must check this limit ourselves, because luaF_newCclosure doesn't. If we have too @@ -1406,19 +1408,19 @@ gen_cmd["NewClosure"] = function (self, cmd, _func) } ]], { num_upvalues = C.integer(num_upvalues), - dst = self:c_var(cmd.dst), - lua_entry_point = self:lua_entry_point_name(cmd.f_id), + dst = self:c_var(args.cmd.dst), + lua_entry_point = self:lua_entry_point_name(args.cmd.f_id), }) end -gen_cmd["InitUpvalues"] = function(self, cmd, _func) - local func = self.module.functions[cmd.f_id] +gen_cmd["InitUpvalues"] = function(self, args) + local func = self.module.functions[args.cmd.f_id] - assert(cmd.src_f._tag == "ir.Value.LocalVar") - local cclosure = string.format("clCvalue(&%s)", self:c_var(cmd.src_f.id)) + assert(args.cmd.src_f._tag == "ir.Value.LocalVar") + local cclosure = string.format("clCvalue(&%s)", self:c_var(args.cmd.src_f.id)) local capture_upvalues = {} - for i, val in ipairs(cmd.srcs) do + for i, val in ipairs(args.cmd.srcs) do local typ = func.captured_vars[i].typ local c_val = self:c_value(val) local upvalue_dst = string.format("&(ccl->upvalue[%s])", C.integer(i)) @@ -1442,57 +1444,58 @@ gen_cmd["InitUpvalues"] = function(self, cmd, _func) }) end -gen_cmd["CallStatic"] = function(self, cmd, func) +gen_cmd["CallStatic"] = function(self, args) local dsts = {} - for i, dst in ipairs(cmd.dsts) do + for i, dst in ipairs(args.cmd.dsts) do dsts[i] = dst and self:c_var(dst) end local xs = {} - for _, x in ipairs(cmd.srcs) do + for _, x in ipairs(args.cmd.srcs) do table.insert(xs, self:c_value(x)) end local parts = {} - local f_val = cmd.src_f + local f_val = args.cmd.src_f local f_id, cclosure if f_val._tag == "ir.Value.Upvalue" then - f_id = assert(func.f_id_of_upvalue[f_val.id]) + f_id = assert(args.func.f_id_of_upvalue[f_val.id]) cclosure = string.format("clCvalue(&%s)", self:c_value(f_val)) elseif f_val._tag == "ir.Value.LocalVar" then - f_id = assert(func.f_id_of_local[f_val.id]) + f_id = assert(args.func.f_id_of_local[f_val.id]) cclosure = string.format("clCvalue(&%s)", self:c_value(f_val)) else tagged_union.error(f_val._tag) end - table.insert(parts, self:update_stack_top(func, cmd)) + table.insert(parts, self:update_stack_top(args.position)) table.insert(parts, string.format("PALLENE_SETLINE(%d);\n", - func.loc and func.loc.line or 0)) + args.func.loc and args.func.loc.line or 0)) table.insert(parts, self:call_pallene_function(dsts, f_id, cclosure, xs, nil)) table.insert(parts, self:restorestack()) return table.concat(parts, "\n") end -gen_cmd["CallDyn"] = function(self, cmd, func) - local f_typ = cmd.f_typ +gen_cmd["CallDyn"] = function(self, args) + local f_typ = args.cmd.f_typ local dsts = {} - for i, dst in ipairs(cmd.dsts) do + for i, dst in ipairs(args.cmd.dsts) do dsts[i] = dst and self:c_var(dst) end local push_arguments = {} - table.insert(push_arguments, self:push_to_stack(f_typ, self:c_value(cmd.src_f))) + table.insert(push_arguments, self:push_to_stack(f_typ, self:c_value(args.cmd.src_f))) for i = 1, #f_typ.arg_types do local typ = f_typ.arg_types[i] - table.insert(push_arguments, self:push_to_stack(typ, self:c_value(cmd.srcs[i]))) + table.insert(push_arguments, self:push_to_stack(typ, self:c_value(args.cmd.srcs[i]))) end local pop_results = {} for i = #f_typ.ret_types, 1, -1 do local typ = f_typ.ret_types[i] - local get_slot = self:get_stack_slot(typ, dsts[i], "slot", cmd.loc, "return value #%d", i) + local get_slot = + self:get_stack_slot(typ, dsts[i], "slot", args.cmd.loc, "return value #%d", i) table.insert(pop_results, util.render([[ { L->top.p--; @@ -1505,7 +1508,7 @@ gen_cmd["CallDyn"] = function(self, cmd, func) end local setline = util.render([[ PALLENE_SETLINE($line); ]], { - line = C.integer(func.loc and func.loc.line or 0) + line = C.integer(args.func.loc and args.func.loc.line or 0) }) return util.render([[ @@ -1516,7 +1519,7 @@ gen_cmd["CallDyn"] = function(self, cmd, func) ${pop_results} ${restore_stack} ]], { - update_stack_top = self:update_stack_top(func, cmd), + update_stack_top = self:update_stack_top(args.position), push_arguments = table.concat(push_arguments, "\n"), setline = setline, pop_results = table.concat(pop_results, "\n"), @@ -1526,43 +1529,43 @@ gen_cmd["CallDyn"] = function(self, cmd, func) }) end -gen_cmd["BuiltinIoWrite"] = function(self, cmd, _func) - local v = self:c_value(cmd.srcs[1]) +gen_cmd["BuiltinIoWrite"] = function(self, args) + local v = self:c_value(args.cmd.srcs[1]) return util.render([[ pallene_io_write(L, $v); ]], { v = v }) end -gen_cmd["BuiltinMathAbs"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local v = self:c_value(cmd.srcs[1]) +gen_cmd["BuiltinMathAbs"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local v = self:c_value(args.cmd.srcs[1]) return util.render([[ $dst = l_mathop(fabs)($v); ]], { dst = dst, v = v }) end -gen_cmd["BuiltinMathCeil"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local v = self:c_value(cmd.srcs[1]) - local line = cmd.loc.line +gen_cmd["BuiltinMathCeil"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local v = self:c_value(args.cmd.srcs[1]) + local line = args.cmd.loc.line return util.render([[ $dst = pallene_math_ceil(L, PALLENE_SOURCE_FILE, $line, $v); ]], { dst = dst, v = v, line = C.integer(line) }) end -gen_cmd["BuiltinMathFloor"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local v = self:c_value(cmd.srcs[1]) - local line = cmd.loc.line +gen_cmd["BuiltinMathFloor"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local v = self:c_value(args.cmd.srcs[1]) + local line = args.cmd.loc.line return util.render([[ $dst = pallene_math_floor(L, PALLENE_SOURCE_FILE, $line, $v); ]], { dst = dst, v = v, line = C.integer(line) }) end -gen_cmd["BuiltinMathFmod"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local x = self:c_value(cmd.srcs[1]) - local y = self:c_value(cmd.srcs[2]) +gen_cmd["BuiltinMathFmod"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local x = self:c_value(args.cmd.srcs[1]) + local y = self:c_value(args.cmd.srcs[2]) return util.render([[ $dst = l_mathop(fmod)($x, $y); ]], { dst = dst, x = x, y = y }) end -gen_cmd["BuiltinMathExp"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local v = self:c_value(cmd.srcs[1]) +gen_cmd["BuiltinMathExp"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local v = self:c_value(args.cmd.srcs[1]) return util.render([[ $dst = l_mathop(exp)($v); ]], { dst = dst, v = v }) end @@ -1572,69 +1575,69 @@ end -- But for --emit-lua, we must do something to make the code work in pure Lua. -- For now, the easiest thing to do is inject math.ln = math.log at the top. -- A smarter routine would replace math.ln with math.log. -gen_cmd["BuiltinMathLn"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local v = self:c_value(cmd.srcs[1]) +gen_cmd["BuiltinMathLn"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local v = self:c_value(args.cmd.srcs[1]) return util.render([[ $dst = l_mathop(log)($v); ]], { dst = dst, v = v }) end -gen_cmd["BuiltinMathLog"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local v = self:c_value(cmd.srcs[1]) - local b = self:c_value(cmd.srcs[2]) +gen_cmd["BuiltinMathLog"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local v = self:c_value(args.cmd.srcs[1]) + local b = self:c_value(args.cmd.srcs[2]) return util.render([[ $dst = pallene_math_log($v, $b); ]], { dst = dst, v = v, b = b }) end -gen_cmd["BuiltinMathModf"] = function(self, cmd, _func) - local dst1 = self:c_var(cmd.dsts[1]) - local dst2 = self:c_var(cmd.dsts[2]) - local v = self:c_value(cmd.srcs[1]) - local line = cmd.loc.line +gen_cmd["BuiltinMathModf"] = function(self, args) + local dst1 = self:c_var(args.cmd.dsts[1]) + local dst2 = self:c_var(args.cmd.dsts[2]) + local v = self:c_value(args.cmd.srcs[1]) + local line = args.cmd.loc.line return util.render([[ $dst1 = pallene_math_modf(L, PALLENE_SOURCE_FILE, $line, $v, &$dst2); ]], { dst1 = dst1, dst2 = dst2, v = v, line = C.integer(line) }) end -gen_cmd["BuiltinMathPow"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local x = self:c_value(cmd.srcs[1]) - local y = self:c_value(cmd.srcs[2]) +gen_cmd["BuiltinMathPow"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local x = self:c_value(args.cmd.srcs[1]) + local y = self:c_value(args.cmd.srcs[2]) return util.render([[ $dst = l_mathop(pow)($x, $y); ]], { dst = dst, x = x, y = y }) end -gen_cmd["BuiltinMathSqrt"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local v = self:c_value(cmd.srcs[1]) +gen_cmd["BuiltinMathSqrt"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local v = self:c_value(args.cmd.srcs[1]) return util.render([[ $dst = l_mathop(sqrt)($v); ]], { dst = dst, v = v }) end -gen_cmd["BuiltinStringChar"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local v = self:c_value(cmd.srcs[1]) - local line = cmd.loc.line +gen_cmd["BuiltinStringChar"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local v = self:c_value(args.cmd.srcs[1]) + local line = args.cmd.loc.line return util.render([[ $dst = pallene_string_char(L, PALLENE_SOURCE_FILE, $line, $v); ]], { dst = dst, v = v, line = C.integer(line) }) end -gen_cmd["BuiltinStringSub"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local str = self:c_value(cmd.srcs[1]) - local i = self:c_value(cmd.srcs[2]) - local j = self:c_value(cmd.srcs[3]) +gen_cmd["BuiltinStringSub"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local str = self:c_value(args.cmd.srcs[1]) + local i = self:c_value(args.cmd.srcs[2]) + local j = self:c_value(args.cmd.srcs[3]) return util.render([[ $dst = pallene_string_sub(L, $str, $i, $j); ]], { dst = dst, str = str, i = i, j = j }) end -gen_cmd["BuiltinType"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local v = self:c_value(cmd.srcs[1]) +gen_cmd["BuiltinType"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local v = self:c_value(args.cmd.srcs[1]) return util.render([[ $dst = pallene_type_builtin(L, $v); ]], { dst = dst, v = v }) end -gen_cmd["BuiltinTostring"] = function(self, cmd, _func) - local dst = self:c_var(cmd.dsts[1]) - local v = self:c_value(cmd.srcs[1]) - local line = cmd.loc.line +gen_cmd["BuiltinTostring"] = function(self, args) + local dst = self:c_var(args.cmd.dsts[1]) + local v = self:c_value(args.cmd.srcs[1]) + local line = args.cmd.loc.line return util.render([[ $dst = pallene_tostring(L, PALLENE_SOURCE_FILE, $line, $v); ]], { dst = dst, line = C.integer(line), v = v }) end @@ -1643,8 +1646,8 @@ end -- Control flow -- -gen_cmd["ForPrep"] = function(self, cmd, func) - local typ = func.vars[cmd.dst_i].typ +gen_cmd["ForPrep"] = function(self, args) + local typ = args.func.vars[args.cmd.dst_i].typ local macro if typ._tag == "types.T.Integer" then macro = "PALLENE_INT_FOR_PREP" @@ -1658,18 +1661,18 @@ gen_cmd["ForPrep"] = function(self, cmd, func) ${macro}($i, $cond, $iter, $count, $start, $limit, $step) ]], { macro = macro, - i = self:c_var(cmd.dst_i), - cond = self:c_var(cmd.dst_cond), - iter = self:c_var(cmd.dst_iter), - count = self:c_var(cmd.dst_count), - start = self:c_value(cmd.src_start), - limit = self:c_value(cmd.src_limit), - step = self:c_value(cmd.src_step), + i = self:c_var(args.cmd.dst_i), + cond = self:c_var(args.cmd.dst_cond), + iter = self:c_var(args.cmd.dst_iter), + count = self:c_var(args.cmd.dst_count), + start = self:c_value(args.cmd.src_start), + limit = self:c_value(args.cmd.src_limit), + step = self:c_value(args.cmd.src_step), })) end -gen_cmd["ForStep"] = function(self, cmd, func) - local typ = func.vars[cmd.dst_i].typ +gen_cmd["ForStep"] = function(self, args) + local typ = args.func.vars[args.cmd.dst_i].typ local macro if typ._tag == "types.T.Integer" then @@ -1684,65 +1687,68 @@ gen_cmd["ForStep"] = function(self, cmd, func) ${macro}($i, $cond, $iter, $count, $start, $limit, $step) ]], { macro = macro, - i = self:c_var(cmd.dst_i), - cond = self:c_var(cmd.dst_cond), - iter = self:c_var(cmd.dst_iter), - count = self:c_var(cmd.dst_count), - start = self:c_value(cmd.src_start), - limit = self:c_value(cmd.src_limit), - step = self:c_value(cmd.src_step), + i = self:c_var(args.cmd.dst_i), + cond = self:c_var(args.cmd.dst_cond), + iter = self:c_var(args.cmd.dst_iter), + count = self:c_var(args.cmd.dst_count), + start = self:c_value(args.cmd.src_start), + limit = self:c_value(args.cmd.src_limit), + step = self:c_value(args.cmd.src_step), })) end -gen_cmd["Jmp"] = function(self, cmd, _func, block_id) - if cmd.target ~= block_id + 1 then - return "goto " .. self:c_label(cmd.target) .. ";" - else - return "" -- fallthrough - end +gen_cmd["Jmp"] = function(self, args) + return "goto " .. self:c_label(args.cmd.target) .. ";" end -gen_cmd["JmpIf"] = function(self, cmd, _func, block_id) - local template - if cmd.target_false == block_id + 1 then - template = "if ($v) goto $t;" - elseif cmd.target_true == block_id + 1 then - template = "if (!$v) goto $f;" - else - template = "if ($v) goto $t; else goto $f;" - end - return util.render(template, { - v = self:c_value(cmd.src_cond), - t = self:c_label(cmd.target_true), - f = self:c_label(cmd.target_false), +gen_cmd["JmpIf"] = function(self, args) + return util.render("if($v) {goto $t;} else {goto $f;}", { + v = self:c_value(args.cmd.src_cond), + t = self:c_label(args.cmd.target_true), + f = self:c_label(args.cmd.target_false), }) end -gen_cmd["CheckGC"] = function(self, cmd, func) +gen_cmd["CheckGC"] = function(self, args) return util.render([[ luaC_condGC(L, ${update_stack_top}, (void)0); ]], { - update_stack_top = self:update_stack_top(func, cmd) }) + update_stack_top = self:update_stack_top(args.position) }) end function Coder:generate_blocks(func) local out = {} - for block_id, block in ipairs(func.blocks) do - table.insert(out, self:c_label(block_id)..":\n") - for _,cmd in ipairs(block.cmds) do - table.insert(out, self:generate_cmd(func, block_id, cmd)) - table.insert(out, "\n") + for block_i,block in ipairs(func.blocks) do + table.insert(out, util.render("$label:\n", { + label = self:c_label(block_i), + })) + for cmd_i,cmd in ipairs(block.cmds) do + if cmd._tag ~= "ir.Cmd.Jmp" or cmd.target ~= block_i + 1 then + local gen_args = { + cmd = cmd, + func = func, + position = { + block_index = block_i, + cmd_index = cmd_i, + }, + } + local cmd_str = self:generate_cmd(gen_args) .. "\n" + table.insert(out, cmd_str) + end end end return table.concat(out) end -function Coder:generate_cmd(func, block_id, cmd) +function Coder:generate_cmd(gen_args) + local cmd = gen_args.cmd + local func = gen_args.func assert(tagged_union.typename(cmd._tag) == "ir.Cmd") local name = tagged_union.consname(cmd._tag) local f = assert(gen_cmd[name], "impossible") - local out = f(self, cmd, func, block_id) + local out = f(self, gen_args) + local slot_of_variable = self.gc[func].slot_of_variable for _, v_id in ipairs(ir.get_dsts(cmd)) do - local n = self.gc[func].slot_of_variable[v_id] + local n = slot_of_variable[v_id] if n then local typ = func.vars[v_id].typ local slot = util.render([[s2v(base + $n)]], { n = C.integer(n) }) diff --git a/src/pallene/gc.lua b/src/pallene/gc.lua index 8d11d042..35b6a60d 100644 --- a/src/pallene/gc.lua +++ b/src/pallene/gc.lua @@ -5,6 +5,7 @@ local ir = require "pallene.ir" local types = require "pallene.types" +local tagged_union = require "pallene.tagged_union" -- GARBAGE COLLECTION -- ================== @@ -18,8 +19,8 @@ local types = require "pallene.types" -- have already been saved by the caller. -- -- As an optimization, we don't save values to the Lua stack if the associated variable dies before --- it reaches a potential garbage collection site. The current analysis is pretty simple, and there --- are many ways to make it more precise. So we don't forget, I'm listing some of the ideas here... +-- it reaches a potential garbage collection site. The current implementation uses flow analysis to +-- find live variables. So we don't forget, I'm listing here some ideas to improve the analysis ... -- But it should be said that we don't know if implementing them would be worth the trouble. -- -- 1) Insert fewer checkGC calls in our functions, or move the checkGC calls to places with fewer @@ -28,77 +29,244 @@ local types = require "pallene.types" -- 2) Identify functions that don't call the GC (directly or indirectly) and don't treat calls to -- them as potential GC sites. (Function inlining might mitigate this for small functions) -- --- 3) Use a flow-based liveliness analysis to precisely identify the commands that a variable --- appears live at, instead of approximating with first definition and last use. --- --- 4) Use SSA form or some form of reaching definitions analysis so that we we only need to mirror +-- 3) Use SSA form or some form of reaching definitions analysis so that we we only need to mirror -- the writes that reach a GC site, instead of always mirroring all writes to a variable if one -- of them reaches a GC site. local gc = {} -function gc.compute_stack_slots(func) +local function FlowState() + return { + input = {}, -- set of var_id, live variables at block start + output = {}, -- set of var_id, live variables at block end + kill = {}, -- set of var_id, variables that are killed inside block + gen = {}, -- set of var_id, variables that become live inside block + } +end - local flat_cmds = ir.flatten_cmd(func.blocks) +local function cmd_uses_gc(tag) + assert(tagged_union.typename(tag) == "ir.Cmd") + return tag == "ir.Cmd.CallStatic" or + tag == "ir.Cmd.CallDyn" or + tag == "ir.Cmd.CheckGC" +end - -- 1) Compute approximated live intervals for GC variables defined by the function. Function - -- parameters are only counted if they are redefined, since their original value was already - -- saved by the caller. Also note that we only care about variables, not about upvalues. - -- The latter are already exposed to the GC via the function closures. +local function copy_set(S) + local new_set = {} + for v,_ in pairs(S) do + new_set[v] = true + end + return new_set +end - local defined_variables = {} -- { var_id }, sorted by first definition - local last_use = {} -- { var_id => integer } - local first_definition = {} -- { var_id => integer } +local function flow_analysis(block_list, state_list) + local function apply_gen_kill_sets(flow_state) + local input = flow_state.input + local output = flow_state.output + local gen = flow_state.gen + local kill = flow_state.kill + local in_changed = false - for i, cmd in ipairs(flat_cmds) do - for _, val in ipairs(ir.get_srcs(cmd)) do - if val._tag == "ir.Value.LocalVar" then - local v_id = val.id - last_use[v_id] = i + for v, _ in pairs(output) do + local val = true + if kill[v] then + val = nil end + local previous_val = input[v] + local new_val = previous_val or val + input[v] = new_val + in_changed = in_changed or (previous_val ~= new_val) end - for _, v_id in ipairs(ir.get_dsts(cmd)) do - local typ = func.vars[v_id].typ - if types.is_gc(typ) and not first_definition[v_id] then - first_definition[v_id] = i - table.insert(defined_variables, v_id) + + for v, g in pairs(gen) do + assert(not (g and kill[v]), "gen and kill can't both be true") + local previous_val = input[v] + local new_val = true + input[v] = new_val + in_changed = in_changed or (previous_val ~= new_val) + end + + for v, _ in pairs(input) do + if not output[v] and not gen[v] then + input[v] = nil + in_changed = true end end + + return in_changed end - -- 2) Find which variables are live at each GC spot in the program. - - local live_gc_vars = {} -- { cmd => {var_id}? } - for i, cmd in ipairs(flat_cmds) do - local tag = cmd._tag - if - tag == "ir.Cmd.CallStatic" or - tag == "ir.Cmd.CallDyn" or - tag == "ir.Cmd.CheckGC" - then - live_gc_vars[cmd] = {} - for _, v_id in ipairs(defined_variables) do - local a = first_definition[v_id] - local b = last_use[v_id] - if a and b and a < i and i <= b then - table.insert(live_gc_vars[cmd], v_id) - end + local function merge_live(input, output) + for v, _ in pairs(input) do + output[v] = true + end + end + + local function clear_set(S) + for v,_ in pairs(S) do + S[v] = nil + end + end + + local succ_list = ir.get_successor_list(block_list) + local pred_list = ir.get_predecessor_list(block_list) + local block_order = ir.get_predecessor_depth_search_topological_sort(pred_list) + + local dirty_flag = {} -- { block_id -> bool? } keeps track of modified blocks + for i = 1, #block_list do + dirty_flag[i] = true + end + + local function update_block(block_i) + local block_succs = succ_list[block_i] + local block_preds = pred_list[block_i] + local state = state_list[block_i] + + -- last block's output is supposed to be fixed + if block_i ~= #block_list then + clear_set(state.output) + for _,succ in ipairs(block_succs) do + local succ_in = state_list[succ].input + merge_live(succ_in, state.output) + end + end + + local in_changed = apply_gen_kill_sets(state) + if in_changed then + for _, pred in ipairs(block_preds) do + dirty_flag[pred] = true end end end - local variable_is_live_at_gc = {} -- { var_id => boolean } - for v_id = 1, #func.vars do - variable_is_live_at_gc[v_id] = false + repeat + local found_dirty_block = false + for _,block_i in ipairs(block_order) do + if dirty_flag[block_i] then + found_dirty_block = true + -- CAREFUL: we have to clean the dirty flag BEFORE updating the block or else we + -- will do the wrong thing for auto-referencing blocks + dirty_flag[block_i] = false + update_block(block_i) + end + end + until not found_dirty_block +end + +local function mark_gen_kill(cmd, gen_set, kill_set) + assert(tagged_union.typename(cmd._tag) == "ir.Cmd") + for _, dst in ipairs(ir.get_dsts(cmd)) do + gen_set[dst] = nil + kill_set[dst] = true end - for _, v_ids in pairs(live_gc_vars) do - for _, v_id in ipairs(v_ids) do - variable_is_live_at_gc[v_id] = true + + for _, src in ipairs(ir.get_srcs(cmd)) do + if src._tag == "ir.Value.LocalVar" then + gen_set[src.id] = true + kill_set[src.id] = nil + end + end +end + +local function make_gen_kill_sets(block, flow_state) + for i = #block.cmds, 1, -1 do + local cmd = block.cmds[i] + mark_gen_kill(cmd, flow_state.gen, flow_state.kill) + end +end + +-- Returns information that is used for allocating variables into the Lua stack. +-- The returned data is: +-- * live_gc_vars: +-- for each command, has a list of GC'd variables that are alive during that command. +-- * live_at_same_time: +-- for each GC'd variable, indicates what other GC'd variables are alive at the same time, +-- that is, if both are alive during the same command for some command in the function. +-- * max_frame_size: +-- what's the maximum number of slots of the Lua stack used for storing GC'd variables +-- during the function. +function gc.compute_stack_slots(func) + + local state_list = {} -- { FlowState } + + -- initialize states + for block_i, block in ipairs(func.blocks) do + local fst = FlowState() + make_gen_kill_sets(block, fst) + state_list[block_i] = fst + end + + -- set returned variables to "live" on exit block + if #func.blocks > 0 then + local exit_output = state_list[#func.blocks].output + for _, var in ipairs(func.ret_vars) do + exit_output[var] = true + end + end + + -- 1) Find live variables at the end of each basic block + flow_analysis(func.blocks, state_list) + + -- 2) Find which GC'd variables are live at each GC spot in the program and + -- which GC'd variables are live at the same time + local live_gc_vars = {} -- { block_id => { cmd_id => {var_id}? } } + local live_at_same_time = {} -- { var_id => { var_id => bool? }? } + + -- initialize live_gc_vars + for _, block in ipairs(func.blocks) do + local live_on_cmds = {} + for cmd_i = 1, #block.cmds do + live_on_cmds[cmd_i] = false + end + table.insert(live_gc_vars, live_on_cmds) + end + + for block_i, block in ipairs(func.blocks) do + local lives_block = copy_set(state_list[block_i].output) + -- filter out non-GC'd variables from set + for var_i, _ in pairs(lives_block) do + local var = func.vars[var_i] + if not types.is_gc(var.typ) then + lives_block[var_i] = nil + end + end + for cmd_i = #block.cmds, 1, -1 do + local cmd = block.cmds[cmd_i] + assert(tagged_union.typename(cmd._tag) == "ir.Cmd") + for _, dst in ipairs(ir.get_dsts(cmd)) do + lives_block[dst] = nil + end + for _, src in ipairs(ir.get_srcs(cmd)) do + if src._tag == "ir.Value.LocalVar" then + local typ = func.vars[src.id].typ + if types.is_gc(typ) then + lives_block[src.id] = true + end + end + end + + if cmd_uses_gc(cmd._tag) + then + local lives_cmd = {} + for var,_ in pairs(lives_block) do + table.insert(lives_cmd, var) + end + live_gc_vars[block_i][cmd_i] = lives_cmd + for var1,_ in pairs(lives_block) do + for var2,_ in pairs(lives_block) do + if not live_at_same_time[var1] then + live_at_same_time[var1] = {} + end + live_at_same_time[var1][var2] = true + end + end + end end end -- 3) Allocate variables to Lua stack slots, ensuring that variables with overlapping lifetimes - -- different stack slots. IMPORTANT: stack slots are 0-based. The C we generate prefers that. + -- get different stack slots. IMPORTANT: stack slots are 0-based. The C we generate prefers + -- that. local max_frame_size = 0 local slot_of_variable = {} -- { var_id => integer? } @@ -107,21 +275,22 @@ function gc.compute_stack_slots(func) slot_of_variable[v_id] = false end - local n = 0 - local stack = { } -- { var_id } - for _, v_id in ipairs(defined_variables) do - if variable_is_live_at_gc[v_id] then - local def = first_definition[v_id] - while n > 0 and last_use[stack[n]] <= def do - stack[n] = nil - n = n - 1 + for v1, set in pairs(live_at_same_time) do + local taken_slots = {} -- { stack_slot => bool? } + for v2,_ in pairs(set) do + local v2_slot = slot_of_variable[v2] + if v2_slot then + taken_slots[v2_slot] = true + end + end + for slot = 0, #func.vars do + if not taken_slots[slot] then + slot_of_variable[v1] = slot + max_frame_size = math.max(max_frame_size, slot + 1) + break end - - n = n + 1 - slot_of_variable[v_id] = n-1 - stack[n] = v_id - max_frame_size = math.max(max_frame_size, n) end + assert(slot_of_variable[v1], "should always find a slot") end return { @@ -131,4 +300,5 @@ function gc.compute_stack_slots(func) } end + return gc diff --git a/src/pallene/tagged_union.lua b/src/pallene/tagged_union.lua index e241dccd..3730d470 100644 --- a/src/pallene/tagged_union.lua +++ b/src/pallene/tagged_union.lua @@ -55,7 +55,7 @@ local function make_tag(mod_name, type_name, cons_name) end -- Create a tagged union constructor --- @param module Module table where the type is being defined +-- @param mod_table Module table where the type is being defined -- @param mod_name Name of the module -- @param type_name Name of the type -- @param constructors Name of the constructor => fields of the record @@ -63,18 +63,24 @@ local function define_union(mod_table, mod_name, type_name, constructors) mod_table[type_name] = {} for cons_name, fields in pairs(constructors) do local tag = make_tag(mod_name, type_name, cons_name) - local function cons(...) - local args = table.pack(...) - if args.n ~= #fields then - error(string.format( - "wrong number of arguments for %s. Expected %d but received %d.", - cons_name, #fields, args.n)) - end - local node = { _tag = tag } - for i, field in ipairs(fields) do - node[field] = args[i] + + local cons + if #fields == 0 then + cons = { _tag = tag } + else + cons = function(...) + local args = table.pack(...) + if args.n ~= #fields then + error(string.format( + "wrong number of arguments for %s. Expected %d but received %d.", + cons_name, #fields, args.n)) + end + local node = { _tag = tag } + for i, field in ipairs(fields) do + node[field] = args[i] + end + return node end - return node end mod_table[type_name][cons_name] = cons end diff --git a/src/pallene/to_ir.lua b/src/pallene/to_ir.lua index f2fcb236..9c268694 100644 --- a/src/pallene/to_ir.lua +++ b/src/pallene/to_ir.lua @@ -350,7 +350,7 @@ function ToIR:convert_toplevel(prog_ast) bb:append_cmd(ir.Cmd.NewTable(self.func.loc, self.module.loc_id_of_exports, ir.Value.Integer(n_exports))) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) -- export the functions for _, f_id in ipairs(self.module.exported_functions) do @@ -455,8 +455,8 @@ function ToIR:convert_stat(bb, stat) local count = ir.add_local(self.func, false, v_type) local iter = ir.add_local(self.func, false, v_type) - local cond_enter = ir.add_local(self.func, false, types.T.Boolean()) - local cond_loop = ir.add_local(self.func, false, types.T.Boolean()) + local cond_enter = ir.add_local(self.func, false, types.T.Boolean) + local cond_loop = ir.add_local(self.func, false, types.T.Boolean) local init_for = ir.Cmd.ForPrep( stat.loc, v, cond_enter, iter, count, @@ -526,25 +526,25 @@ function ToIR:convert_stat(bb, stat) -- the table passed as argument to `ipairs` local arr = ipairs_args[1] - assert(types.equals(arr._type, types.T.Array(types.T.Any()))) + assert(types.equals(arr._type, types.T.Array(types.T.Any))) local v_arr = ir.add_local(self.func, "$xs", arr._type) self:exp_to_assignment(bb, v_arr, arr) -- local i_num: integer = 1 - local v_inum = ir.add_local(self.func, "$"..decls[1].name.."_num", types.T.Integer()) + local v_inum = ir.add_local(self.func, "$"..decls[1].name.."_num", types.T.Integer) local start = ir.Value.Integer(1) bb:append_cmd(ir.Cmd.Move(stat.loc, v_inum, start)) local loop_begin = bb:finish_block() -- x_dyn = xs[i_num] - local v_x_dyn = ir.add_local(self.func, "$"..decls[2].name.."_dyn", types.T.Any()) + local v_x_dyn = ir.add_local(self.func, "$"..decls[2].name.."_dyn", types.T.Any) local src_arr = ir.Value.LocalVar(v_arr) local src_i = ir.Value.LocalVar(v_inum) - bb:append_cmd(ir.Cmd.GetArr(stat.loc, types.T.Any(), v_x_dyn, src_arr, src_i)) + bb:append_cmd(ir.Cmd.GetArr(stat.loc, types.T.Any, v_x_dyn, src_arr, src_i)) -- if x_dyn == nil then break end - local cond_checknil = ir.add_local(self.func, false, types.T.Boolean()) + local cond_checknil = ir.add_local(self.func, false, types.T.Boolean) bb:append_cmd(ir.Cmd.IsNil(stat.loc, cond_checknil, ir.Value.LocalVar(v_x_dyn))) step_test_jmpIf= bb:append_cmd( ir.Cmd.JmpIf(stat.loc, ir.Value.LocalVar(cond_checknil), nil, nil)) @@ -556,7 +556,7 @@ function ToIR:convert_stat(bb, stat) if decls[1]._type._tag == "types.T.Integer" then bb:append_cmd(ir.Cmd.Move(stat.loc, v_i, ir.Value.LocalVar(v_inum))) else - bb:append_cmd(ir.Cmd.ToDyn(stat.loc, types.T.Integer(), v_i, ir.Value.LocalVar(v_inum))) + bb:append_cmd(ir.Cmd.ToDyn(stat.loc, types.T.Integer, v_i, ir.Value.LocalVar(v_inum))) end -- local x = x_dyn as T2 @@ -604,7 +604,7 @@ function ToIR:convert_stat(bb, stat) local v_lhs_dyn = {} for _, decl in ipairs(decls) do - local v = ir.add_local(self.func, "$" .. decl.name .. "_dyn", types.T.Any()) + local v = ir.add_local(self.func, "$" .. decl.name .. "_dyn", types.T.Any) table.insert(v_lhs_dyn, v) end @@ -618,7 +618,7 @@ function ToIR:convert_stat(bb, stat) bb:append_cmd(ir.Cmd.CallDyn(exps[1].loc, itertype, v_lhs_dyn, ir.Value.LocalVar(v_iter), args)) -- if i == nil then break end - local cond_checknil = ir.add_local(self.func, false, types.T.Boolean()) + local cond_checknil = ir.add_local(self.func, false, types.T.Boolean) bb:append_cmd(ir.Cmd.IsNil(stat.loc, cond_checknil, ir.Value.LocalVar(v_lhs_dyn[1]))) step_test_jmpIf = bb:append_cmd( ir.Cmd.JmpIf(stat.loc, ir.Value.LocalVar(cond_checknil), nil, nil)) @@ -984,7 +984,7 @@ end function ToIR:exp_to_value(bb, exp, is_recursive) local tag = exp._tag if tag == "ast.Exp.Nil" then - return ir.Value.Nil() + return ir.Value.Nil elseif tag == "ast.Exp.Bool" then return ir.Value.Bool(exp.value) @@ -1071,7 +1071,7 @@ function ToIR:exp_to_assignment(bb, dst, exp) if typ._tag == "types.T.Array" then local n = ir.Value.Integer(#exp.fields) bb:append_cmd(ir.Cmd.NewArr(loc, dst, n)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) for i, field in ipairs(exp.fields) do assert(field._tag == "ast.Field.List") local av = ir.Value.LocalVar(dst) @@ -1084,7 +1084,7 @@ function ToIR:exp_to_assignment(bb, dst, exp) elseif typ._tag == "types.T.Table" then local n = ir.Value.Integer(#exp.fields) bb:append_cmd(ir.Cmd.NewTable(loc, dst, n)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) for _, field in ipairs(exp.fields) do assert(field._tag == "ast.Field.Rec") local tv = ir.Value.LocalVar(dst) @@ -1102,7 +1102,7 @@ function ToIR:exp_to_assignment(bb, dst, exp) end bb:append_cmd(ir.Cmd.NewRecord(loc, typ, dst)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) for _, field_name in ipairs(typ.field_names) do local f_exp = assert(field_exps[field_name]) local dv = ir.Value.LocalVar(dst) @@ -1124,7 +1124,7 @@ function ToIR:exp_to_assignment(bb, dst, exp) assert(typ.is_upvalue_box) bb:append_cmd(ir.Cmd.NewRecord(loc, typ, dst)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) elseif tag == "ast.Exp.Lambda" then local f_id = self:register_lambda(exp, "$lambda") @@ -1221,11 +1221,11 @@ function ToIR:exp_to_assignment(bb, dst, exp) elseif bname == "string.char" then assert(#xs == 1) bb:append_cmd(ir.Cmd.BuiltinStringChar(loc, dsts, xs)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) elseif bname == "string.sub" then assert(#xs == 3) bb:append_cmd(ir.Cmd.BuiltinStringSub(loc, dsts, xs)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) elseif bname == "type" then assert(#xs == 1) bb:append_cmd(ir.Cmd.BuiltinType(loc, dsts, xs)) @@ -1308,7 +1308,7 @@ function ToIR:exp_to_assignment(bb, dst, exp) xs[i] = self:exp_to_value(bb, x_exp) end bb:append_cmd(ir.Cmd.Concat(loc, dst, xs)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) elseif tag == "ast.Exp.Binop" then local op = exp.op @@ -1391,7 +1391,7 @@ function ToIR:value_is_truthy(bb, exp, val) if typ._tag == "types.T.Boolean" then return val elseif typ._tag == "types.T.Any" then - local b = ir.add_local(self.func, false, types.T.Boolean()) + local b = ir.add_local(self.func, false, types.T.Boolean) bb:append_cmd(ir.Cmd.IsTruthy(exp.loc, b, val)) return ir.Value.LocalVar(b) elseif tagged_union.tag_is_type(typ) then diff --git a/src/pallene/typechecker.lua b/src/pallene/typechecker.lua index 954e1122..8912ad40 100644 --- a/src/pallene/typechecker.lua +++ b/src/pallene/typechecker.lua @@ -168,7 +168,7 @@ end function Typechecker:from_ast_type(ast_typ) local tag = ast_typ._tag if tag == "ast.Type.Nil" then - return types.T.Nil() + return types.T.Nil elseif tag == "ast.Type.Name" then local name = ast_typ.name @@ -229,11 +229,11 @@ function Typechecker:check_program(prog_ast) local module_name = prog_ast.module_name -- 1) Add primitive types to the symbol table - self:add_type_symbol("any", types.T.Any()) - self:add_type_symbol("boolean", types.T.Boolean()) - self:add_type_symbol("float", types.T.Float()) - self:add_type_symbol("integer", types.T.Integer()) - self:add_type_symbol("string", types.T.String()) + self:add_type_symbol("any", types.T.Any) + self:add_type_symbol("boolean", types.T.Boolean) + self:add_type_symbol("float", types.T.Float) + self:add_type_symbol("integer", types.T.Integer) + self:add_type_symbol("string", types.T.String) -- 2) Add builtins to symbol table. -- The order does not matter because they are distinct. @@ -247,7 +247,7 @@ function Typechecker:check_program(prog_ast) local id = mod_name .. "." .. fun_name symbols[fun_name] = typechecker.Symbol.Value(typ, typechecker.Def.Builtin(id)) end - local typ = (mod_name == "string") and types.T.String() or false + local typ = (mod_name == "string") and types.T.String or false self:add_module_symbol(mod_name, typ, symbols) end @@ -423,10 +423,10 @@ function Typechecker:check_stat(stat, is_toplevel) local decl_types = {} for _ = 1, #stat.decls do - table.insert(decl_types, types.T.Any()) + table.insert(decl_types, types.T.Any) end - local itertype = types.T.Function({ types.T.Any(), types.T.Any() }, decl_types) + local itertype = types.T.Function({ types.T.Any, types.T.Any }, decl_types) rhs[1] = self:check_exp_synthesize(rhs[1]) local iteratorfn = rhs[1] @@ -712,7 +712,7 @@ function Typechecker:check_var(var) "expected array but found %s in indexed expression", types.tostring(arr_type)) end - var.k = self:check_exp_verify(var.k, types.T.Integer(), "array index") + var.k = self:check_exp_verify(var.k, types.T.Integer, "array index") var._type = arr_type.elem else @@ -793,19 +793,19 @@ function Typechecker:check_exp_synthesize(exp) local tag = exp._tag if tag == "ast.Exp.Nil" then - exp._type = types.T.Nil() + exp._type = types.T.Nil elseif tag == "ast.Exp.Bool" then - exp._type = types.T.Boolean() + exp._type = types.T.Boolean elseif tag == "ast.Exp.Integer" then - exp._type = types.T.Integer() + exp._type = types.T.Integer elseif tag == "ast.Exp.Float" then - exp._type = types.T.Float() + exp._type = types.T.Float elseif tag == "ast.Exp.String" then - exp._type = types.T.String() + exp._type = types.T.String elseif tag == "ast.Exp.InitList" then type_error(exp.loc, "missing type hint for initializer") @@ -827,7 +827,7 @@ function Typechecker:check_exp_synthesize(exp) "trying to take the length of a %s instead of an array or string", types.tostring(t)) end - exp._type = types.T.Integer() + exp._type = types.T.Integer elseif op == "-" then if t._tag ~= "types.T.Integer" and t._tag ~= "types.T.Float" then type_error(exp.loc, @@ -841,10 +841,10 @@ function Typechecker:check_exp_synthesize(exp) "trying to bitwise negate a %s instead of an integer", types.tostring(t)) end - exp._type = types.T.Integer() + exp._type = types.T.Integer elseif op == "not" then check_type_is_condition(exp.exp, "'not' operator") - exp._type = types.T.Boolean() + exp._type = types.T.Boolean else tagged_union.error(op) end @@ -868,7 +868,7 @@ function Typechecker:check_exp_synthesize(exp) "cannot compare %s and %s using %s", types.tostring(t1), types.tostring(t2), op) end - exp._type = types.T.Boolean() + exp._type = types.T.Boolean elseif op == "<" or op == ">" or op == "<=" or op == ">=" then if (t1._tag == "types.T.Integer" and t2._tag == "types.T.Integer") or @@ -886,7 +886,7 @@ function Typechecker:check_exp_synthesize(exp) "cannot compare %s and %s using %s", types.tostring(t1), types.tostring(t2), op) end - exp._type = types.T.Boolean() + exp._type = types.T.Boolean elseif op == "+" or op == "-" or op == "*" or op == "%" or op == "//" then if not is_numeric_type(t1) then @@ -903,11 +903,11 @@ function Typechecker:check_exp_synthesize(exp) if t1._tag == "types.T.Integer" and t2._tag == "types.T.Integer" then - exp._type = types.T.Integer() + exp._type = types.T.Integer else exp.lhs = self:coerce_numeric_exp_to_float(exp.lhs) exp.rhs = self:coerce_numeric_exp_to_float(exp.rhs) - exp._type = types.T.Float() + exp._type = types.T.Float end elseif op == "/" or op == "^" then @@ -924,7 +924,7 @@ function Typechecker:check_exp_synthesize(exp) exp.lhs = self:coerce_numeric_exp_to_float(exp.lhs) exp.rhs = self:coerce_numeric_exp_to_float(exp.rhs) - exp._type = types.T.Float() + exp._type = types.T.Float elseif op == ".." then -- The arguments to '..' must be a strings. We do not allow "any" because Pallene does @@ -935,7 +935,7 @@ function Typechecker:check_exp_synthesize(exp) if t2._tag ~= "types.T.String" then type_error(exp.loc, "cannot concatenate with %s value", types.tostring(t2)) end - exp._type = types.T.String() + exp._type = types.T.String elseif op == "and" or op == "or" then check_type_is_condition(exp.lhs, "first operand of '%s'", op) @@ -953,7 +953,7 @@ function Typechecker:check_exp_synthesize(exp) "right-hand side of bitwise expression is a %s instead of an integer", types.tostring(t2)) end - exp._type = types.T.Integer() + exp._type = types.T.Integer else tagged_union.error(op) @@ -987,7 +987,7 @@ function Typechecker:check_exp_synthesize(exp) elseif tag == "ast.Exp.ToFloat" then assert(exp.exp._type._tag == "types.T.Integer") - exp._type = types.T.Float() + exp._type = types.T.Float else tagged_union.error(tag) diff --git a/src/pallene/uninitialized.lua b/src/pallene/uninitialized.lua index d6e42f77..7ede0165 100644 --- a/src/pallene/uninitialized.lua +++ b/src/pallene/uninitialized.lua @@ -7,15 +7,28 @@ -- initialized and when control flows to the end of a non-void function without returning. Make sure -- that you call ir.clean first, so that it does the right thing in the presence of `while true` -- loops. --- --- `uninit` is the set of variables that are potentially uninitialized. --- `kill` is the set of variables that are initialized at a given block. local ir = require "pallene.ir" local tagged_union = require "pallene.tagged_union" local uninitialized = {} +local function FlowState() + return { + input = {}, -- set of var_id, uninitialized variables at block start + output = {}, -- set of var_id, uninitialized variables at block end + kill = {}, -- set of var_id, variables that are initialized inside block + } +end + +local function copy_set(S) + local new_set = {} + for v,_ in pairs(S) do + new_set[v] = true + end + return new_set +end + local function fill_set(cmd, set, val) assert(tagged_union.typename(cmd._tag) == "ir.Cmd") for _, src in ipairs(ir.get_srcs(cmd)) do @@ -37,41 +50,85 @@ local function fill_set(cmd, set, val) end end -local function flow_analysis(block_list, uninit_sets, kill_sets) - local function merge_uninit(A, B, kill) - local changed = false - for v, _ in pairs(B) do +local function flow_analysis(block_list, state_list) + local function apply_kill_set(flow_state) + local input = flow_state.input + local output = flow_state.output + local kill = flow_state.kill + local out_changed = false + for v, _ in pairs(input) do if not kill[v] then - if not A[v] then - changed = true + if not output[v] then + out_changed = true end - A[v] = true + output[v] = true + end + end + + for v, _ in pairs(output) do + if not input[v] then + output[v] = nil + out_changed = true end end - return changed + return out_changed + end + + local function merge_uninit(input, output) + for v, _ in pairs(output) do + input[v] = true + end + end + + local function clear_set(S) + for v,_ in pairs(S) do + S[v] = nil + end end local succ_list = ir.get_successor_list(block_list) + local pred_list = ir.get_predecessor_list(block_list) local block_order = ir.get_successor_depth_search_topological_sort(succ_list) - local function block_analysis(block_i) + local dirty_flag = {} -- { block_id -> bool? } keeps track of modified blocks + for i = 1, #block_list do + dirty_flag[i] = true + end + + local function update_block(block_i) local block_succs = succ_list[block_i] - local uninit = uninit_sets[block_i] - local kill = kill_sets[block_i] - local changed = false - for _,succ in ipairs(block_succs) do - local c = merge_uninit(uninit_sets[succ], uninit, kill) - changed = c or changed + local block_preds = pred_list[block_i] + local state = state_list[block_i] + + -- first block's input is supposed to be fixed + if block_i ~= 1 then + clear_set(state.input) + for _,pred in ipairs(block_preds) do + local pred_out = state_list[pred].output + merge_uninit(state.input, pred_out) + end + end + + local out_changed = apply_kill_set(state) + if out_changed then + for _, succ in ipairs(block_succs) do + dirty_flag[succ] = true + end end - return changed end repeat - local changed = false + local found_dirty_block = false for _,block_i in ipairs(block_order) do - changed = block_analysis(block_i) or changed + if dirty_flag[block_i] then + found_dirty_block = true + -- CAREFUL: we have to clean the dirty flag BEFORE updating the block or else we + -- will do the wrong thing for auto-referencing blocks + dirty_flag[block_i] = false + update_block(block_i) + end end - until not changed + until not found_dirty_block end local function gen_kill_set(block) @@ -91,36 +148,31 @@ function uninitialized.verify_variables(module) local nvars = #func.vars local nargs = #func.typ.arg_types - -- initialize sets - - -- variables that are initialized inside a given block - local kill_sets = {} -- { block_id -> {var_id -> bool?} } - - -- variables that are uninitialized when entering a given block - local uninit_sets = {} -- { block_id -> {var_id -> bool?} } - for _,b in ipairs(func.blocks) do - local kill = gen_kill_set(b) - table.insert(kill_sets, kill) - table.insert(uninit_sets, {}) + local state_list = {} -- { FlowState } + -- initialize states + for block_i,block in ipairs(func.blocks) do + local fst = FlowState() + fst.kill = gen_kill_set(block) + state_list[block_i] = fst end - local entry_uninit = uninit_sets[1] + local entry_input = state_list[1].input for v_i = nargs+1, nvars do - entry_uninit[v_i] = true + entry_input[v_i] = true end -- solve flow equations - flow_analysis(func.blocks, uninit_sets, kill_sets) + flow_analysis(func.blocks, state_list) -- check for errors local reported_variables = {} -- (only one error message per variable) for block_i, block in ipairs(func.blocks) do - local input_uninit = uninit_sets[block_i] + local uninit = copy_set(state_list[block_i].input) for _, cmd in ipairs(block.cmds) do local loc = cmd.loc - fill_set(cmd, input_uninit, nil) + fill_set(cmd, uninit, nil) for _, src in ipairs(ir.get_srcs(cmd)) do local v = src.id - if src._tag == "ir.Value.LocalVar" and input_uninit[v] then + if src._tag == "ir.Value.LocalVar" and uninit[v] then if not reported_variables[v] then reported_variables[v] = true local name = assert(func.vars[v].name) @@ -132,7 +184,7 @@ function uninitialized.verify_variables(module) end end - local exit_uninit = uninit_sets[#func.blocks] + local exit_uninit = state_list[#func.blocks].output if #func.ret_vars > 0 then local ret1 = func.ret_vars[1] if exit_uninit[ret1] then