diff --git a/Cargo.lock b/Cargo.lock index aa5f50f..d6a2323 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,6 +20,8 @@ dependencies = [ "tracing", "tracing-subscriber", "tree-sitter", + "tree-sitter-python", + "tree-sitter-rust", "unicode-width", ] @@ -770,6 +772,26 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8ddffe35a0e5eeeadf13ff7350af564c6e73993a24db62caee1822b185c2600" +[[package]] +name = "tree-sitter-python" +version = "0.23.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d065aaa27f3aaceaf60c1f0e0ac09e1cb9eb8ed28e7bcdaa52129cffc7f4b04" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-rust" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4d64d449ca63e683c562c7743946a646671ca23947b9c925c0cfbe65051a4af" +dependencies = [ + "cc", + "tree-sitter-language", +] + [[package]] name = "unicode-ident" version = "1.0.13" diff --git a/Cargo.toml b/Cargo.toml index e0f717f..0880717 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,3 +54,5 @@ toml = "0.8.19" [dev-dependencies] simple_test_case = "1.2.0" criterion = "0.5" +tree-sitter-python = "0.23.6" +tree-sitter-rust = "0.23.2" diff --git a/data/config.toml b/data/config.toml index e763b5e..3cbc283 100644 --- a/data/config.toml +++ b/data/config.toml @@ -72,6 +72,7 @@ function = { fg = "#957FB8" } keyword = { fg = "#Bf616A" } module = { fg = "#2D4F67" } number = { fg = "#D27E99" } +operator = { fg = "#E6C384" } punctuation = { fg = "#9CABCA" } string = { fg = "#61DCA5" } type = { fg = "#7E9CD8" } diff --git a/src/buffer/internal.rs b/src/buffer/internal.rs index 19b93f4..fcde7c5 100644 --- a/src/buffer/internal.rs +++ b/src/buffer/internal.rs @@ -414,7 +414,7 @@ impl GapBuffer { } /// Convert a byte index to a character index - pub fn byte_to_char(&self, byte_idx: usize) -> usize { + pub fn raw_byte_to_char(&self, byte_idx: usize) -> usize { self.chars_in_raw_range(0, byte_idx) } @@ -739,7 +739,7 @@ impl GapBuffer { } #[inline] - fn byte_to_raw_byte(&self, byte: usize) -> usize { + pub fn byte_to_raw_byte(&self, byte: usize) -> usize { if byte > self.gap_start { byte + self.gap() } else { diff --git a/src/lsp/capabilities.rs b/src/lsp/capabilities.rs index 01ed6c5..1abfd67 100644 --- a/src/lsp/capabilities.rs +++ b/src/lsp/capabilities.rs @@ -86,7 +86,7 @@ impl PositionEncoding { Self::Utf8 => { let line_start = b.txt.line_to_char(pos.line as usize); let byte_idx = b.txt.char_to_byte(line_start + pos.character as usize); - let col = b.txt.byte_to_char(byte_idx); + let col = b.txt.raw_byte_to_char(byte_idx); (pos.line as usize, col) } diff --git a/src/ts.rs b/src/ts.rs index 23b7317..7e947b5 100644 --- a/src/ts.rs +++ b/src/ts.rs @@ -26,6 +26,7 @@ use crate::{ use libloading::{Library, Symbol}; use std::{ cmp::{max, min, Ord, Ordering, PartialOrd}, + collections::HashSet, fmt, fs, iter::Peekable, ops::{Deref, DerefMut}, @@ -33,12 +34,14 @@ use std::{ slice, }; use streaming_iterator::StreamingIterator; +use tracing::{error, info}; use tree_sitter::{self as ts, ffi::TSLanguage}; pub const TK_DEFAULT: &str = "default"; pub const TK_DOT: &str = "dot"; pub const TK_LOAD: &str = "load"; pub const TK_EXEC: &str = "exec"; +pub const SUPPORTED_PREDICATES: [&str; 0] = []; /// Buffer level tree-sitter state for parsing and highlighting #[derive(Debug)] @@ -61,14 +64,32 @@ impl TsState { Err(e) => return Err(format!("unable to read tree-sitter query file: {e}")), }; - let mut p = Parser::try_new(so_dir, lang)?; + let p = Parser::try_new(so_dir, lang)?; + + Self::try_new_explicit(p, &query, gb) + } + + #[cfg(test)] + fn try_new_from_language( + lang_name: &str, + lang: ts::Language, + query: &str, + gb: &GapBuffer, + ) -> Result { + let p = Parser::try_new_from_language(lang_name, lang)?; + + Self::try_new_explicit(p, query, gb) + } + + fn try_new_explicit(mut p: Parser, query: &str, gb: &GapBuffer) -> Result { let tree = p.parse_with( &mut |byte_offset, _| gb.maximal_slice_from_offset(byte_offset), None, ); + match tree { Some(tree) => { - let mut t = p.new_tokenizer(&query)?; + let mut t = p.new_tokenizer(query)?; t.update(tree.root_node(), gb); Ok(Self { p, t, tree }) } @@ -139,8 +160,8 @@ impl<'a> ts::TextProvider<&'a [u8]> for &'a GapBuffer { end_byte, .. } = node.range(); - let char_from = self.byte_to_char(start_byte); - let char_to = self.byte_to_char(end_byte); + let char_from = self.raw_byte_to_char(self.byte_to_raw_byte(start_byte)); + let char_to = self.raw_byte_to_char(self.byte_to_raw_byte(end_byte)); self.slice(char_from, char_to).slice_iter() } @@ -151,7 +172,9 @@ pub struct Parser { lang_name: String, inner: ts::Parser, lang: ts::Language, - _lib: Library, // Need to prevent drop while the parser is in use + // Need to prevent drop while the parser is in use + // Stored as an Option to allow for crate-based parsers that are not backed by a .so file + _lib: Option, } impl Deref for Parser { @@ -204,15 +227,53 @@ impl Parser { lang_name: lang_name.to_owned(), inner, lang, - _lib: lib, + _lib: Some(lib), }) } } + /// Construct a new tokenizer directly from a ts::Language provided by a crate + #[cfg(test)] + fn try_new_from_language(lang_name: &str, lang: ts::Language) -> Result { + let mut inner = ts::Parser::new(); + inner.set_language(&lang).map_err(|e| e.to_string())?; + + Ok(Self { + lang_name: lang_name.to_owned(), + inner, + lang, + _lib: None, + }) + } + pub fn new_tokenizer(&self, query: &str) -> Result { let q = ts::Query::new(&self.lang, query).map_err(|e| format!("{e:?}"))?; let cur = ts::QueryCursor::new(); + // If a query has been copied from another text editor then there is a chance that + // it makes use of custom predicates that we don't know how to handle. The highlights + // as a whole won't behave as the user expects in this instance so we error out the + // setup of syntax-highlighting as a whole in this case and log an error + let mut unsupported_predicates = HashSet::new(); + for i in 0..q.pattern_count() { + for p in q.general_predicates(i) { + if !SUPPORTED_PREDICATES.contains(&p.operator.as_ref()) { + unsupported_predicates.insert(p.operator.clone()); + } + } + } + + if !unsupported_predicates.is_empty() { + error!("Unsupported custom tree-sitter predicates found: {unsupported_predicates:?}"); + info!("Supported custom tree-sitter predicates: {SUPPORTED_PREDICATES:?}"); + info!("Please modify the highlights.scm file to remove the unsupported predicates"); + + return Err(format!( + "{} highlights query contained unsupported custom predicates", + self.lang_name + )); + } + Ok(Tokenizer { q, cur, @@ -236,36 +297,29 @@ impl fmt::Debug for Tokenizer { } impl Tokenizer { - // Compound queries such as the example below can result in duplicate nodes being returned - // from the caputures iterator in both the init and update methods. As such, we need to sort - // and dedupe the list of resulting syntax ranges in order to correctly ensure that we have - // no overlapping or duplicated tokens emitted. - // - // (macro_invocation - // macro: (identifier) @function.macro - // "!" @function.macro) - pub fn update(&mut self, root: ts::Node<'_>, gb: &GapBuffer) { // This is a streaming-iterator not an interator, hence the odd while-let that follows let mut it = self.cur.captures(&self.q, root, gb); // FIXME: this is really inefficient. Ideally we should be able to apply a diff here self.ranges.clear(); - while let Some((m, _)) = it.next() { - for cap_idx in 0..self.q.capture_names().len() { - for node in m.nodes_for_capture_index(cap_idx as u32) { - let r = ByteRange::from(node.range()); - if let Some(prev) = self.ranges.last() { - if r.from < prev.r.to && prev.r.from < r.to { - continue; - } - } - self.ranges.push(SyntaxRange { - r, - cap_idx: Some(cap_idx), - }); + while let Some((m, idx)) = it.next() { + let cap = m.captures[*idx]; + let r = ByteRange::from(cap.node.range()); + if let Some(prev) = self.ranges.last_mut() { + if r == prev.r { + // prefering the the last capture found so that precedence ordering + // in query files matches Neovim & the treesitter-cli + prev.cap_idx = Some(cap.index as usize); + continue; + } else if r.from < prev.r.to && prev.r.from < r.to { + continue; } } + self.ranges.push(SyntaxRange { + r, + cap_idx: Some(cap.index as usize), + }); } self.ranges.sort_unstable(); @@ -290,6 +344,19 @@ impl Tokenizer { &self.ranges, ) } + + #[cfg(test)] + fn range_tokens(&self) -> Vec> { + let names = self.q.capture_names(); + + self.ranges + .iter() + .map(|sr| RangeToken { + tag: sr.cap_idx.map(|i| names[i]).unwrap_or(TK_DEFAULT), + r: sr.r, + }) + .collect() + } } /// Byte offsets within a Buffer @@ -1100,49 +1167,128 @@ mod tests { assert_eq!(held, expected); } + fn rt(tag: &str, from: usize, to: usize) -> RangeToken<'_> { + RangeToken { + tag, + r: ByteRange { from, to }, + } + } + #[test] - #[ignore = "this test requires installed parsers and queries"] fn char_delete_correctly_update_state() { + // minimal query for the fn keyword and parens + let query = r#" +"fn" @keyword + +[ "(" ")" "{" "}" ] @punctuation"#; + let s = "fn main() {}"; let mut b = Buffer::new_unnamed(0, s); + let gb = &b.txt; b.ts_state = Some( - TsState::try_new( - "rust", - "/home/sminez/.local/share/nvim/lazy/nvim-treesitter/parser", - "data/tree-sitter/queries", - &b.txt, - ) - .unwrap(), + TsState::try_new_from_language("rust", tree_sitter_rust::LANGUAGE.into(), query, gb) + .unwrap(), ); - let ranges = b.ts_state.as_ref().unwrap().t.ranges.clone(); - let sr = |idx, from, to| SyntaxRange { - cap_idx: Some(idx), - r: ByteRange { from, to }, - }; - assert_eq!(b.str_contents(), "fn main() {}\n"); assert_eq!( - ranges, + b.ts_state.as_ref().unwrap().t.range_tokens(), vec![ - sr(5, 0, 2), // fn - sr(14, 7, 8), // ( - sr(14, 8, 9), // ) - sr(14, 10, 11), // { - sr(14, 11, 12), // } + rt("keyword", 0, 2), // fn + rt("punctuation", 7, 8), // ( + rt("punctuation", 8, 9), // ) + rt("punctuation", 10, 11), // { + rt("punctuation", 11, 12), // } ] ); b.dot = Dot::Cur { c: Cur { idx: 9 } }; b.handle_action(Action::Delete, Source::Fsys); b.ts_state.as_mut().unwrap().update(&b.txt); - let ranges = b.ts_state.as_ref().unwrap().t.ranges.clone(); + let ranges = b.ts_state.as_ref().unwrap().t.range_tokens(); assert_eq!(b.str_contents(), "fn main(){}\n"); assert_eq!(ranges.len(), 5); // these two should have moved left one character - assert_eq!(ranges[3], sr(14, 9, 10), "opening curly"); - assert_eq!(ranges[4], sr(14, 10, 11), "closing curly"); + assert_eq!(ranges[3], rt("punctuation", 9, 10), "opening curly"); + assert_eq!(ranges[4], rt("punctuation", 10, 11), "closing curly"); + } + + #[test] + fn overlapping_tokens_prefer_previous_matches() { + // Minimal query extracted from the full query in gh#88 that resulted in + // overlapping tokens being produced + let query = r#" +(identifier) @variable + +(import_statement + name: (dotted_name + (identifier) @module)) + +(import_statement + name: (aliased_import + name: (dotted_name + (identifier) @module) + alias: (identifier) @module)) + +(import_from_statement + module_name: (dotted_name + (identifier) @module))"#; + + let s = "import builtins as _builtins"; + let b = Buffer::new_unnamed(0, s); + let gb = &b.txt; + let ts = TsState::try_new_from_language( + "python", + tree_sitter_python::LANGUAGE.into(), + query, + gb, + ) + .unwrap(); + + assert_eq!( + ts.t.range_tokens(), + vec![ + rt("module", 7, 15), // builtins + rt("module", 19, 28) // _builtins + ] + ); + } + + #[test] + fn built_in_predicates_work() { + let query = r#" +(identifier) @variable + +; Assume all-caps names are constants +((identifier) @constant + (#match? @constant "^[A-Z][A-Z%d_]*$")) + +((identifier) @constant.builtin + (#any-of? @constant.builtin "Some" "None" "Ok" "Err")) + +[ "(" ")" "{" "}" ] @punctuation"#; + + let s = "Ok(Some(42)) foo BAR"; + let b = Buffer::new_unnamed(0, s); + let gb = &b.txt; + let ts = + TsState::try_new_from_language("rust", tree_sitter_rust::LANGUAGE.into(), query, gb) + .unwrap(); + + assert_eq!( + ts.t.range_tokens(), + vec![ + rt("constant.builtin", 0, 2), // Ok + rt("punctuation", 2, 3), // ( + rt("constant.builtin", 3, 7), // Some + rt("punctuation", 7, 8), // ( + rt("punctuation", 10, 11), // ) + rt("punctuation", 11, 12), // ) + rt("variable", 13, 16), // foo + rt("constant", 17, 20), // BAR + ] + ); } }