diff --git a/crates/pgt_completions/src/context/base_parser.rs b/crates/pgt_completions/src/context/base_parser.rs index 93333679..83b31582 100644 --- a/crates/pgt_completions/src/context/base_parser.rs +++ b/crates/pgt_completions/src/context/base_parser.rs @@ -1,6 +1,5 @@ -use std::iter::Peekable; - use pgt_text_size::{TextRange, TextSize}; +use std::iter::Peekable; pub(crate) struct TokenNavigator { tokens: Peekable>, @@ -101,73 +100,139 @@ impl WordWithIndex { } } -/// Note: A policy name within quotation marks will be considered a single word. -pub(crate) fn sql_to_words(sql: &str) -> Result, String> { - let mut words = vec![]; - - let mut start_of_word: Option = None; - let mut current_word = String::new(); - let mut in_quotation_marks = false; - - for (current_position, current_char) in sql.char_indices() { - if (current_char.is_ascii_whitespace() || current_char == ';') - && !current_word.is_empty() - && start_of_word.is_some() - && !in_quotation_marks - { - words.push(WordWithIndex { - word: current_word, - start: start_of_word.unwrap(), - end: current_position, - }); - - current_word = String::new(); - start_of_word = None; - } else if (current_char.is_ascii_whitespace() || current_char == ';') - && current_word.is_empty() - { - // do nothing - } else if current_char == '"' && start_of_word.is_none() { - in_quotation_marks = true; - current_word.push(current_char); - start_of_word = Some(current_position); - } else if current_char == '"' && start_of_word.is_some() { - current_word.push(current_char); - in_quotation_marks = false; - } else if start_of_word.is_some() { - current_word.push(current_char) +pub(crate) struct SubStatementParser { + start_of_word: Option, + current_word: String, + in_quotation_marks: bool, + is_fn_call: bool, + words: Vec, +} + +impl SubStatementParser { + pub(crate) fn parse(sql: &str) -> Result, String> { + let mut parser = SubStatementParser { + start_of_word: None, + current_word: String::new(), + in_quotation_marks: false, + is_fn_call: false, + words: vec![], + }; + + parser.collect_words(sql); + + if parser.in_quotation_marks { + Err("String was not closed properly.".into()) } else { - start_of_word = Some(current_position); - current_word.push(current_char); + Ok(parser.words) } } - if let Some(start_of_word) = start_of_word { - if !current_word.is_empty() { - words.push(WordWithIndex { - word: current_word, - start: start_of_word, - end: sql.len(), - }); + pub fn collect_words(&mut self, sql: &str) { + for (pos, c) in sql.char_indices() { + match c { + '"' => { + if !self.has_started_word() { + self.in_quotation_marks = true; + self.add_char(c); + self.start_word(pos); + } else { + self.in_quotation_marks = false; + self.add_char(c); + } + } + + '(' => { + if !self.has_started_word() { + self.push_char_as_word(c, pos); + } else { + self.add_char(c); + self.is_fn_call = true; + } + } + + ')' => { + if self.is_fn_call { + self.add_char(c); + self.is_fn_call = false; + } else { + if self.has_started_word() { + self.push_word(pos); + } + self.push_char_as_word(c, pos); + } + } + + _ => { + if c.is_ascii_whitespace() || c == ';' { + if self.in_quotation_marks { + self.add_char(c); + } else if !self.is_empty() && self.has_started_word() { + self.push_word(pos); + } + } else if self.has_started_word() { + self.add_char(c); + } else { + self.start_word(pos); + self.add_char(c) + } + } + } + } + + if self.has_started_word() && !self.is_empty() { + self.push_word(sql.len()) } } - if in_quotation_marks { - Err("String was not closed properly.".into()) - } else { - Ok(words) + fn is_empty(&self) -> bool { + self.current_word.is_empty() + } + + fn add_char(&mut self, c: char) { + self.current_word.push(c) + } + + fn start_word(&mut self, pos: usize) { + self.start_of_word = Some(pos); + } + + fn has_started_word(&self) -> bool { + self.start_of_word.is_some() + } + + fn push_char_as_word(&mut self, c: char, pos: usize) { + self.words.push(WordWithIndex { + word: String::from(c), + start: pos, + end: pos + 1, + }); + } + + fn push_word(&mut self, current_position: usize) { + self.words.push(WordWithIndex { + word: self.current_word.clone(), + start: self.start_of_word.unwrap(), + end: current_position, + }); + self.current_word = String::new(); + self.start_of_word = None; } } +/// Note: A policy name within quotation marks will be considered a single word. +pub(crate) fn sql_to_words(sql: &str) -> Result, String> { + SubStatementParser::parse(sql) +} + #[cfg(test)] mod tests { - use crate::context::base_parser::{WordWithIndex, sql_to_words}; + use crate::context::base_parser::{SubStatementParser, WordWithIndex, sql_to_words}; #[test] fn determines_positions_correctly() { - let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (true);".to_string(); + let query = "\ncreate policy \"my cool pol\"\n\ton auth.users\n\tas permissive\n\tfor select\n\t\tto public\n\t\tusing (auth.uid());".to_string(); - let words = sql_to_words(query.as_str()).unwrap(); + let words = SubStatementParser::parse(query.as_str()).unwrap(); assert_eq!(words[0], to_word("create", 1, 7)); assert_eq!(words[1], to_word("policy", 8, 14)); @@ -181,7 +246,9 @@ mod tests { assert_eq!(words[9], to_word("to", 73, 75)); assert_eq!(words[10], to_word("public", 78, 84)); assert_eq!(words[11], to_word("using", 87, 92)); - assert_eq!(words[12], to_word("(true)", 93, 99)); + assert_eq!(words[12], to_word("(", 93, 94)); + assert_eq!(words[13], to_word("auth.uid()", 94, 104)); + assert_eq!(words[14], to_word(")", 104, 105)); } #[test] diff --git a/crates/pgt_completions/src/context/mod.rs b/crates/pgt_completions/src/context/mod.rs index 996ec6be..01e563b0 100644 --- a/crates/pgt_completions/src/context/mod.rs +++ b/crates/pgt_completions/src/context/mod.rs @@ -47,6 +47,15 @@ pub enum WrappingClause<'a> { SetStatement, AlterRole, DropRole, + + /// `PolicyCheck` refers to either the `WITH CHECK` or the `USING` clause + /// in a policy statement. + /// ```sql + /// CREATE POLICY "my pol" ON PUBLIC.USERS + /// FOR SELECT + /// USING (...) -- this one! + /// ``` + PolicyCheck, } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -78,6 +87,7 @@ pub(crate) enum NodeUnderCursor<'a> { text: NodeText, range: TextRange, kind: String, + previous_node_kind: Option, }, } @@ -222,6 +232,7 @@ impl<'a> CompletionContext<'a> { text: revoke_context.node_text.into(), range: revoke_context.node_range, kind: revoke_context.node_kind.clone(), + previous_node_kind: None, }); if revoke_context.node_kind == "revoke_table" { @@ -249,6 +260,7 @@ impl<'a> CompletionContext<'a> { text: grant_context.node_text.into(), range: grant_context.node_range, kind: grant_context.node_kind.clone(), + previous_node_kind: None, }); if grant_context.node_kind == "grant_table" { @@ -276,6 +288,7 @@ impl<'a> CompletionContext<'a> { text: policy_context.node_text.into(), range: policy_context.node_range, kind: policy_context.node_kind.clone(), + previous_node_kind: Some(policy_context.previous_node_kind), }); if policy_context.node_kind == "policy_table" { @@ -295,7 +308,13 @@ impl<'a> CompletionContext<'a> { } "policy_role" => Some(WrappingClause::ToRoleAssignment), "policy_table" => Some(WrappingClause::From), - _ => None, + _ => { + if policy_context.in_check_or_using_clause { + Some(WrappingClause::PolicyCheck) + } else { + None + } + } }; } @@ -785,7 +804,11 @@ impl<'a> CompletionContext<'a> { .is_some_and(|sib| kinds.contains(&sib.kind())) } - NodeUnderCursor::CustomNode { .. } => false, + NodeUnderCursor::CustomNode { + previous_node_kind, .. + } => previous_node_kind + .as_ref() + .is_some_and(|k| kinds.contains(&k.as_str())), } }) } diff --git a/crates/pgt_completions/src/context/policy_parser.rs b/crates/pgt_completions/src/context/policy_parser.rs index 58619502..bcc60499 100644 --- a/crates/pgt_completions/src/context/policy_parser.rs +++ b/crates/pgt_completions/src/context/policy_parser.rs @@ -22,6 +22,10 @@ pub(crate) struct PolicyContext { pub node_text: String, pub node_range: TextRange, pub node_kind: String, + pub previous_node_text: String, + pub previous_node_range: TextRange, + pub previous_node_kind: String, + pub in_check_or_using_clause: bool, } /// Simple parser that'll turn a policy-related statement into a context object required for @@ -32,6 +36,7 @@ pub(crate) struct PolicyParser { navigator: TokenNavigator, context: PolicyContext, cursor_position: usize, + in_check_or_using_clause: bool, } impl CompletionStatementParser for PolicyParser { @@ -63,6 +68,7 @@ impl CompletionStatementParser for PolicyParser { navigator: tokens.into(), context: PolicyContext::default(), cursor_position, + in_check_or_using_clause: false, } } } @@ -73,6 +79,8 @@ impl PolicyParser { return; } + self.context.in_check_or_using_clause = self.in_check_or_using_clause; + let previous = self.navigator.previous_token.take().unwrap(); match previous @@ -84,6 +92,8 @@ impl PolicyParser { self.context.node_range = token.get_range(); self.context.node_kind = "policy_name".into(); self.context.node_text = token.get_word(); + + self.context.previous_node_kind = "keyword_policy".into(); } "on" => { if token.get_word_without_quotes().contains('.') { @@ -112,17 +122,35 @@ impl PolicyParser { self.context.node_text = token.get_word(); self.context.node_kind = "policy_table".into(); } + + self.context.previous_node_kind = "keyword_on".into(); } "to" => { self.context.node_range = token.get_range(); self.context.node_kind = "policy_role".into(); self.context.node_text = token.get_word(); + + self.context.previous_node_kind = "keyword_to".into(); } - _ => { + + other => { self.context.node_range = token.get_range(); self.context.node_text = token.get_word(); + + self.context.previous_node_range = previous.get_range(); + self.context.previous_node_text = previous.get_word(); + + match other { + "(" | "=" => self.context.previous_node_kind = other.into(), + "and" => self.context.previous_node_kind = "keyword_and".into(), + + _ => self.context.previous_node_kind = "".into(), + } } } + + self.context.previous_node_range = previous.get_range(); + self.context.previous_node_text = previous.get_word(); } fn handle_token(&mut self, token: WordWithIndex) { @@ -142,6 +170,13 @@ impl PolicyParser { } "on" => self.table_with_schema(), + "(" if self.navigator.prev_matches(&["using", "check"]) => { + self.in_check_or_using_clause = true; + } + ")" => { + self.in_check_or_using_clause = false; + } + // skip the "to" so we don't parse it as the TO rolename when it's under the cursor "rename" if self.navigator.next_matches(&["to"]) => { self.navigator.advance(); @@ -218,7 +253,11 @@ mod tests { statement_kind: PolicyStmtKind::Create, node_text: "REPLACED_TOKEN".into(), node_range: TextRange::new(TextSize::new(25), TextSize::new(39)), - node_kind: "policy_name".into() + node_kind: "policy_name".into(), + in_check_or_using_clause: false, + previous_node_kind: "keyword_policy".into(), + previous_node_range: TextRange::new(18.into(), 24.into()), + previous_node_text: "policy".into(), } ); @@ -241,6 +280,10 @@ mod tests { node_text: "REPLACED_TOKEN".into(), node_kind: "".into(), node_range: TextRange::new(TextSize::new(42), TextSize::new(56)), + in_check_or_using_clause: false, + previous_node_kind: "".into(), + previous_node_range: TextRange::new(25.into(), 41.into()), + previous_node_text: "\"my cool policy\"".into(), } ); @@ -263,6 +306,10 @@ mod tests { node_text: "REPLACED_TOKEN".into(), node_kind: "policy_table".into(), node_range: TextRange::new(TextSize::new(45), TextSize::new(59)), + in_check_or_using_clause: false, + previous_node_kind: "keyword_on".into(), + previous_node_range: TextRange::new(42.into(), 44.into()), + previous_node_text: "on".into(), } ); @@ -285,6 +332,10 @@ mod tests { node_text: "REPLACED_TOKEN".into(), node_kind: "policy_table".into(), node_range: TextRange::new(TextSize::new(50), TextSize::new(64)), + in_check_or_using_clause: false, + previous_node_kind: "keyword_on".into(), + previous_node_range: TextRange::new(42.into(), 44.into()), + previous_node_text: "on".into(), } ); @@ -308,6 +359,10 @@ mod tests { node_text: "REPLACED_TOKEN".into(), node_kind: "".into(), node_range: TextRange::new(TextSize::new(72), TextSize::new(86)), + in_check_or_using_clause: false, + previous_node_kind: "".into(), + previous_node_range: TextRange::new(69.into(), 71.into()), + previous_node_text: "as".into(), } ); @@ -332,6 +387,10 @@ mod tests { node_text: "REPLACED_TOKEN".into(), node_kind: "".into(), node_range: TextRange::new(TextSize::new(95), TextSize::new(109)), + in_check_or_using_clause: false, + previous_node_kind: "".into(), + previous_node_range: TextRange::new(72.into(), 82.into()), + previous_node_text: "permissive".into(), } ); @@ -356,6 +415,10 @@ mod tests { node_text: "REPLACED_TOKEN".into(), node_kind: "policy_role".into(), node_range: TextRange::new(TextSize::new(98), TextSize::new(112)), + in_check_or_using_clause: false, + previous_node_kind: "keyword_to".into(), + previous_node_range: TextRange::new(95.into(), 97.into()), + previous_node_text: "to".into(), } ); } @@ -383,7 +446,11 @@ mod tests { statement_kind: PolicyStmtKind::Create, node_text: "REPLACED_TOKEN".into(), node_range: TextRange::new(TextSize::new(57), TextSize::new(71)), - node_kind: "policy_table".into() + node_kind: "policy_table".into(), + in_check_or_using_clause: false, + previous_node_kind: "keyword_on".into(), + previous_node_range: TextRange::new(54.into(), 56.into()), + previous_node_text: "on".into(), } ) } @@ -411,7 +478,11 @@ mod tests { statement_kind: PolicyStmtKind::Create, node_text: "REPLACED_TOKEN".into(), node_range: TextRange::new(TextSize::new(62), TextSize::new(76)), - node_kind: "policy_table".into() + node_kind: "policy_table".into(), + in_check_or_using_clause: false, + previous_node_kind: "keyword_on".into(), + previous_node_range: TextRange::new(54.into(), 56.into()), + previous_node_text: "on".into(), } ) } @@ -436,7 +507,11 @@ mod tests { statement_kind: PolicyStmtKind::Drop, node_text: "REPLACED_TOKEN".into(), node_range: TextRange::new(TextSize::new(23), TextSize::new(37)), - node_kind: "policy_name".into() + node_kind: "policy_name".into(), + in_check_or_using_clause: false, + previous_node_kind: "keyword_policy".into(), + previous_node_range: TextRange::new(16.into(), 22.into()), + previous_node_text: "policy".into(), } ); @@ -459,7 +534,11 @@ mod tests { statement_kind: PolicyStmtKind::Drop, node_text: "\"REPLACED_TOKEN\"".into(), node_range: TextRange::new(TextSize::new(23), TextSize::new(39)), - node_kind: "policy_name".into() + node_kind: "policy_name".into(), + in_check_or_using_clause: false, + previous_node_kind: "keyword_policy".into(), + previous_node_range: TextRange::new(16.into(), 22.into()), + previous_node_text: "policy".into(), } ); } @@ -477,4 +556,100 @@ mod tests { assert_eq!(context, PolicyContext::default()); } + + #[test] + fn correctly_determines_we_are_inside_checks() { + { + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" + on auth.users + to all + using (id = {}) + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some(r#""my cool policy""#.into()), + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(112), TextSize::new(126)), + node_kind: "".into(), + in_check_or_using_clause: true, + previous_node_kind: "=".into(), + previous_node_range: TextRange::new(110.into(), 111.into()), + previous_node_text: "=".into(), + } + ); + } + + { + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" + on auth.users + to all + using ({} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some(r#""my cool policy""#.into()), + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(106), TextSize::new(120)), + node_kind: "".into(), + in_check_or_using_clause: true, + previous_node_kind: "(".into(), + previous_node_range: TextRange::new(105.into(), 106.into()), + previous_node_text: "(".into(), + } + ) + } + + { + let (pos, query) = with_pos(format!( + r#" + create policy "my cool policy" + on auth.users + to all + with check ({} + "#, + CURSOR_POS + )); + + let context = PolicyParser::get_context(query.as_str(), pos); + + assert_eq!( + context, + PolicyContext { + policy_name: Some(r#""my cool policy""#.into()), + table_name: Some("users".into()), + schema_name: Some("auth".into()), + statement_kind: PolicyStmtKind::Create, + node_text: "REPLACED_TOKEN".into(), + node_range: TextRange::new(TextSize::new(111), TextSize::new(125)), + node_kind: "".into(), + in_check_or_using_clause: true, + previous_node_kind: "(".into(), + previous_node_range: TextRange::new(110.into(), 111.into()), + previous_node_text: "(".into(), + } + ) + } + } } diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 4299973b..04d0af65 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -817,4 +817,52 @@ mod tests { .await; } } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn suggests_columns_policy_using_clause(pool: PgPool) { + let setup = r#" + create table instruments ( + id bigint primary key generated always as identity, + name text not null, + z text, + created_at timestamp with time zone default now() + ); + "#; + + pool.execute(setup).await.unwrap(); + + let col_queries = vec![ + format!( + r#"create policy "my_pol" on public.instruments for select using ({})"#, + CURSOR_POS + ), + format!( + r#"create policy "my_pol" on public.instruments for insert with check ({})"#, + CURSOR_POS + ), + format!( + r#"create policy "my_pol" on public.instruments for update using (id = 1 and {})"#, + CURSOR_POS + ), + format!( + r#"create policy "my_pol" on public.instruments for insert with check (id = 1 and {})"#, + CURSOR_POS + ), + ]; + + for query in col_queries { + assert_complete_results( + query.as_str(), + vec![ + CompletionAssertion::Label("created_at".into()), + CompletionAssertion::Label("id".into()), + CompletionAssertion::Label("name".into()), + CompletionAssertion::Label("z".into()), + ], + None, + &pool, + ) + .await; + } + } } diff --git a/crates/pgt_completions/src/providers/functions.rs b/crates/pgt_completions/src/providers/functions.rs index 2bc4f331..615e4f95 100644 --- a/crates/pgt_completions/src/providers/functions.rs +++ b/crates/pgt_completions/src/providers/functions.rs @@ -65,11 +65,14 @@ fn get_completion_text(ctx: &CompletionContext, func: &Function) -> CompletionTe #[cfg(test)] mod tests { - use sqlx::PgPool; + use sqlx::{Executor, PgPool}; use crate::{ CompletionItem, CompletionItemKind, complete, - test_helper::{CURSOR_POS, get_test_deps, get_test_params}, + test_helper::{ + CURSOR_POS, CompletionAssertion, assert_complete_results, get_test_deps, + get_test_params, + }, }; #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] @@ -201,4 +204,84 @@ mod tests { assert_eq!(label, "cool"); assert_eq!(kind, CompletionItemKind::Function); } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn only_allows_functions_and_procedures_in_policy_checks(pool: PgPool) { + let setup = r#" + create table coos ( + id serial primary key, + name text + ); + + create or replace function my_cool_foo() + returns trigger + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + + create or replace procedure my_cool_proc() + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + + create or replace function string_concat_state( + state text, + value text, + separator text) + returns text + language plpgsql + as $$ + begin + if state is null then + return value; + else + return state || separator || value; + end if; + end; + $$; + + create aggregate string_concat(text, text) ( + sfunc = string_concat_state, + stype = text, + initcond = '' + ); + "#; + + pool.execute(setup).await.unwrap(); + + let query = format!( + r#"create policy "my_pol" on public.instruments for insert with check (id = {})"#, + CURSOR_POS + ); + + assert_complete_results( + query.as_str(), + vec![ + CompletionAssertion::LabelNotExists("string_concat".into()), + CompletionAssertion::LabelAndKind( + "my_cool_foo".into(), + CompletionItemKind::Function, + ), + CompletionAssertion::LabelAndKind( + "my_cool_proc".into(), + CompletionItemKind::Function, + ), + CompletionAssertion::LabelAndKind( + "string_concat_state".into(), + CompletionItemKind::Function, + ), + ], + None, + &pool, + ) + .await; + } } diff --git a/crates/pgt_completions/src/relevance/filtering.rs b/crates/pgt_completions/src/relevance/filtering.rs index a020d2e8..beea6ddb 100644 --- a/crates/pgt_completions/src/relevance/filtering.rs +++ b/crates/pgt_completions/src/relevance/filtering.rs @@ -1,3 +1,5 @@ +use pgt_schema_cache::ProcKind; + use crate::context::{CompletionContext, NodeUnderCursor, WrappingClause, WrappingNode}; use super::CompletionRelevanceData; @@ -137,17 +139,27 @@ impl CompletionFilter<'_> { && ctx.parent_matches_one_of_kind(&["field"])) } + WrappingClause::PolicyCheck => { + ctx.before_cursor_matches_kind(&["keyword_and", "("]) + } + _ => false, } } - CompletionRelevanceData::Function(_) => matches!( - clause, + CompletionRelevanceData::Function(f) => match clause { WrappingClause::From - | WrappingClause::Select - | WrappingClause::Where - | WrappingClause::Join { .. } - ), + | WrappingClause::Select + | WrappingClause::Where + | WrappingClause::Join { .. } => true, + + WrappingClause::PolicyCheck => { + ctx.before_cursor_matches_kind(&["="]) + && matches!(f.kind, ProcKind::Function | ProcKind::Procedure) + } + + _ => false, + }, CompletionRelevanceData::Schema(_) => match clause { WrappingClause::Select diff --git a/crates/pgt_completions/src/sanitization.rs b/crates/pgt_completions/src/sanitization.rs index ddc9563e..bf4d9816 100644 --- a/crates/pgt_completions/src/sanitization.rs +++ b/crates/pgt_completions/src/sanitization.rs @@ -257,10 +257,15 @@ fn cursor_between_parentheses(sql: &str, position: TextSize) -> bool { .find(|c| !c.is_whitespace()) .unwrap_or_default(); - let before_matches = before == ',' || before == '('; - let after_matches = after == ',' || after == ')'; + // (.. and |) + let after_and_keyword = &sql[position.saturating_sub(4)..position] == "and " && after == ')'; + let after_eq_sign = before == '=' && after == ')'; - before_matches && after_matches + let head_of_list = before == '(' && after == ','; + let end_of_list = before == ',' && after == ')'; + let between_list_items = before == ',' && after == ','; + + head_of_list || end_of_list || between_list_items || after_and_keyword || after_eq_sign } #[cfg(test)] @@ -444,5 +449,22 @@ mod tests { "insert into instruments (name) values (a_function(name, ))", TextSize::new(56) )); + + // will sanitize after = + assert!(cursor_between_parentheses( + // create policy my_pol on users using (id = |), + "create policy my_pol on users using (id = )", + TextSize::new(42) + )); + + // will sanitize after and + assert!(cursor_between_parentheses( + // create policy my_pol on users using (id = 1 and |), + "create policy my_pol on users using (id = 1 and )", + TextSize::new(48) + )); + + // does not break if sql is really short + assert!(!cursor_between_parentheses("(a)", TextSize::new(2))); } } diff --git a/crates/pgt_schema_cache/src/lib.rs b/crates/pgt_schema_cache/src/lib.rs index 9beb2f8a..b67f9412 100644 --- a/crates/pgt_schema_cache/src/lib.rs +++ b/crates/pgt_schema_cache/src/lib.rs @@ -14,7 +14,7 @@ mod types; mod versions; pub use columns::*; -pub use functions::{Behavior, Function, FunctionArg, FunctionArgs}; +pub use functions::{Behavior, Function, FunctionArg, FunctionArgs, ProcKind}; pub use policies::{Policy, PolicyCommand}; pub use roles::*; pub use schema_cache::SchemaCache;