Skip to content

Commit 2c57c2d

Browse files
sharkdpAlexWaygood
andauthored
[red-knot] Type narrowing for isinstance checks (#13894)
## Summary Add type narrowing for `isinstance(object, classinfo)` [1] checks: ```py x = 1 if flag else "a" if isinstance(x, int): reveal_type(x) # revealed: Literal[1] ``` closes #13893 [1] https://docs.python.org/3/library/functions.html#isinstance ## Test Plan New Markdown-based tests in `narrow/isinstance.md`. --------- Co-authored-by: Alex Waygood <[email protected]>
1 parent 72c18c8 commit 2c57c2d

File tree

5 files changed

+241
-12
lines changed

5 files changed

+241
-12
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Narrowing for `isinstance` checks
2+
3+
Narrowing for `isinstance(object, classinfo)` expressions.
4+
5+
## `classinfo` is a single type
6+
7+
```py
8+
x = 1 if flag else "a"
9+
10+
if isinstance(x, int):
11+
reveal_type(x) # revealed: Literal[1]
12+
13+
if isinstance(x, str):
14+
reveal_type(x) # revealed: Literal["a"]
15+
if isinstance(x, int):
16+
reveal_type(x) # revealed: Never
17+
18+
if isinstance(x, (int, object)):
19+
reveal_type(x) # revealed: Literal[1] | Literal["a"]
20+
```
21+
22+
## `classinfo` is a tuple of types
23+
24+
Note: `isinstance(x, (int, str))` should not be confused with
25+
`isinstance(x, tuple[(int, str)])`. The former is equivalent to
26+
`isinstance(x, int | str)`:
27+
28+
```py
29+
x = 1 if flag else "a"
30+
31+
if isinstance(x, (int, str)):
32+
reveal_type(x) # revealed: Literal[1] | Literal["a"]
33+
34+
if isinstance(x, (int, bytes)):
35+
reveal_type(x) # revealed: Literal[1]
36+
37+
if isinstance(x, (bytes, str)):
38+
reveal_type(x) # revealed: Literal["a"]
39+
40+
# No narrowing should occur if a larger type is also
41+
# one of the possibilities:
42+
if isinstance(x, (int, object)):
43+
reveal_type(x) # revealed: Literal[1] | Literal["a"]
44+
45+
y = 1 if flag1 else "a" if flag2 else b"b"
46+
if isinstance(y, (int, str)):
47+
reveal_type(y) # revealed: Literal[1] | Literal["a"]
48+
49+
if isinstance(y, (int, bytes)):
50+
reveal_type(y) # revealed: Literal[1] | Literal[b"b"]
51+
52+
if isinstance(y, (str, bytes)):
53+
reveal_type(y) # revealed: Literal["a"] | Literal[b"b"]
54+
```
55+
56+
## `classinfo` is a nested tuple of types
57+
58+
```py
59+
x = 1 if flag else "a"
60+
61+
if isinstance(x, (bool, (bytes, int))):
62+
reveal_type(x) # revealed: Literal[1]
63+
```
64+
65+
## Class types
66+
67+
```py
68+
class A: ...
69+
70+
71+
class B: ...
72+
73+
74+
def get_object() -> object: ...
75+
76+
77+
x = get_object()
78+
79+
if isinstance(x, A):
80+
reveal_type(x) # revealed: A
81+
if isinstance(x, B):
82+
reveal_type(x) # revealed: A & B
83+
```
84+
85+
## No narrowing for instances of `builtins.type`
86+
87+
```py
88+
t = type("t", (), {})
89+
90+
# This isn't testing what we want it to test if we infer anything more precise here:
91+
reveal_type(t) # revealed: type
92+
x = 1 if flag else "foo"
93+
94+
if isinstance(x, t):
95+
reveal_type(x) # revealed: Literal[1] | Literal["foo"]
96+
```
97+
98+
## Do not use custom `isinstance` for narrowing
99+
100+
```py
101+
def isinstance(x, t):
102+
return True
103+
104+
105+
x = 1 if flag else "a"
106+
if isinstance(x, int):
107+
reveal_type(x) # revealed: Literal[1] | Literal["a"]
108+
```
109+
110+
## Do support narrowing if `isinstance` is aliased
111+
112+
```py
113+
isinstance_alias = isinstance
114+
115+
x = 1 if flag else "a"
116+
if isinstance_alias(x, int):
117+
reveal_type(x) # revealed: Literal[1]
118+
```
119+
120+
## Do support narrowing if `isinstance` is imported
121+
122+
```py
123+
from builtins import isinstance as imported_isinstance
124+
125+
x = 1 if flag else "a"
126+
if imported_isinstance(x, int):
127+
reveal_type(x) # revealed: Literal[1]
128+
```
129+
130+
## Do not narrow if second argument is not a type
131+
132+
```py
133+
x = 1 if flag else "a"
134+
135+
# TODO: this should cause us to emit a diagnostic during
136+
# type checking
137+
if isinstance(x, "a"):
138+
reveal_type(x) # revealed: Literal[1] | Literal["a"]
139+
140+
# TODO: this should cause us to emit a diagnostic during
141+
# type checking
142+
if isinstance(x, "int"):
143+
reveal_type(x) # revealed: Literal[1] | Literal["a"]
144+
```
145+
146+
## Do not narrow if there are keyword arguments
147+
148+
```py
149+
x = 1 if flag else "a"
150+
151+
# TODO: this should cause us to emit a diagnostic
152+
# (`isinstance` has no `foo` parameter)
153+
if isinstance(x, int, foo="bar"):
154+
reveal_type(x) # revealed: Literal[1] | Literal["a"]
155+
```

crates/red_knot_python_semantic/src/semantic_index/definition.rs

+7-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,13 @@ impl<'db> Definition<'db> {
4747
self.kind(db).category().is_binding()
4848
}
4949

50-
/// Return true if this is a symbol was defined in the `typing` or `typing_extensions` modules
50+
pub(crate) fn is_builtin_definition(self, db: &'db dyn Db) -> bool {
51+
file_to_module(db, self.file(db)).is_some_and(|module| {
52+
module.search_path().is_standard_library() && matches!(&**module.name(), "builtins")
53+
})
54+
}
55+
56+
/// Return true if this symbol was defined in the `typing` or `typing_extensions` modules
5157
pub(crate) fn is_typing_definition(self, db: &'db dyn Db) -> bool {
5258
file_to_module(db, self.file(db)).is_some_and(|module| {
5359
module.search_path().is_standard_library()

crates/red_knot_python_semantic/src/types.rs

+16-7
Original file line numberDiff line numberDiff line change
@@ -868,13 +868,16 @@ impl<'db> Type<'db> {
868868
fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> {
869869
match self {
870870
// TODO validate typed call arguments vs callable signature
871-
Type::FunctionLiteral(function_type) => match function_type.known(db) {
872-
None => CallOutcome::callable(function_type.return_type(db)),
873-
Some(KnownFunction::RevealType) => CallOutcome::revealed(
874-
function_type.return_type(db),
875-
*arg_types.first().unwrap_or(&Type::Unknown),
876-
),
877-
},
871+
Type::FunctionLiteral(function_type) => {
872+
if function_type.is_known(db, KnownFunction::RevealType) {
873+
CallOutcome::revealed(
874+
function_type.return_type(db),
875+
*arg_types.first().unwrap_or(&Type::Unknown),
876+
)
877+
} else {
878+
CallOutcome::callable(function_type.return_type(db))
879+
}
880+
}
878881

879882
// TODO annotated return type on `__new__` or metaclass `__call__`
880883
Type::ClassLiteral(class) => {
@@ -1595,6 +1598,10 @@ impl<'db> FunctionType<'db> {
15951598
})
15961599
.unwrap_or(Type::Unknown)
15971600
}
1601+
1602+
pub fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
1603+
self.known(db) == Some(known_function)
1604+
}
15981605
}
15991606

16001607
/// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might
@@ -1603,6 +1610,8 @@ impl<'db> FunctionType<'db> {
16031610
pub enum KnownFunction {
16041611
/// `builtins.reveal_type`, `typing.reveal_type` or `typing_extensions.reveal_type`
16051612
RevealType,
1613+
/// `builtins.isinstance`
1614+
IsInstance,
16061615
}
16071616

16081617
#[salsa::interned]

crates/red_knot_python_semantic/src/types/infer.rs

+3
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,9 @@ impl<'db> TypeInferenceBuilder<'db> {
779779
"reveal_type" if definition.is_typing_definition(self.db) => {
780780
Some(KnownFunction::RevealType)
781781
}
782+
"isinstance" if definition.is_builtin_definition(self.db) => {
783+
Some(KnownFunction::IsInstance)
784+
}
782785
_ => None,
783786
};
784787
let function_ty = Type::FunctionLiteral(FunctionType::new(

crates/red_knot_python_semantic/src/types/narrow.rs

+60-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ use crate::semantic_index::definition::Definition;
44
use crate::semantic_index::expression::Expression;
55
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
66
use crate::semantic_index::symbol_table;
7-
use crate::types::{infer_expression_types, IntersectionBuilder, Type};
7+
use crate::types::{
8+
infer_expression_types, IntersectionBuilder, KnownFunction, Type, UnionBuilder,
9+
};
810
use crate::Db;
911
use itertools::Itertools;
1012
use ruff_python_ast as ast;
@@ -60,6 +62,28 @@ fn all_narrowing_constraints_for_expression<'db>(
6062
NarrowingConstraintsBuilder::new(db, Constraint::Expression(expression)).finish()
6163
}
6264

65+
/// Generate a constraint from the *type* of the second argument of an `isinstance` call.
66+
///
67+
/// Example: for `isinstance(…, str)`, we would infer `Type::ClassLiteral(str)` from the
68+
/// second argument, but we need to generate a `Type::Instance(str)` constraint that can
69+
/// be used to narrow down the type of the first argument.
70+
fn generate_isinstance_constraint<'db>(
71+
db: &'db dyn Db,
72+
classinfo: &Type<'db>,
73+
) -> Option<Type<'db>> {
74+
match classinfo {
75+
Type::ClassLiteral(class) => Some(Type::Instance(*class)),
76+
Type::Tuple(tuple) => {
77+
let mut builder = UnionBuilder::new(db);
78+
for element in tuple.elements(db) {
79+
builder = builder.add(generate_isinstance_constraint(db, element)?);
80+
}
81+
Some(builder.build())
82+
}
83+
_ => None,
84+
}
85+
}
86+
6387
type NarrowingConstraints<'db> = FxHashMap<ScopedSymbolId, Type<'db>>;
6488

6589
struct NarrowingConstraintsBuilder<'db> {
@@ -88,10 +112,15 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
88112
}
89113

90114
fn evaluate_expression_constraint(&mut self, expression: Expression<'db>) {
91-
if let ast::Expr::Compare(expr_compare) = expression.node_ref(self.db).node() {
92-
self.add_expr_compare(expr_compare, expression);
115+
match expression.node_ref(self.db).node() {
116+
ast::Expr::Compare(expr_compare) => {
117+
self.add_expr_compare(expr_compare, expression);
118+
}
119+
ast::Expr::Call(expr_call) => {
120+
self.add_expr_call(expr_call, expression);
121+
}
122+
_ => {} // TODO other test expression kinds
93123
}
94-
// TODO other test expression kinds
95124
}
96125

97126
fn evaluate_pattern_constraint(&mut self, pattern: PatternConstraint<'db>) {
@@ -194,6 +223,33 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
194223
}
195224
}
196225

226+
fn add_expr_call(&mut self, expr_call: &ast::ExprCall, expression: Expression<'db>) {
227+
let scope = self.scope();
228+
let inference = infer_expression_types(self.db, expression);
229+
230+
if let Some(func_type) = inference
231+
.expression_ty(expr_call.func.scoped_ast_id(self.db, scope))
232+
.into_function_literal_type()
233+
{
234+
if func_type.is_known(self.db, KnownFunction::IsInstance)
235+
&& expr_call.arguments.keywords.is_empty()
236+
{
237+
if let [ast::Expr::Name(ast::ExprName { id, .. }), rhs] = &*expr_call.arguments.args
238+
{
239+
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
240+
241+
let rhs_type = inference.expression_ty(rhs.scoped_ast_id(self.db, scope));
242+
243+
// TODO: add support for PEP 604 union types on the right hand side:
244+
// isinstance(x, str | (int | float))
245+
if let Some(constraint) = generate_isinstance_constraint(self.db, &rhs_type) {
246+
self.constraints.insert(symbol, constraint);
247+
}
248+
}
249+
}
250+
}
251+
}
252+
197253
fn add_match_pattern_singleton(
198254
&mut self,
199255
subject: &ast::Expr,

0 commit comments

Comments
 (0)