Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion pyrefly/lib/lsp/wasm/semantic_tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,18 @@
* LICENSE file in the root directory of this source tree.
*/

use std::collections::HashMap;

use lsp_types::SemanticToken;
use pyrefly_build::handle::Handle;
use pyrefly_python::module_name::ModuleName;
use pyrefly_python::short_identifier::ShortIdentifier;
use pyrefly_python::symbol_kind::SymbolKind;
use pyrefly_util::visit::Visit;
use ruff_python_ast::Stmt;
use ruff_text_size::TextRange;

use crate::binding::binding::Binding;
use crate::binding::binding::Key;
use crate::binding::bindings::Bindings;
use crate::export::exports::Export;
Expand Down Expand Up @@ -70,6 +78,7 @@ impl Transaction<'_> {
let legends = SemanticTokensLegends::new();
let disabled_ranges = disabled_ranges_for_module(ast.as_ref(), handle.sys_info());
let mut builder = SemanticTokenBuilder::new(limit_range, disabled_ranges);
let mut symbol_kinds: HashMap<ShortIdentifier, (ModuleName, SymbolKind)> = HashMap::new();
for NamedBinding {
definition_handle,
definition_export,
Expand All @@ -81,9 +90,20 @@ impl Transaction<'_> {
..
} = definition_export
{
builder.process_key(&key, definition_handle.module(), symbol_kind)
let binding = bindings.get(bindings.key_to_idx(&key));
let definition_module = match binding {
Binding::Import(module, _, _) | Binding::Module(module, ..) => *module,
_ => definition_handle.module(),
};
if let Key::Definition(short) = &key {
symbol_kinds.insert(short.clone(), (definition_module, symbol_kind));
}
builder.process_key(&key, definition_module, symbol_kind);
}
}
for stmt in &ast.body {
add_import_from_alias_tokens(&mut builder, stmt, &symbol_kinds);
}
builder.process_ast(&ast, &|range| self.get_type_trace(handle, range));
Some(legends.convert_tokens_into_lsp_semantic_tokens(
&builder.all_tokens_sorted(),
Expand All @@ -92,3 +112,21 @@ impl Transaction<'_> {
))
}
}

fn add_import_from_alias_tokens(
builder: &mut SemanticTokenBuilder,
stmt: &Stmt,
symbol_kinds: &HashMap<ShortIdentifier, (ModuleName, SymbolKind)>,
) {
if let Stmt::ImportFrom(import_from) = stmt {
for alias in &import_from.names {
if let Some(asname) = &alias.asname {
let key = ShortIdentifier::new(asname);
if let Some((definition_module, symbol_kind)) = symbol_kinds.get(&key) {
builder.process_range(alias.name.range, *definition_module, *symbol_kind);
}
}
}
}
stmt.recurse(&mut |inner| add_import_from_alias_tokens(builder, inner, symbol_kinds));
Copy link

Copilot AI Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recursive call to add_import_from_alias_tokens on line 131 is unnecessary and potentially problematic. Python ImportFrom statements cannot be nested inside other statements - they only appear at the module level or within conditional blocks. The recursion could cause the function to be called multiple times on the same import statement if it appears within nested control flow structures (like if/else blocks), leading to duplicate token generation. Since ImportFrom statements are already being processed from ast.body (lines 104-106), the recursion should be removed.

Suggested change
stmt.recurse(&mut |inner| add_import_from_alias_tokens(builder, inner, symbol_kinds));

Copilot uses AI. Check for mistakes.
}
34 changes: 21 additions & 13 deletions pyrefly/lib/state/semantic_tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,12 @@ impl SemanticTokenBuilder {
.any(|disabled| disabled.contains_range(range))
}

pub fn process_key(
fn push_symbol_range(
&mut self,
key: &Key,
reference_range: TextRange,
definition_module: ModuleName,
symbol_kind: SymbolKind,
) {
let reference_range = key.range();
let (token_type, mut token_modifiers) =
symbol_kind.to_lsp_semantic_token_type_with_modifiers();
let is_default_library = {
Expand All @@ -244,6 +243,24 @@ impl SemanticTokenBuilder {
self.push_if_in_range(reference_range, token_type, token_modifiers);
}

pub fn process_key(
&mut self,
key: &Key,
definition_module: ModuleName,
symbol_kind: SymbolKind,
) {
self.push_symbol_range(key.range(), definition_module, symbol_kind);
}

pub fn process_range(
&mut self,
range: TextRange,
definition_module: ModuleName,
symbol_kind: SymbolKind,
) {
self.push_symbol_range(range, definition_module, symbol_kind);
}

fn process_arguments(&mut self, args: &Arguments) {
for keyword in &args.keywords {
if let Some(arg) = &keyword.arg {
Expand Down Expand Up @@ -341,19 +358,10 @@ impl SemanticTokenBuilder {
}
}
}
Stmt::ImportFrom(StmtImportFrom { module, names, .. }) => {
Stmt::ImportFrom(StmtImportFrom { module, .. }) => {
if let Some(module) = module {
self.push_if_in_range(module.range, SemanticTokenType::NAMESPACE, vec![]);
}
for alias in names {
if alias.asname.is_some() {
self.push_if_in_range(
alias.name.range,
SemanticTokenType::NAMESPACE,
vec![],
);
}
}
}
Stmt::AnnAssign(ann_assign) => {
if let Expr::Name(name) = &*ann_assign.target {
Expand Down
4 changes: 2 additions & 2 deletions pyrefly/lib/test/lsp/semantic_tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ line: 1, column: 5, length: 3, text: lib
token-type: namespace

line: 1, column: 16, length: 4, text: func
token-type: namespace
token-type: function

line: 1, column: 24, length: 4, text: func
token-type: function
Expand Down Expand Up @@ -913,7 +913,7 @@ line: 1, column: 5, length: 3, text: foo
token-type: namespace

line: 1, column: 16, length: 3, text: bar
token-type: namespace
token-type: function

line: 1, column: 23, length: 3, text: baz
token-type: function
Expand Down
Loading