Skip to content

Commit cf52028

Browse files
fix
1 parent 071dc8a commit cf52028

File tree

4 files changed

+195
-17
lines changed

4 files changed

+195
-17
lines changed

pyrefly/lib/binding/function.rs

Lines changed: 178 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use ruff_python_ast::ExceptHandler;
2222
use ruff_python_ast::Expr;
2323
use ruff_python_ast::ExprCall;
2424
use ruff_python_ast::Identifier;
25+
use ruff_python_ast::Parameter;
2526
use ruff_python_ast::Parameters;
2627
use ruff_python_ast::Stmt;
2728
use ruff_python_ast::StmtExpr;
@@ -32,6 +33,7 @@ use ruff_python_ast::name::Name;
3233
use ruff_text_size::Ranged;
3334
use ruff_text_size::TextRange;
3435
use starlark_map::small_map::SmallMap;
36+
use starlark_map::small_set::SmallSet;
3537

3638
use crate::binding::binding::AnnotationTarget;
3739
use crate::binding::binding::Binding;
@@ -94,6 +96,9 @@ fn is_annotated<T>(returns: &Option<T>, params: &Parameters) -> bool {
9496
}
9597
false
9698
}
99+
100+
type ConstrainedTypeVarParams = SmallMap<Name, SmallSet<Name>>;
101+
97102
struct SelfAttrNames<'a> {
98103
self_name: &'a Name,
99104
names: SmallMap<Name, TextRange>,
@@ -349,8 +354,23 @@ impl<'a> BindingsBuilder<'a> {
349354
/// This function must not be called unless the function body statements will be bound;
350355
/// it relies on that binding to ensure we don't have a dangling `Idx<Key>` (which could lead
351356
/// to a panic).
352-
fn implicit_return(&mut self, body: &[Stmt], func_name: &Identifier) -> Idx<Key> {
353-
let last_exprs = function_last_expressions(body, self.sys_info).map(|x| {
357+
fn implicit_return(
358+
&mut self,
359+
body: &[Stmt],
360+
func_name: &Identifier,
361+
parameters: &Parameters,
362+
) -> Idx<Key> {
363+
let constrained_typevars = self.constrained_typevar_params(parameters);
364+
let last_exprs = function_last_expressions(
365+
body,
366+
self.sys_info,
367+
if constrained_typevars.is_empty() {
368+
None
369+
} else {
370+
Some(&constrained_typevars)
371+
},
372+
)
373+
.map(|x| {
354374
x.into_map(|(last, x)| (last, self.last_statement_idx_for_implicit_return(last, x)))
355375
.into_boxed_slice()
356376
});
@@ -360,6 +380,67 @@ impl<'a> BindingsBuilder<'a> {
360380
)
361381
}
362382

383+
fn constrained_typevar_params(&self, parameters: &Parameters) -> ConstrainedTypeVarParams {
384+
let mut constrained = ConstrainedTypeVarParams::new();
385+
let mut consider_param = |param: &Parameter| {
386+
let Some(annotation) = param.annotation() else {
387+
return;
388+
};
389+
let Expr::Name(typevar_name) = annotation else {
390+
return;
391+
};
392+
let Some(constraints) = self.typevar_constraints_for_name(&typevar_name.id) else {
393+
return;
394+
};
395+
constrained.insert(param.name().id.clone(), constraints);
396+
};
397+
for param in parameters.iter_non_variadic_params() {
398+
consider_param(&param.parameter);
399+
}
400+
if let Some(vararg) = &parameters.vararg {
401+
consider_param(vararg);
402+
}
403+
if let Some(kwarg) = &parameters.kwarg {
404+
consider_param(kwarg);
405+
}
406+
constrained
407+
}
408+
409+
fn typevar_constraints_for_name(&self, name: &Name) -> Option<SmallSet<Name>> {
410+
let (idx, _) = self.scopes.binding_idx_for_name(name)?;
411+
let (_, binding) = self.get_original_binding(idx)?;
412+
let binding = binding?;
413+
let Binding::TypeVar(_, _, call) = binding else {
414+
return None;
415+
};
416+
Self::typevar_constraints_from_call(call)
417+
}
418+
419+
fn typevar_constraints_from_call(call: &ExprCall) -> Option<SmallSet<Name>> {
420+
if call
421+
.arguments
422+
.keywords
423+
.iter()
424+
.any(|kw| kw.arg.as_ref().is_some_and(|id| id.id == "bound"))
425+
{
426+
return None;
427+
}
428+
let mut args = call.arguments.args.iter();
429+
let _name_arg = args.next()?;
430+
let mut constraints = SmallSet::new();
431+
for arg in args {
432+
let Expr::Name(name) = arg else {
433+
return None;
434+
};
435+
constraints.insert(name.id.clone());
436+
}
437+
if constraints.len() < 2 {
438+
None
439+
} else {
440+
Some(constraints)
441+
}
442+
}
443+
363444
/// Handles both checking yield / return expressions and binding the return type.
364445
fn analyze_return_type(
365446
&mut self,
@@ -617,7 +698,7 @@ impl<'a> BindingsBuilder<'a> {
617698
self_assignments
618699
}
619700
UntypedDefBehavior::CheckAndInferReturnAny => {
620-
let implicit_return = self.implicit_return(&body, func_name);
701+
let implicit_return = self.implicit_return(&body, func_name, parameters);
621702
let (yields_and_returns, self_assignments, unused_parameters, unused_variables) =
622703
self.function_body_scope(
623704
parameters,
@@ -647,7 +728,7 @@ impl<'a> BindingsBuilder<'a> {
647728
self_assignments
648729
}
649730
UntypedDefBehavior::CheckAndInferReturnType => {
650-
let implicit_return = self.implicit_return(&body, func_name);
731+
let implicit_return = self.implicit_return(&body, func_name, parameters);
651732
let (yields_and_returns, self_assignments, unused_parameters, unused_variables) =
652733
self.function_body_scope(
653734
parameters,
@@ -753,6 +834,47 @@ impl<'a> BindingsBuilder<'a> {
753834
}
754835
}
755836

837+
fn extract_isinstance_test(test: &Expr) -> Option<(Name, SmallSet<Name>)> {
838+
let Expr::Call(ExprCall {
839+
func, arguments, ..
840+
}) = test
841+
else {
842+
return None;
843+
};
844+
let Expr::Name(func_name) = func.as_ref() else {
845+
return None;
846+
};
847+
if func_name.id.as_str() != "isinstance" {
848+
return None;
849+
}
850+
if !arguments.keywords.is_empty() || arguments.args.len() != 2 {
851+
return None;
852+
}
853+
let Expr::Name(subject) = &arguments.args[0] else {
854+
return None;
855+
};
856+
let mut types = SmallSet::new();
857+
match &arguments.args[1] {
858+
Expr::Name(name) => {
859+
types.insert(name.id.clone());
860+
}
861+
Expr::Tuple(tuple) => {
862+
for elt in &tuple.elts {
863+
let Expr::Name(name) = elt else {
864+
return None;
865+
};
866+
types.insert(name.id.clone());
867+
}
868+
}
869+
_ => return None,
870+
}
871+
if types.is_empty() {
872+
None
873+
} else {
874+
Some((subject.id.clone(), types))
875+
}
876+
}
877+
756878
/// Given the body of a function, what are the potential expressions that
757879
/// could be the last ones to be executed, where the function then falls off the end.
758880
///
@@ -762,8 +884,14 @@ impl<'a> BindingsBuilder<'a> {
762884
fn function_last_expressions<'a>(
763885
x: &'a [Stmt],
764886
sys_info: &SysInfo,
887+
constrained_typevars: Option<&ConstrainedTypeVarParams>,
765888
) -> Option<Vec<(LastStmt, &'a Expr)>> {
766-
fn f<'a>(sys_info: &SysInfo, x: &'a [Stmt], res: &mut Vec<(LastStmt, &'a Expr)>) -> Option<()> {
889+
fn f<'a>(
890+
sys_info: &SysInfo,
891+
x: &'a [Stmt],
892+
res: &mut Vec<(LastStmt, &'a Expr)>,
893+
constrained_typevars: Option<&ConstrainedTypeVarParams>,
894+
) -> Option<()> {
767895
match x.last()? {
768896
Stmt::Expr(x) => res.push((LastStmt::Expr, &x.value)),
769897
Stmt::Return(_) | Stmt::Raise(_) => {}
@@ -773,7 +901,7 @@ fn function_last_expressions<'a>(
773901
for y in &x.items {
774902
res.push((LastStmt::With(kind), &y.context_expr));
775903
}
776-
f(sys_info, &x.body, res)?;
904+
f(sys_info, &x.body, res, constrained_typevars)?;
777905
}
778906
Stmt::While(x) => {
779907
// Infinite loops with no breaks cannot fall through
@@ -813,13 +941,46 @@ fn function_last_expressions<'a>(
813941
}
814942
Stmt::If(x) => {
815943
let mut last_test = None;
944+
let mut chain_var = None;
945+
let mut covered = SmallSet::new();
946+
let mut chain_valid = constrained_typevars.is_some();
816947
for (test, body) in sys_info.pruned_if_branches(x) {
817948
last_test = test;
818-
f(sys_info, body, res)?;
949+
if let (Some(test), true) = (test, chain_valid) {
950+
// Special-case isinstance chains over constrained TypeVars to avoid
951+
// false missing-return errors when all constraints are covered.
952+
if let Some((var, types)) = extract_isinstance_test(test) {
953+
if let Some(existing) = &chain_var {
954+
if existing != &var {
955+
chain_valid = false;
956+
}
957+
} else {
958+
chain_var = Some(var);
959+
}
960+
if chain_valid {
961+
for ty in types {
962+
covered.insert(ty);
963+
}
964+
}
965+
} else {
966+
chain_valid = false;
967+
}
968+
}
969+
f(sys_info, body, res, constrained_typevars)?;
819970
}
820971
if last_test.is_some() {
821-
// The final `if` can fall through, so the `if` itself might be the last statement.
822-
return None;
972+
let mut exhaustive = false;
973+
if chain_valid
974+
&& let (Some(var), Some(constrained_typevars)) =
975+
(&chain_var, constrained_typevars)
976+
&& let Some(constraints) = constrained_typevars.get(var)
977+
{
978+
exhaustive = constraints.iter().all(|c| covered.contains(c));
979+
}
980+
if !exhaustive {
981+
// The final `if` can fall through, so the `if` itself might be the last statement.
982+
return None;
983+
}
823984
}
824985
}
825986
Stmt::Try(x) => {
@@ -830,16 +991,18 @@ fn function_last_expressions<'a>(
830991
.iter()
831992
.any(|stmt| matches!(stmt, Stmt::Return(_)))
832993
{
833-
f(sys_info, &x.finalbody, res)?;
994+
f(sys_info, &x.finalbody, res, constrained_typevars)?;
834995
} else {
835996
if x.orelse.is_empty() {
836-
f(sys_info, &x.body, res)?;
997+
f(sys_info, &x.body, res, constrained_typevars)?;
837998
} else {
838-
f(sys_info, &x.orelse, res)?;
999+
f(sys_info, &x.orelse, res, constrained_typevars)?;
8391000
}
8401001
for handler in &x.handlers {
8411002
match handler {
842-
ExceptHandler::ExceptHandler(x) => f(sys_info, &x.body, res)?,
1003+
ExceptHandler::ExceptHandler(x) => {
1004+
f(sys_info, &x.body, res, constrained_typevars)?
1005+
}
8431006
}
8441007
}
8451008
// If we don't have a matching handler, we raise an exception, which is fine.
@@ -848,7 +1011,7 @@ fn function_last_expressions<'a>(
8481011
Stmt::Match(x) => {
8491012
let mut exhaustive = false;
8501013
for case in x.cases.iter() {
851-
f(sys_info, &case.body, res)?;
1014+
f(sys_info, &case.body, res, constrained_typevars)?;
8521015
if case.pattern.is_wildcard() || case.pattern.is_irrefutable() {
8531016
exhaustive = true;
8541017
break;
@@ -864,7 +1027,7 @@ fn function_last_expressions<'a>(
8641027
}
8651028

8661029
let mut res = Vec::new();
867-
f(sys_info, x, &mut res)?;
1030+
f(sys_info, x, &mut res, constrained_typevars)?;
8681031
Some(res)
8691032
}
8701033

pyrefly/lib/solver/subset.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> {
10001000
}
10011001
(t1, Type::Quantified(q)) => match q.restriction() {
10021002
// This only works for constraints and not bounds, because a TypeVar must resolve to exactly one of its constraints.
1003-
Restriction::Constraints(constraints) => all(constraints.iter(), |constraint| {
1003+
Restriction::Constraints(constraints) => any(constraints.iter(), |constraint| {
10041004
self.is_subset_eq(t1, constraint)
10051005
}),
10061006
_ => Err(SubsetError::Other),

pyrefly/lib/test/generic_restrictions.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ testcase!(
523523
test_add_with_constraints,
524524
r#"
525525
def add[T: (int, str)](x: T, y: T) -> T:
526-
return x + y # E: `+` is not supported # E: `+` is not supported # E: `int | Unknown` is not assignable to declared return type `T`
526+
return x + y # E: `+` is not supported # E: `+` is not supported
527527
"#,
528528
);
529529

pyrefly/lib/test/returns.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,21 @@ def f(b: bool) -> int: # E: Function declared to return `int`, but one or more
136136
"#,
137137
);
138138

139+
testcase!(
140+
test_return_constrained_typevar_isinstance,
141+
r#"
142+
from typing import TypeVar
143+
144+
T = TypeVar("T", int, str)
145+
146+
def f(x: T) -> T:
147+
if isinstance(x, int):
148+
return x
149+
elif isinstance(x, str):
150+
return x
151+
"#,
152+
);
153+
139154
testcase!(
140155
test_return_if_no_else_none,
141156
r#"

0 commit comments

Comments
 (0)