From 76a1e1a9eb2d2c965b7333d9e67bd8ceda4dff70 Mon Sep 17 00:00:00 2001 From: Hugo Musso Gualandi Date: Wed, 17 May 2023 23:32:45 -0300 Subject: [PATCH] Nullary tagged union constructors no longer need parens For example, now we can write types.T.Any instead of types.T.Any(). It should have been like this from the beginning, but we couldn't because some of our code was using ir.Cmd objects as table keys. (This was just fixed in a recent PR) --- spec/types_spec.lua | 84 ++++++++++++++++++------------------ src/pallene/builtins.lua | 42 +++++++++--------- src/pallene/coder.lua | 8 ++-- src/pallene/tagged_union.lua | 30 +++++++------ src/pallene/to_ir.lua | 40 ++++++++--------- src/pallene/typechecker.lua | 52 +++++++++++----------- 6 files changed, 131 insertions(+), 125 deletions(-) diff --git a/spec/types_spec.lua b/spec/types_spec.lua index 9d87380f..0adca7c1 100644 --- a/spec/types_spec.lua +++ b/spec/types_spec.lua @@ -8,41 +8,41 @@ local types = require "pallene.types" describe("Pallene types", function() it("pretty-prints types", function() - assert.same("{ integer }", types.tostring(types.T.Array(types.T.Integer()))) + assert.same("{ integer }", types.tostring(types.T.Array(types.T.Integer))) assert.same("{ x: float, y: float }", types.tostring( - types.T.Table({x = types.T.Float(), y = types.T.Float()}))) + types.T.Table({x = types.T.Float, y = types.T.Float}))) end) it("is_gc works", function() - assert.falsy(types.is_gc(types.T.Integer())) - assert.truthy(types.is_gc(types.T.String())) - assert.truthy(types.is_gc(types.T.Array(types.T.Integer()))) - assert.truthy(types.is_gc(types.T.Table({x = types.T.Float()}))) + assert.falsy(types.is_gc(types.T.Integer)) + assert.truthy(types.is_gc(types.T.String)) + assert.truthy(types.is_gc(types.T.Array(types.T.Integer))) + assert.truthy(types.is_gc(types.T.Table({x = types.T.Float}))) assert.truthy(types.is_gc(types.T.Function({}, {}))) end) describe("equality", function() it("works for primitive types", function() - assert.truthy(types.equals(types.T.Integer(), types.T.Integer())) - assert.falsy(types.equals(types.T.Integer(), types.T.String())) + assert.truthy(types.equals(types.T.Integer, types.T.Integer)) + assert.falsy(types.equals(types.T.Integer, types.T.String)) end) it("is true for two identical tables", function() local t1 = types.T.Table({ - y = types.T.Integer(), x = types.T.Integer()}) + y = types.T.Integer, x = types.T.Integer}) local t2 = types.T.Table({ - x = types.T.Integer(), y = types.T.Integer()}) + x = types.T.Integer, y = types.T.Integer}) assert.truthy(types.equals(t1, t2)) assert.truthy(types.equals(t2, t1)) end) it("is false for tables with different number of fields", function() - local t1 = types.T.Table({x = types.T.Integer()}) - local t2 = types.T.Table({x = types.T.Integer(), - y = types.T.Integer()}) - local t3 = types.T.Table({x = types.T.Integer(), - y = types.T.Integer(), z = types.T.Integer()}) + local t1 = types.T.Table({x = types.T.Integer}) + local t2 = types.T.Table({x = types.T.Integer, + y = types.T.Integer}) + local t3 = types.T.Table({x = types.T.Integer, + y = types.T.Integer, z = types.T.Integer}) assert.falsy(types.equals(t1, t2)) assert.falsy(types.equals(t2, t1)) assert.falsy(types.equals(t2, t3)) @@ -52,39 +52,39 @@ describe("Pallene types", function() end) it("is false for tables with different field names", function() - local t1 = types.T.Table({x = types.T.Integer()}) - local t2 = types.T.Table({y = types.T.Integer()}) + local t1 = types.T.Table({x = types.T.Integer}) + local t2 = types.T.Table({y = types.T.Integer}) assert.falsy(types.equals(t1, t2)) assert.falsy(types.equals(t2, t1)) end) it("is false for tables with different field types", function() - local t1 = types.T.Table({x = types.T.Integer()}) - local t2 = types.T.Table({x = types.T.Float()}) + local t1 = types.T.Table({x = types.T.Integer}) + local t2 = types.T.Table({x = types.T.Float}) assert.falsy(types.equals(t1, t2)) assert.falsy(types.equals(t2, t1)) end) it("is true for identical functions", function() - local f1 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()}) - local f2 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()}) + local f1 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean}) + local f2 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean}) assert.truthy(types.equals(f1, f2)) end) it("is false for functions with different input types", function() - local f1 = types.T.Function({types.T.String(), types.T.Boolean()}, {types.T.Boolean()}) - local f2 = types.T.Function({types.T.Integer(), types.T.Integer()}, {types.T.Boolean()}) + local f1 = types.T.Function({types.T.String, types.T.Boolean}, {types.T.Boolean}) + local f2 = types.T.Function({types.T.Integer, types.T.Integer}, {types.T.Boolean}) assert.falsy(types.equals(f1, f2)) end) it("is false for functions with different output types", function() - local f1 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Boolean()}) - local f2 = types.T.Function({types.T.String(), types.T.Integer()}, {types.T.Integer()}) + local f1 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Boolean}) + local f2 = types.T.Function({types.T.String, types.T.Integer}, {types.T.Integer}) assert.falsy(types.equals(f1, f2)) end) it("is false for functions with different input arity", function() - local s = types.T.String() + local s = types.T.String local f1 = types.T.Function({}, {s}) local f2 = types.T.Function({s}, {s}) local f3 = types.T.Function({s, s}, {s}) @@ -97,7 +97,7 @@ describe("Pallene types", function() end) it("is false for functions with different output arity", function() - local s = types.T.String() + local s = types.T.String local f1 = types.T.Function({s}, {}) local f2 = types.T.Function({s}, {s}) local f3 = types.T.Function({s}, {s, s}) @@ -123,42 +123,42 @@ describe("Pallene types", function() describe("consistency", function() it("allows 'any' on either side", function() - assert.truthy(types.consistent(types.T.Any(), types.T.Any())) - assert.truthy(types.consistent(types.T.Any(), types.T.Integer())) - assert.truthy(types.consistent(types.T.Integer(), types.T.Any())) + assert.truthy(types.consistent(types.T.Any, types.T.Any)) + assert.truthy(types.consistent(types.T.Any, types.T.Integer)) + assert.truthy(types.consistent(types.T.Integer, types.T.Any)) end) it("allows types with same tag", function() assert.truthy(types.consistent( - types.T.Integer(), - types.T.Integer() + types.T.Integer, + types.T.Integer )) assert.truthy(types.consistent( - types.T.Array(types.T.Integer()), - types.T.Array(types.T.Integer()) + types.T.Array(types.T.Integer), + types.T.Array(types.T.Integer) )) assert.truthy(types.consistent( - types.T.Array(types.T.Integer()), - types.T.Array(types.T.String()) + types.T.Array(types.T.Integer), + types.T.Array(types.T.String) )) assert.truthy(types.consistent( - types.T.Function({types.T.Integer()}, {types.T.Integer()}), - types.T.Function({types.T.String(), types.T.String()}, {}) + types.T.Function({types.T.Integer}, {types.T.Integer}), + types.T.Function({types.T.String, types.T.String}, {}) )) end) it("forbids different tags", function() assert.falsy(types.consistent( - types.T.Integer(), - types.T.String() + types.T.Integer, + types.T.String )) assert.falsy(types.consistent( - types.T.Array(types.T.Integer()), - types.T.Function({types.T.Integer()},{types.T.Integer()}) + types.T.Array(types.T.Integer), + types.T.Function({types.T.Integer},{types.T.Integer}) )) end) end) diff --git a/src/pallene/builtins.lua b/src/pallene/builtins.lua index a433026b..ba44279b 100644 --- a/src/pallene/builtins.lua +++ b/src/pallene/builtins.lua @@ -10,38 +10,38 @@ local builtins = {} -- TODO: It will be easier to read this is we could write down the types using the normal grammar -local ipairs_itertype = T.Function({T.Any(), T.Any()}, {T.Any(), T.Any()}) +local ipairs_itertype = T.Function({T.Any, T.Any}, {T.Any, T.Any}) builtins.functions = { - type = T.Function({ T.Any() }, { T.String() }), - tostring = T.Function({ T.Any() }, { T.String() }), - ipairs = T.Function({T.Array(T.Any())}, {ipairs_itertype, T.Any(), T.Any()}) + type = T.Function({ T.Any }, { T.String }), + tostring = T.Function({ T.Any }, { T.String }), + ipairs = T.Function({T.Array(T.Any)}, {ipairs_itertype, T.Any, T.Any}) } builtins.modules = { io = { - write = T.Function({ T.String() }, {}), + write = T.Function({ T.String }, {}), }, math = { - abs = T.Function({ T.Float() }, { T.Float() }), - ceil = T.Function({ T.Float() }, { T.Integer() }), - floor = T.Function({ T.Float() }, { T.Integer() }), - fmod = T.Function({ T.Float(), T.Float() }, { T.Float() }), - exp = T.Function({ T.Float() }, { T.Float() }), - ln = T.Function({ T.Float() }, { T.Float() }), - log = T.Function({ T.Float(), T.Float() }, { T.Float() }), - modf = T.Function({ T.Float() }, { T.Integer(), T.Float() }), - pow = T.Function({ T.Float(), T.Float() }, { T.Float() }), - sqrt = T.Function({ T.Float() }, { T.Float() }), + abs = T.Function({ T.Float }, { T.Float }), + ceil = T.Function({ T.Float }, { T.Integer }), + floor = T.Function({ T.Float }, { T.Integer }), + fmod = T.Function({ T.Float, T.Float }, { T.Float }), + exp = T.Function({ T.Float }, { T.Float }), + ln = T.Function({ T.Float }, { T.Float }), + log = T.Function({ T.Float, T.Float }, { T.Float }), + modf = T.Function({ T.Float }, { T.Integer, T.Float }), + pow = T.Function({ T.Float, T.Float }, { T.Float }), + sqrt = T.Function({ T.Float }, { T.Float }), -- constant numbers - huge = T.Float(), - mininteger = T.Integer(), - maxinteger = T.Integer(), - pi = T.Float(), + huge = T.Float, + mininteger = T.Integer, + maxinteger = T.Integer, + pi = T.Float, }, string = { - char = T.Function({ T.Integer() }, { T.String() }), - sub = T.Function({ T.String(), T.Integer(), T.Integer() }, { T.String() }), + char = T.Function({ T.Integer }, { T.String }), + sub = T.Function({ T.String, T.Integer, T.Integer }, { T.String }), }, } diff --git a/src/pallene/coder.lua b/src/pallene/coder.lua index 0730f1b9..ede52bda 100644 --- a/src/pallene/coder.lua +++ b/src/pallene/coder.lua @@ -363,7 +363,7 @@ function Coder:c_value(value) return C.float(value.value) elseif tag == "ir.Value.String" then local str = value.value - return lua_value(types.T.String(), self:string_upvalue_slot(str)) + return lua_value(types.T.String, self:string_upvalue_slot(str)) elseif tag == "ir.Value.LocalVar" then return self:c_var(value.id) elseif tag == "ir.Value.Upvalue" then @@ -730,8 +730,8 @@ function Coder:init_upvalues() -- If we are using tracebacks if self.flags.use_traceback then - table.insert(self.constants, coder.Constant.DebugUserdata()) - table.insert(self.constants, coder.Constant.DebugMetatable()) + table.insert(self.constants, coder.Constant.DebugUserdata) + table.insert(self.constants, coder.Constant.DebugMetatable) end -- Metatables @@ -1353,7 +1353,7 @@ gen_cmd["SetTable"] = function(self, args) tab = tab, key = key, val = val, - init_keyv = set_stack_slot(types.T.String(), "&keyv", key), + init_keyv = set_stack_slot(types.T.String, "&keyv", key), init_valv = set_stack_slot(src_typ, "&valv", val), -- Here we use set_stack_slot slot on a heap object, because -- we call the barrier by hand outside the if statement. diff --git a/src/pallene/tagged_union.lua b/src/pallene/tagged_union.lua index e241dccd..3730d470 100644 --- a/src/pallene/tagged_union.lua +++ b/src/pallene/tagged_union.lua @@ -55,7 +55,7 @@ local function make_tag(mod_name, type_name, cons_name) end -- Create a tagged union constructor --- @param module Module table where the type is being defined +-- @param mod_table Module table where the type is being defined -- @param mod_name Name of the module -- @param type_name Name of the type -- @param constructors Name of the constructor => fields of the record @@ -63,18 +63,24 @@ local function define_union(mod_table, mod_name, type_name, constructors) mod_table[type_name] = {} for cons_name, fields in pairs(constructors) do local tag = make_tag(mod_name, type_name, cons_name) - local function cons(...) - local args = table.pack(...) - if args.n ~= #fields then - error(string.format( - "wrong number of arguments for %s. Expected %d but received %d.", - cons_name, #fields, args.n)) - end - local node = { _tag = tag } - for i, field in ipairs(fields) do - node[field] = args[i] + + local cons + if #fields == 0 then + cons = { _tag = tag } + else + cons = function(...) + local args = table.pack(...) + if args.n ~= #fields then + error(string.format( + "wrong number of arguments for %s. Expected %d but received %d.", + cons_name, #fields, args.n)) + end + local node = { _tag = tag } + for i, field in ipairs(fields) do + node[field] = args[i] + end + return node end - return node end mod_table[type_name][cons_name] = cons end diff --git a/src/pallene/to_ir.lua b/src/pallene/to_ir.lua index f2fcb236..9c268694 100644 --- a/src/pallene/to_ir.lua +++ b/src/pallene/to_ir.lua @@ -350,7 +350,7 @@ function ToIR:convert_toplevel(prog_ast) bb:append_cmd(ir.Cmd.NewTable(self.func.loc, self.module.loc_id_of_exports, ir.Value.Integer(n_exports))) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) -- export the functions for _, f_id in ipairs(self.module.exported_functions) do @@ -455,8 +455,8 @@ function ToIR:convert_stat(bb, stat) local count = ir.add_local(self.func, false, v_type) local iter = ir.add_local(self.func, false, v_type) - local cond_enter = ir.add_local(self.func, false, types.T.Boolean()) - local cond_loop = ir.add_local(self.func, false, types.T.Boolean()) + local cond_enter = ir.add_local(self.func, false, types.T.Boolean) + local cond_loop = ir.add_local(self.func, false, types.T.Boolean) local init_for = ir.Cmd.ForPrep( stat.loc, v, cond_enter, iter, count, @@ -526,25 +526,25 @@ function ToIR:convert_stat(bb, stat) -- the table passed as argument to `ipairs` local arr = ipairs_args[1] - assert(types.equals(arr._type, types.T.Array(types.T.Any()))) + assert(types.equals(arr._type, types.T.Array(types.T.Any))) local v_arr = ir.add_local(self.func, "$xs", arr._type) self:exp_to_assignment(bb, v_arr, arr) -- local i_num: integer = 1 - local v_inum = ir.add_local(self.func, "$"..decls[1].name.."_num", types.T.Integer()) + local v_inum = ir.add_local(self.func, "$"..decls[1].name.."_num", types.T.Integer) local start = ir.Value.Integer(1) bb:append_cmd(ir.Cmd.Move(stat.loc, v_inum, start)) local loop_begin = bb:finish_block() -- x_dyn = xs[i_num] - local v_x_dyn = ir.add_local(self.func, "$"..decls[2].name.."_dyn", types.T.Any()) + local v_x_dyn = ir.add_local(self.func, "$"..decls[2].name.."_dyn", types.T.Any) local src_arr = ir.Value.LocalVar(v_arr) local src_i = ir.Value.LocalVar(v_inum) - bb:append_cmd(ir.Cmd.GetArr(stat.loc, types.T.Any(), v_x_dyn, src_arr, src_i)) + bb:append_cmd(ir.Cmd.GetArr(stat.loc, types.T.Any, v_x_dyn, src_arr, src_i)) -- if x_dyn == nil then break end - local cond_checknil = ir.add_local(self.func, false, types.T.Boolean()) + local cond_checknil = ir.add_local(self.func, false, types.T.Boolean) bb:append_cmd(ir.Cmd.IsNil(stat.loc, cond_checknil, ir.Value.LocalVar(v_x_dyn))) step_test_jmpIf= bb:append_cmd( ir.Cmd.JmpIf(stat.loc, ir.Value.LocalVar(cond_checknil), nil, nil)) @@ -556,7 +556,7 @@ function ToIR:convert_stat(bb, stat) if decls[1]._type._tag == "types.T.Integer" then bb:append_cmd(ir.Cmd.Move(stat.loc, v_i, ir.Value.LocalVar(v_inum))) else - bb:append_cmd(ir.Cmd.ToDyn(stat.loc, types.T.Integer(), v_i, ir.Value.LocalVar(v_inum))) + bb:append_cmd(ir.Cmd.ToDyn(stat.loc, types.T.Integer, v_i, ir.Value.LocalVar(v_inum))) end -- local x = x_dyn as T2 @@ -604,7 +604,7 @@ function ToIR:convert_stat(bb, stat) local v_lhs_dyn = {} for _, decl in ipairs(decls) do - local v = ir.add_local(self.func, "$" .. decl.name .. "_dyn", types.T.Any()) + local v = ir.add_local(self.func, "$" .. decl.name .. "_dyn", types.T.Any) table.insert(v_lhs_dyn, v) end @@ -618,7 +618,7 @@ function ToIR:convert_stat(bb, stat) bb:append_cmd(ir.Cmd.CallDyn(exps[1].loc, itertype, v_lhs_dyn, ir.Value.LocalVar(v_iter), args)) -- if i == nil then break end - local cond_checknil = ir.add_local(self.func, false, types.T.Boolean()) + local cond_checknil = ir.add_local(self.func, false, types.T.Boolean) bb:append_cmd(ir.Cmd.IsNil(stat.loc, cond_checknil, ir.Value.LocalVar(v_lhs_dyn[1]))) step_test_jmpIf = bb:append_cmd( ir.Cmd.JmpIf(stat.loc, ir.Value.LocalVar(cond_checknil), nil, nil)) @@ -984,7 +984,7 @@ end function ToIR:exp_to_value(bb, exp, is_recursive) local tag = exp._tag if tag == "ast.Exp.Nil" then - return ir.Value.Nil() + return ir.Value.Nil elseif tag == "ast.Exp.Bool" then return ir.Value.Bool(exp.value) @@ -1071,7 +1071,7 @@ function ToIR:exp_to_assignment(bb, dst, exp) if typ._tag == "types.T.Array" then local n = ir.Value.Integer(#exp.fields) bb:append_cmd(ir.Cmd.NewArr(loc, dst, n)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) for i, field in ipairs(exp.fields) do assert(field._tag == "ast.Field.List") local av = ir.Value.LocalVar(dst) @@ -1084,7 +1084,7 @@ function ToIR:exp_to_assignment(bb, dst, exp) elseif typ._tag == "types.T.Table" then local n = ir.Value.Integer(#exp.fields) bb:append_cmd(ir.Cmd.NewTable(loc, dst, n)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) for _, field in ipairs(exp.fields) do assert(field._tag == "ast.Field.Rec") local tv = ir.Value.LocalVar(dst) @@ -1102,7 +1102,7 @@ function ToIR:exp_to_assignment(bb, dst, exp) end bb:append_cmd(ir.Cmd.NewRecord(loc, typ, dst)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) for _, field_name in ipairs(typ.field_names) do local f_exp = assert(field_exps[field_name]) local dv = ir.Value.LocalVar(dst) @@ -1124,7 +1124,7 @@ function ToIR:exp_to_assignment(bb, dst, exp) assert(typ.is_upvalue_box) bb:append_cmd(ir.Cmd.NewRecord(loc, typ, dst)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) elseif tag == "ast.Exp.Lambda" then local f_id = self:register_lambda(exp, "$lambda") @@ -1221,11 +1221,11 @@ function ToIR:exp_to_assignment(bb, dst, exp) elseif bname == "string.char" then assert(#xs == 1) bb:append_cmd(ir.Cmd.BuiltinStringChar(loc, dsts, xs)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) elseif bname == "string.sub" then assert(#xs == 3) bb:append_cmd(ir.Cmd.BuiltinStringSub(loc, dsts, xs)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) elseif bname == "type" then assert(#xs == 1) bb:append_cmd(ir.Cmd.BuiltinType(loc, dsts, xs)) @@ -1308,7 +1308,7 @@ function ToIR:exp_to_assignment(bb, dst, exp) xs[i] = self:exp_to_value(bb, x_exp) end bb:append_cmd(ir.Cmd.Concat(loc, dst, xs)) - bb:append_cmd(ir.Cmd.CheckGC()) + bb:append_cmd(ir.Cmd.CheckGC) elseif tag == "ast.Exp.Binop" then local op = exp.op @@ -1391,7 +1391,7 @@ function ToIR:value_is_truthy(bb, exp, val) if typ._tag == "types.T.Boolean" then return val elseif typ._tag == "types.T.Any" then - local b = ir.add_local(self.func, false, types.T.Boolean()) + local b = ir.add_local(self.func, false, types.T.Boolean) bb:append_cmd(ir.Cmd.IsTruthy(exp.loc, b, val)) return ir.Value.LocalVar(b) elseif tagged_union.tag_is_type(typ) then diff --git a/src/pallene/typechecker.lua b/src/pallene/typechecker.lua index 954e1122..8912ad40 100644 --- a/src/pallene/typechecker.lua +++ b/src/pallene/typechecker.lua @@ -168,7 +168,7 @@ end function Typechecker:from_ast_type(ast_typ) local tag = ast_typ._tag if tag == "ast.Type.Nil" then - return types.T.Nil() + return types.T.Nil elseif tag == "ast.Type.Name" then local name = ast_typ.name @@ -229,11 +229,11 @@ function Typechecker:check_program(prog_ast) local module_name = prog_ast.module_name -- 1) Add primitive types to the symbol table - self:add_type_symbol("any", types.T.Any()) - self:add_type_symbol("boolean", types.T.Boolean()) - self:add_type_symbol("float", types.T.Float()) - self:add_type_symbol("integer", types.T.Integer()) - self:add_type_symbol("string", types.T.String()) + self:add_type_symbol("any", types.T.Any) + self:add_type_symbol("boolean", types.T.Boolean) + self:add_type_symbol("float", types.T.Float) + self:add_type_symbol("integer", types.T.Integer) + self:add_type_symbol("string", types.T.String) -- 2) Add builtins to symbol table. -- The order does not matter because they are distinct. @@ -247,7 +247,7 @@ function Typechecker:check_program(prog_ast) local id = mod_name .. "." .. fun_name symbols[fun_name] = typechecker.Symbol.Value(typ, typechecker.Def.Builtin(id)) end - local typ = (mod_name == "string") and types.T.String() or false + local typ = (mod_name == "string") and types.T.String or false self:add_module_symbol(mod_name, typ, symbols) end @@ -423,10 +423,10 @@ function Typechecker:check_stat(stat, is_toplevel) local decl_types = {} for _ = 1, #stat.decls do - table.insert(decl_types, types.T.Any()) + table.insert(decl_types, types.T.Any) end - local itertype = types.T.Function({ types.T.Any(), types.T.Any() }, decl_types) + local itertype = types.T.Function({ types.T.Any, types.T.Any }, decl_types) rhs[1] = self:check_exp_synthesize(rhs[1]) local iteratorfn = rhs[1] @@ -712,7 +712,7 @@ function Typechecker:check_var(var) "expected array but found %s in indexed expression", types.tostring(arr_type)) end - var.k = self:check_exp_verify(var.k, types.T.Integer(), "array index") + var.k = self:check_exp_verify(var.k, types.T.Integer, "array index") var._type = arr_type.elem else @@ -793,19 +793,19 @@ function Typechecker:check_exp_synthesize(exp) local tag = exp._tag if tag == "ast.Exp.Nil" then - exp._type = types.T.Nil() + exp._type = types.T.Nil elseif tag == "ast.Exp.Bool" then - exp._type = types.T.Boolean() + exp._type = types.T.Boolean elseif tag == "ast.Exp.Integer" then - exp._type = types.T.Integer() + exp._type = types.T.Integer elseif tag == "ast.Exp.Float" then - exp._type = types.T.Float() + exp._type = types.T.Float elseif tag == "ast.Exp.String" then - exp._type = types.T.String() + exp._type = types.T.String elseif tag == "ast.Exp.InitList" then type_error(exp.loc, "missing type hint for initializer") @@ -827,7 +827,7 @@ function Typechecker:check_exp_synthesize(exp) "trying to take the length of a %s instead of an array or string", types.tostring(t)) end - exp._type = types.T.Integer() + exp._type = types.T.Integer elseif op == "-" then if t._tag ~= "types.T.Integer" and t._tag ~= "types.T.Float" then type_error(exp.loc, @@ -841,10 +841,10 @@ function Typechecker:check_exp_synthesize(exp) "trying to bitwise negate a %s instead of an integer", types.tostring(t)) end - exp._type = types.T.Integer() + exp._type = types.T.Integer elseif op == "not" then check_type_is_condition(exp.exp, "'not' operator") - exp._type = types.T.Boolean() + exp._type = types.T.Boolean else tagged_union.error(op) end @@ -868,7 +868,7 @@ function Typechecker:check_exp_synthesize(exp) "cannot compare %s and %s using %s", types.tostring(t1), types.tostring(t2), op) end - exp._type = types.T.Boolean() + exp._type = types.T.Boolean elseif op == "<" or op == ">" or op == "<=" or op == ">=" then if (t1._tag == "types.T.Integer" and t2._tag == "types.T.Integer") or @@ -886,7 +886,7 @@ function Typechecker:check_exp_synthesize(exp) "cannot compare %s and %s using %s", types.tostring(t1), types.tostring(t2), op) end - exp._type = types.T.Boolean() + exp._type = types.T.Boolean elseif op == "+" or op == "-" or op == "*" or op == "%" or op == "//" then if not is_numeric_type(t1) then @@ -903,11 +903,11 @@ function Typechecker:check_exp_synthesize(exp) if t1._tag == "types.T.Integer" and t2._tag == "types.T.Integer" then - exp._type = types.T.Integer() + exp._type = types.T.Integer else exp.lhs = self:coerce_numeric_exp_to_float(exp.lhs) exp.rhs = self:coerce_numeric_exp_to_float(exp.rhs) - exp._type = types.T.Float() + exp._type = types.T.Float end elseif op == "/" or op == "^" then @@ -924,7 +924,7 @@ function Typechecker:check_exp_synthesize(exp) exp.lhs = self:coerce_numeric_exp_to_float(exp.lhs) exp.rhs = self:coerce_numeric_exp_to_float(exp.rhs) - exp._type = types.T.Float() + exp._type = types.T.Float elseif op == ".." then -- The arguments to '..' must be a strings. We do not allow "any" because Pallene does @@ -935,7 +935,7 @@ function Typechecker:check_exp_synthesize(exp) if t2._tag ~= "types.T.String" then type_error(exp.loc, "cannot concatenate with %s value", types.tostring(t2)) end - exp._type = types.T.String() + exp._type = types.T.String elseif op == "and" or op == "or" then check_type_is_condition(exp.lhs, "first operand of '%s'", op) @@ -953,7 +953,7 @@ function Typechecker:check_exp_synthesize(exp) "right-hand side of bitwise expression is a %s instead of an integer", types.tostring(t2)) end - exp._type = types.T.Integer() + exp._type = types.T.Integer else tagged_union.error(op) @@ -987,7 +987,7 @@ function Typechecker:check_exp_synthesize(exp) elseif tag == "ast.Exp.ToFloat" then assert(exp.exp._type._tag == "types.T.Integer") - exp._type = types.T.Float() + exp._type = types.T.Float else tagged_union.error(tag)