Skip to content

Commit

Permalink
[draft] prepare outlines for outlines-core v0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
yvan-sraka committed Feb 3, 2025
1 parent 437ffe4 commit 0aab081
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 29 deletions.
89 changes: 61 additions & 28 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import collections
import copy
import warnings
from typing import TYPE_CHECKING, Any, Generator, Union
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union

import torch
from lark.indenter import DedentError
from lark.lexer import UnexpectedCharacters, UnexpectedToken
from outlines_core.fsm.guide import Generate
from outlines_core.fsm.guide import Guide as CoreGuide
from outlines_core.fsm.guide import RegexGuide as CoreRegexGuide
from outlines_core.fsm.guide import Write
from outlines_core.fsm.guide import (
create_states_mapping as uncached_create_states_mapping,
)
from outlines_core.guide import Guide as CoreGuide
from outlines_core.guide import Index, Vocabulary

from outlines import grammars
from outlines.fsm.parsing import PartialLark, PartialParserState
Expand All @@ -21,6 +17,35 @@
from outlines.models.tokenizer import Tokenizer


@dataclass(frozen=True)
class Write:
"""Write instruction.
Attributes
----------
tokens
The sequence of tokens to be added to the current sequence by the
generation process.
"""

tokens: List[int]


@dataclass(frozen=True)
class Generate:
"""Generate instruction
Attributes
----------
tokens
The tokens that lead to a valid completion if generated. A value
of ``None`` indicates that all tokens are allowed.
"""

tokens: Optional[List[int]]


Instruction = Union[Write, Generate]


Expand Down Expand Up @@ -72,29 +97,37 @@ def copy(self):
return self


def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)
class RegexGuide(Guide):
"""Guide to generate text in the language of a regular expression."""


class RegexGuide(CoreRegexGuide):
"""
Guide to generate text in the language of a regular expression.
CoreRegexGuide with outlines cache
"""
def __init__(self, guide, eos_tensor):
self.eos_tensor = eos_tensor
self._guide = guide

@classmethod
def from_regex(
cls,
regex_string: str,
tokenizer,
**kwargs,
):
return super().from_regex(
regex_string,
tokenizer,
_create_states_mapping=cached_create_states_mapping,
**kwargs,
)
def from_regex(cls, regex, tokenizer):
vocabulary = Vocabulary.from_pretrained(tokenizer.name_or_path())
index = Index(regex, vocabulary)
guide = Guide(index)

eos_tensor = torch.tensor([vocabulary.get_eos_token_id()])
return cls(guide, eos_tensor)

def get_next_instruction(self, state):
if self.is_final_state(state):
return self.eos_tensor
return None

def get_next_state(self, state, token_id):
if token_id == self.eos_tensor or self.is_final_state(state):
return self._guide.final_state
return self._guide.advance(token_id)

def is_final_state(self, state):
return self._guide.is_finished()

def copy(self):
return RegexGuide(self._guide, self.eos_tensor)


CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"])
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"pycountry",
"airportsdata",
"torch",
"outlines_core==0.1.26",
"outlines_core==0.2.2",
"genson",
]
dynamic = ["version"]
Expand Down

0 comments on commit 0aab081

Please sign in to comment.