Skip to content

Commit

Permalink
fix subtype check of generic record against generic interface
Browse files Browse the repository at this point in the history
Fixes #859.
  • Loading branch information
hishamhm committed Jan 12, 2025
1 parent 91fe47a commit b2176c4
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 30 deletions.
59 changes: 59 additions & 0 deletions spec/lang/subtyping/interface_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,65 @@ describe("subtyping of interfaces:", function()
f(MyRecord)
]]))

it("record <: interface (regression test for #859, example without generics)", util.check([[
local interface IFoo
get_value: function(self): integer
end
local record Foo is IFoo
_value: integer
end
function Foo:get_value():integer
return self._value
end
function Foo.new(value: integer):Foo
local fields = { _value = value }
return setmetatable(fields, { __index = Foo })
end
local function create_foo(value: integer):IFoo
local foo = Foo.new(value)
return foo
end
local foo = create_foo(5)
print(foo:get_value())
]]))

it("generic record <: generic interface (regression test for #859)", util.check([[
local interface IFoo<T>
get_value: function(self): T
end
local record Foo<T> is IFoo<T>
_value: T
end
function Foo:get_value():T
return self._value
end
function Foo.new(value: T):Foo<T>
local fields = { _value = value }
return setmetatable(fields, { __index = Foo })
end
local function create_foo<T>(value: T):IFoo<T>
local foo = Foo.new(value)
return foo
-- Have to do this instead for now:
-- return foo as IFoo<T>
end
------------------------
local foo = create_foo(5)
print(foo:get_value())
]]))

it("regression test for #830", util.check_lines([[
local interface IFoo
end
Expand Down
50 changes: 35 additions & 15 deletions tl.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8636,19 +8636,18 @@ do
return true
end

local function find_in_interface_list(a, f)
if not a.interface_list then
return nil
function TypeChecker:in_interface_list(r, iface)
if not r.interface_list then
return false
end

for _, t in ipairs(a.interface_list) do
local ret = f(t)
if ret then
return ret
for _, t in ipairs(r.interface_list) do
if self:is_a(t, iface) then
return true
end
end

return nil
return false
end

function TypeChecker:subtype_record(a, b)
Expand Down Expand Up @@ -8830,6 +8829,23 @@ do
end
end

local function a_is_interface_b(self, a, b)
assert(a.found)
assert(b.found)
local af = a.found.def
if af.typename == "generic" then
af = self:apply_generic(a, af, a.typevals)
end

if af.fields then
if self:in_interface_list(af, b) then
return true
end
end

return self:is_a(a, self:resolve_nominal(b))
end


local emptytable_relations = {
["emptytable"] = compare_true,
Expand Down Expand Up @@ -9056,7 +9072,7 @@ do


if rb.typename == "interface" then
return self:is_a(a, rb)
return a_is_interface_b(self, a, b)
end


Expand Down Expand Up @@ -9094,7 +9110,7 @@ do
},
["interface"] = {
["interface"] = function(self, a, b)
if find_in_interface_list(a, function(t) return (self:is_a(t, b)) end) then
if self:in_interface_list(a, b) then
return true
end
return self:same_type(a, b)
Expand Down Expand Up @@ -9152,7 +9168,7 @@ a.types[i], b.types[i]), }
["record"] = {
["record"] = TypeChecker.subtype_record,
["interface"] = function(self, a, b)
if find_in_interface_list(a, function(t) return (self:is_a(t, b)) end) then
if self:in_interface_list(a, b) then
return true
end
if not a.declname then
Expand Down Expand Up @@ -12924,10 +12940,14 @@ self:expand_type(node, values, elements) })
local t = tn and a_type(node, tn, {})

if not t and ra.fields then
t = find_in_interface_list(ra, function(ty)
local tname = types_op[ty.typename]
return tname and a_type(node, tname, {})
end)
if ra.interface_list then
for _, it in ipairs(ra.interface_list) do
if types_op[it.typename] then
t = a_type(node, types_op[it.typename], {})
break
end
end
end
end

local meta_on_operator
Expand Down
50 changes: 35 additions & 15 deletions tl.tl
Original file line number Diff line number Diff line change
Expand Up @@ -8636,19 +8636,18 @@ do
return true
end

local function find_in_interface_list<T>(a: RecordLikeType, f: function(Type): T): T
if not a.interface_list then
return nil
function TypeChecker:in_interface_list(r: RecordLikeType, iface: Type): boolean
if not r.interface_list then
return false
end

for _, t in ipairs(a.interface_list) do
local ret = f(t)
if ret then
return ret
for _, t in ipairs(r.interface_list) do
if self:is_a(t, iface) then
return true
end
end

return nil
return false
end

function TypeChecker:subtype_record(a: RecordLikeType, b: RecordLikeType): boolean, {Error}
Expand Down Expand Up @@ -8830,6 +8829,23 @@ do
end
end

local function a_is_interface_b(self: TypeChecker, a: NominalType, b: NominalType): boolean, {Error}
assert(a.found)
assert(b.found)
local af = a.found.def
if af is GenericType then
af = self:apply_generic(a, af, a.typevals)
end

if af is RecordLikeType then
if self:in_interface_list(af, b) then
return true
end
end

return self:is_a(a, self:resolve_nominal(b))
end

-- emptytable rules are the same in eqtype_relations and subtype_relations
local emptytable_relations: {TypeName:CompareTypes} = {
["emptytable"] = compare_true,
Expand Down Expand Up @@ -9056,7 +9072,7 @@ do

-- match interface subtyping
if rb is InterfaceType then
return self:is_a(a, rb)
return a_is_interface_b(self, a, b)
end

-- all other types nominally
Expand Down Expand Up @@ -9094,7 +9110,7 @@ do
},
["interface"] = {
["interface"] = function(self: TypeChecker, a: InterfaceType, b: InterfaceType): boolean, {Error}
if find_in_interface_list(a, function(t: Type): boolean return (self:is_a(t, b)) end) then
if self:in_interface_list(a, b) then
return true
end
return self:same_type(a, b)
Expand Down Expand Up @@ -9152,7 +9168,7 @@ do
["record"] = {
["record"] = TypeChecker.subtype_record,
["interface"] = function(self: TypeChecker, a: RecordType, b: InterfaceType): boolean, {Error}
if find_in_interface_list(a, function(t: Type): boolean return (self:is_a(t, b)) end) then
if self:in_interface_list(a, b) then
return true
end
if not a.declname then
Expand Down Expand Up @@ -12924,10 +12940,14 @@ do
local t = tn and a_type(node, tn, {})

if not t and ra is RecordLikeType then
t = find_in_interface_list(ra, function(ty: Type): Type
local tname = types_op[ty.typename]
return tname and a_type(node, tname, {})
end)
if ra.interface_list then
for _, it in ipairs(ra.interface_list) do
if types_op[it.typename] then
t = a_type(node, types_op[it.typename], {})
break
end
end
end
end

local meta_on_operator: integer
Expand Down

0 comments on commit b2176c4

Please sign in to comment.