Skip to content

Commit e402e27

Browse files
authored
Use referencial equality in traversal helper methods (#13895)
1 parent de4181d commit e402e27

File tree

4 files changed

+66
-65
lines changed

4 files changed

+66
-65
lines changed

crates/ruff_linter/src/rules/flake8_simplify/rules/needless_bool.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ pub(crate) fn needless_bool(checker: &mut Checker, stmt: &Stmt) {
144144
.semantic()
145145
.current_statement_parent()
146146
.and_then(|parent| traversal::suite(stmt, parent))
147-
.and_then(|suite| traversal::next_sibling(stmt, suite))
147+
.and_then(|suite| suite.next_sibling())
148148
else {
149149
return;
150150
};

crates/ruff_linter/src/rules/flake8_simplify/rules/reimplemented_builtin.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ pub(crate) fn convert_for_loop_to_any_all(checker: &mut Checker, stmt: &Stmt) {
7272
// - `for` loop followed by `return True` or `return False`.
7373
let Some(terminal) = match_else_return(stmt).or_else(|| {
7474
let parent = checker.semantic().current_statement_parent()?;
75-
let suite = traversal::suite(stmt, parent)?;
76-
let sibling = traversal::next_sibling(stmt, suite)?;
75+
let sibling = traversal::suite(stmt, parent)?.next_sibling()?;
7776
match_sibling_return(stmt, sibling)
7877
}) else {
7978
return;

crates/ruff_linter/src/rules/refurb/rules/repeated_append.rs

+9-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use rustc_hash::FxHashMap;
33
use ast::traversal;
44
use ruff_diagnostics::{Diagnostic, Edit, Fix, FixAvailability, Violation};
55
use ruff_macros::{derive_message_formats, violation};
6+
use ruff_python_ast::traversal::EnclosingSuite;
67
use ruff_python_ast::{self as ast, Expr, Stmt};
78
use ruff_python_codegen::Generator;
89
use ruff_python_semantic::analyze::typing::is_list;
@@ -179,32 +180,33 @@ fn match_consecutive_appends<'a>(
179180

180181
// In order to match consecutive statements, we need to go to the tree ancestor of the
181182
// given statement, find its position there, and match all 'appends' from there.
182-
let siblings: &[Stmt] = if semantic.at_top_level() {
183+
let suite = if semantic.at_top_level() {
183184
// If the statement is at the top level, we should go to the parent module.
184185
// Module is available in the definitions list.
185-
semantic.definitions.python_ast()?
186+
EnclosingSuite::new(semantic.definitions.python_ast()?, stmt)?
186187
} else {
187188
// Otherwise, go to the parent, and take its body as a sequence of siblings.
188189
semantic
189190
.current_statement_parent()
190191
.and_then(|parent| traversal::suite(stmt, parent))?
191192
};
192193

193-
let stmt_index = siblings.iter().position(|sibling| sibling == stmt)?;
194-
195194
// We shouldn't repeat the same work for many 'appends' that go in a row. Let's check
196195
// that this statement is at the beginning of such a group.
197-
if stmt_index != 0 && match_append(semantic, &siblings[stmt_index - 1]).is_some() {
196+
if suite
197+
.previous_sibling()
198+
.is_some_and(|previous_stmt| match_append(semantic, previous_stmt).is_some())
199+
{
198200
return None;
199201
}
200202

201203
// Starting from the next statement, let's match all appends and make a vector.
202204
Some(
203205
std::iter::once(append)
204206
.chain(
205-
siblings
207+
suite
208+
.next_siblings()
206209
.iter()
207-
.skip(stmt_index + 1)
208210
.map_while(|sibling| match_append(semantic, sibling)),
209211
)
210212
.collect(),
+55-55
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,81 @@
11
//! Utilities for manually traversing a Python AST.
2-
use crate::{self as ast, ExceptHandler, Stmt, Suite};
2+
use crate::{self as ast, AnyNodeRef, ExceptHandler, Stmt};
33

4-
/// Given a [`Stmt`] and its parent, return the [`Suite`] that contains the [`Stmt`].
5-
pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option<&'a Suite> {
4+
/// Given a [`Stmt`] and its parent, return the [`ast::Suite`] that contains the [`Stmt`].
5+
pub fn suite<'a>(stmt: &'a Stmt, parent: &'a Stmt) -> Option<EnclosingSuite<'a>> {
66
// TODO: refactor this to work without a parent, ie when `stmt` is at the top level
77
match parent {
8-
Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => Some(body),
9-
Stmt::ClassDef(ast::StmtClassDef { body, .. }) => Some(body),
10-
Stmt::For(ast::StmtFor { body, orelse, .. }) => {
11-
if body.contains(stmt) {
12-
Some(body)
13-
} else if orelse.contains(stmt) {
14-
Some(orelse)
15-
} else {
16-
None
17-
}
18-
}
19-
Stmt::While(ast::StmtWhile { body, orelse, .. }) => {
20-
if body.contains(stmt) {
21-
Some(body)
22-
} else if orelse.contains(stmt) {
23-
Some(orelse)
24-
} else {
25-
None
26-
}
27-
}
8+
Stmt::FunctionDef(ast::StmtFunctionDef { body, .. }) => EnclosingSuite::new(body, stmt),
9+
Stmt::ClassDef(ast::StmtClassDef { body, .. }) => EnclosingSuite::new(body, stmt),
10+
Stmt::For(ast::StmtFor { body, orelse, .. }) => [body, orelse]
11+
.iter()
12+
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
13+
Stmt::While(ast::StmtWhile { body, orelse, .. }) => [body, orelse]
14+
.iter()
15+
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
2816
Stmt::If(ast::StmtIf {
2917
body,
3018
elif_else_clauses,
3119
..
32-
}) => {
33-
if body.contains(stmt) {
34-
Some(body)
35-
} else {
36-
elif_else_clauses
37-
.iter()
38-
.map(|elif_else_clause| &elif_else_clause.body)
39-
.find(|body| body.contains(stmt))
40-
}
41-
}
42-
Stmt::With(ast::StmtWith { body, .. }) => Some(body),
20+
}) => [body]
21+
.into_iter()
22+
.chain(elif_else_clauses.iter().map(|clause| &clause.body))
23+
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
24+
Stmt::With(ast::StmtWith { body, .. }) => EnclosingSuite::new(body, stmt),
4325
Stmt::Match(ast::StmtMatch { cases, .. }) => cases
4426
.iter()
4527
.map(|case| &case.body)
46-
.find(|body| body.contains(stmt)),
28+
.find_map(|body| EnclosingSuite::new(body, stmt)),
4729
Stmt::Try(ast::StmtTry {
4830
body,
4931
handlers,
5032
orelse,
5133
finalbody,
5234
..
53-
}) => {
54-
if body.contains(stmt) {
55-
Some(body)
56-
} else if orelse.contains(stmt) {
57-
Some(orelse)
58-
} else if finalbody.contains(stmt) {
59-
Some(finalbody)
60-
} else {
35+
}) => [body, orelse, finalbody]
36+
.into_iter()
37+
.chain(
6138
handlers
6239
.iter()
6340
.filter_map(ExceptHandler::as_except_handler)
64-
.map(|handler| &handler.body)
65-
.find(|body| body.contains(stmt))
66-
}
67-
}
41+
.map(|handler| &handler.body),
42+
)
43+
.find_map(|suite| EnclosingSuite::new(suite, stmt)),
6844
_ => None,
6945
}
7046
}
7147

72-
/// Given a [`Stmt`] and its containing [`Suite`], return the next [`Stmt`] in the [`Suite`].
73-
pub fn next_sibling<'a>(stmt: &'a Stmt, suite: &'a Suite) -> Option<&'a Stmt> {
74-
let mut iter = suite.iter();
75-
while let Some(sibling) = iter.next() {
76-
if sibling == stmt {
77-
return iter.next();
78-
}
48+
pub struct EnclosingSuite<'a> {
49+
suite: &'a [Stmt],
50+
position: usize,
51+
}
52+
53+
impl<'a> EnclosingSuite<'a> {
54+
pub fn new(suite: &'a [Stmt], stmt: &'a Stmt) -> Option<Self> {
55+
let position = suite
56+
.iter()
57+
.position(|sibling| AnyNodeRef::ptr_eq(sibling.into(), stmt.into()))?;
58+
59+
Some(EnclosingSuite { suite, position })
60+
}
61+
62+
pub fn next_sibling(&self) -> Option<&'a Stmt> {
63+
self.suite.get(self.position + 1)
64+
}
65+
66+
pub fn next_siblings(&self) -> &'a [Stmt] {
67+
self.suite.get(self.position + 1..).unwrap_or_default()
68+
}
69+
70+
pub fn previous_sibling(&self) -> Option<&'a Stmt> {
71+
self.suite.get(self.position.checked_sub(1)?)
72+
}
73+
}
74+
75+
impl std::ops::Deref for EnclosingSuite<'_> {
76+
type Target = [Stmt];
77+
78+
fn deref(&self) -> &Self::Target {
79+
self.suite
7980
}
80-
None
8181
}

0 commit comments

Comments
 (0)