From 750d25ab59c7bc25b589424ce27e4b1e6f8556a9 Mon Sep 17 00:00:00 2001 From: PoJu Chen Date: Thu, 14 Dec 2023 17:54:37 -0600 Subject: [PATCH] fix-issue-415-integer-type-error --- src/llm_vm/guided_completion.py | 104 ++++++++++++++++---------------- 1 file changed, 53 insertions(+), 51 deletions(-) diff --git a/src/llm_vm/guided_completion.py b/src/llm_vm/guided_completion.py index 463f6ba3..4db91cbc 100644 --- a/src/llm_vm/guided_completion.py +++ b/src/llm_vm/guided_completion.py @@ -3,9 +3,9 @@ import torch from lark import Lark, Transformer, v_args from lark.indenter import PythonIndenter -from transformers import (AutoModelForCausalLM, LogitsProcessorList, LogitsProcessor, AutoTokenizer) +from transformers import AutoModelForCausalLM, LogitsProcessorList, LogitsProcessor, AutoTokenizer import re -from abc import ABC,abstractmethod +from abc import ABC, abstractmethod model = models.transformers("gpt2") @@ -37,12 +37,12 @@ def create(regex, type, choices, grammar_type, *, default=None): class GenerativeCompletion(Completion): def __init__(self, generator, *generator_args): """ - Parameters: + Parameters: ----------- - + generator (Callable[[Transformers, ...generator_args], None]): Generator function to be used on the complete *generator_args (Any): Generator arguments (without model) - + """ self.generator = generator self.generator_args = generator_args @@ -62,7 +62,7 @@ def choices_completion(choices): def type_completion(type_name): if type_name not in ["float", "integer"]: raise Exception("type must be float or integer") - return GenerativeCompletion(getattr(generate, type_name)) + return GenerativeCompletion(generate.format, type_name) @staticmethod def response_completion(): @@ -78,7 +78,7 @@ def __init__(self, model_uri, tokenizer, grammar_type='python'): # Load model from HuggingFace self.model = AutoModelForCausalLM.from_pretrained(self.model_identifier) print(f"{self.model_identifier} model is ready for use on {self.model.device}", flush=True) - + self.constraint = GrammarConstraint.create(grammar_type, model_uri, tokenizer) # Initialize dict to store terminal symbols and their corresponding token mappings @@ -87,7 +87,7 @@ def __init__(self, model_uri, tokenizer, grammar_type='python'): @property def eos_token_id(self): return self.model.config.eos_token_id - + def complete(self, prompt, **model_kwargs): # Check if the specified grammar type is supported @@ -115,12 +115,18 @@ def complete(self, prompt, **model_kwargs): model_kwargs["logits_processor"] = logits_processor # Generate completion using the model - res = self.model.generate(**model_kwargs, input_ids=token_ids['input_ids'], attention_mask=token_ids['attention_mask'], eos_token_id=self.eos_token_id, pad_token_id=self.eos_token_id) + res = self.model.generate( + **model_kwargs, + input_ids=token_ids['input_ids'], + attention_mask=token_ids['attention_mask'], + eos_token_id=self.eos_token_id, + pad_token_id=self.eos_token_id, + ) # Decode completion and return the generated text, including input tokens res_text = self.tokenizer.batch_decode(res.sequences, skip_special_tokens=True)[0] return res_text.strip() - + class GrammarConstraint(ABC): # Initialize the GrammarConstraint class with the model URI and tokenizer @@ -153,8 +159,7 @@ def create(grammar_type, model_uri, tokenizer): return grammars[grammar_type](model_uri, tokenizer) -class PythonConstraint(GrammarConstraint): - +class PythonConstraint(GrammarConstraint): # Construct a filter set for valid tokens based on a given expression def construct_filter_set(self, expression): vocab = self.tokenizer.vocab @@ -164,24 +169,26 @@ def construct_filter_set(self, expression): # Preprocess expression to handle special cases if expression[0] in specials: expression = f'\{expression}' - elif len(expression) == 1 and (expression == '[' or expression == '('): + elif len(expression) == 1 and (expression == '[' or expression == '('): expression = f'\{expression}' - + try: # Compile the regex pattern and use it to match valid tokens in the vocabulary pattern = re.compile(expression, re.UNICODE) for token, id in vocab.items(): if pattern.match(token) is not None: - valid_tokens.append((token, id)) - except Exception as e: - print("Regex Compiling Error: ", f"{e} - {expression}") - + valid_tokens.append((token, id)) + except Exception as e: + print("Regex Compiling Error: ", f"{e} - {expression}") + # Return a set of valid tokens based on the pattern return set(valid_tokens) - + def parse_grammar(self): # Create the Python parser with Lark, using the LALR algorithm - self.parser = Lark.open_from_package('lark', 'python.lark', ['grammars'], parser='lalr', regex=True, lexer='contextual', postlex=PythonIndenter(), start='file_input') + self.parser = Lark.open_from_package( + 'lark', 'python.lark', ['grammars'], parser='lalr', regex=True, lexer='contextual', postlex=PythonIndenter(), start='file_input' + ) terminals = self.parser.terminals t_map = {} for t in terminals: @@ -193,12 +200,12 @@ def parse_grammar(self): def _prefix_state(self, prefix_str=None, last_token=None): valid_next = [] if self._parser_state is None: - try: + try: # Parse the entire token sequence interactive_tree = self.parser.parse_interactive(prefix_str) interactive_tree.exhaust_lexer() self._parser_state = interactive_tree.copy() - + # Get the valid next states valid_next = list(interactive_tree.accepts()) except Exception as e: @@ -220,7 +227,7 @@ def _prefix_state(self, prefix_str=None, last_token=None): def construct_final_filter_set(self, prefix_ids, terminals_map): valid_next_ids = [] - + if self._copy_state == True: # Decode only the last prefix ID last_token = self.tokenizer.batch_decode(prefix_ids[:, -1], skip_special_tokens=True)[0] @@ -239,13 +246,12 @@ def construct_final_filter_set(self, prefix_ids, terminals_map): for t in token_set: # Add valid token IDs to the list valid_next_ids.append(t[-1]) - + # Return a set of valid next token IDs return set(valid_next_ids) -class JSONConstraint(GrammarConstraint): - +class JSONConstraint(GrammarConstraint): # Construct a filter set for valid tokens based on a given expression def construct_filter_set(self, expression): vocab = self.tokenizer.vocab @@ -255,7 +261,7 @@ def construct_filter_set(self, expression): # Preprocess expression to handle special cases if expression[0] in specials: expression = f'\{expression}' - elif len(expression) == 1 and (expression == '[' or expression == '('): + elif len(expression) == 1 and (expression == '[' or expression == '('): expression = f'\{expression}' try: @@ -263,13 +269,13 @@ def construct_filter_set(self, expression): pattern = re.compile(expression, re.UNICODE) for token, id in vocab.items(): if pattern.match(token) is not None: - valid_tokens.append((token, id)) - except Exception as e: - print(e, expression) - + valid_tokens.append((token, id)) + except Exception as e: + print(e, expression) + # Return a set of valid tokens based on the pattern return set(valid_tokens) - + def parse_grammar(self): # Define JSON grammar json_grammar = r""" @@ -295,8 +301,8 @@ def parse_grammar(self): %ignore WS """ - - # Create Lark internal transformer to make parsing faster and more memory efficient + + # Create Lark internal transformer to make parsing faster and more memory efficient class TreeToJson(Transformer): @v_args(inline=True) def string(self, s): @@ -311,28 +317,25 @@ def string(self, s): true = lambda self, _: True false = lambda self, _: False - # Create the JSON parser with Lark, using the LALR algorithm - self.parser = Lark(json_grammar, parser='lalr', - lexer='contextual', - transformer=TreeToJson()) + self.parser = Lark(json_grammar, parser='lalr', lexer='contextual', transformer=TreeToJson()) terminals = self.parser.terminals t_map = {} for t in terminals: t_map[t.name] = t.pattern.value - + # Return a map of terminal tokens and their corresponding regex patterns return t_map def _prefix_state(self, prefix_str=None, last_token=None): valid_next = [] if self._parser_state is None: - try: + try: # Parse the entire token sequence interactive_tree = self.parser.parse_interactive(prefix_str) interactive_tree.exhaust_lexer() self._parser_state = interactive_tree.copy() - + # Get the valid next states valid_next = list(interactive_tree.accepts()) except Exception as e: @@ -354,7 +357,7 @@ def _prefix_state(self, prefix_str=None, last_token=None): def construct_final_filter_set(self, prefix_ids, terminals_map): valid_next_ids = [] - + if self._copy_state == True: # Decode only the last prefix ID last_token = self.tokenizer.batch_decode(prefix_ids[:, -1], skip_special_tokens=True)[0] @@ -373,33 +376,32 @@ def construct_final_filter_set(self, prefix_ids, terminals_map): for t in token_set: # Add valid token IDs to the list valid_next_ids.append(t[-1]) - + # Return a set of valid next token IDs return set(valid_next_ids) class GrammarLogitsProcessor(LogitsProcessor): - def __init__(self, constraint_class, terminals_map): # Initialize the GrammarLogitsProcessor class with a constraint class and terminals map. - self.constraint_class = constraint_class - self.terminals_map = terminals_map + self.constraint_class = constraint_class + self.terminals_map = terminals_map def __call__(self, input_ids, scores): # This method is called for each generation step - + # Initialize a boolean bias tensor with the same shape as scores bias = torch.zeros_like(scores, dtype=torch.bool) - + # Get the set of valid next token IDs for current step valid_next_ids = self.constraint_class.construct_final_filter_set(input_ids, self.terminals_map) - + # Set the bias to True for valid next token IDs for id in valid_next_ids: bias[0, id] = True - + # Add the bias to the scores tensor to zero-out invalid next tokens scores += bias - + # Return the modified scores tensor return scores