Skip to content

Commit

Permalink
[draft] prepare outlines for outlines-core v0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
yvan-sraka committed Feb 3, 2025
1 parent 437ffe4 commit 875b369
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 38 deletions.
2 changes: 1 addition & 1 deletion benchmarks/bench_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema

from outlines.caching import cache_disabled
from outlines.fsm.guide import RegexGuide
Expand Down
2 changes: 1 addition & 1 deletion docs/cookbook/chain_of_thought.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ We could generate a response using the json schema but for a change we will use

```python
from outlines.fsm.json_schema import convert_json_schema_to_str
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema

schema_str = convert_json_schema_to_str(json_schema=json_schema)
regex_str = build_regex_from_schema(schema_str)
Expand Down
2 changes: 1 addition & 1 deletion docs/cookbook/react_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ We could generate a response using the json schema but we will use the regex and

```python
from outlines.fsm.json_schema import convert_json_schema_to_str
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema

json_schema = Decision.model_json_schema()
schema_str = convert_json_schema_to_str(json_schema=json_schema)
Expand Down
107 changes: 79 additions & 28 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import collections
import copy
import warnings
from typing import TYPE_CHECKING, Any, Generator, Union
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union

import torch
from lark.indenter import DedentError
from lark.lexer import UnexpectedCharacters, UnexpectedToken
from outlines_core.fsm.guide import Generate
from outlines_core.fsm.guide import Guide as CoreGuide
from outlines_core.fsm.guide import RegexGuide as CoreRegexGuide
from outlines_core.fsm.guide import Write
from outlines_core.fsm.guide import (
create_states_mapping as uncached_create_states_mapping,
)
from outlines_core import Guide as CoreGuide
from outlines_core import Index, Vocabulary

from outlines import grammars
from outlines.fsm.parsing import PartialLark, PartialParserState
Expand All @@ -21,6 +17,35 @@
from outlines.models.tokenizer import Tokenizer


@dataclass(frozen=True)
class Write:
"""Write instruction.
Attributes
----------
tokens
The sequence of tokens to be added to the current sequence by the
generation process.
"""

tokens: List[int]


@dataclass(frozen=True)
class Generate:
"""Generate instruction
Attributes
----------
tokens
The tokens that lead to a valid completion if generated. A value
of ``None`` indicates that all tokens are allowed.
"""

tokens: Optional[List[int]]


Instruction = Union[Write, Generate]


Expand Down Expand Up @@ -72,29 +97,37 @@ def copy(self):
return self


def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)
class RegexGuide(Guide):
"""Guide to generate text in the language of a regular expression."""


class RegexGuide(CoreRegexGuide):
"""
Guide to generate text in the language of a regular expression.
CoreRegexGuide with outlines cache
"""
def __init__(self, guide, eos_tensor):
self.eos_tensor = eos_tensor
self._guide = guide

@classmethod
def from_regex(
cls,
regex_string: str,
tokenizer,
**kwargs,
):
return super().from_regex(
regex_string,
tokenizer,
_create_states_mapping=cached_create_states_mapping,
**kwargs,
)
def from_regex(cls, regex, tokenizer):
vocabulary = Vocabulary.from_pretrained(tokenizer.name_or_path())
index = Index(regex, vocabulary)
guide = Guide(index)

eos_tensor = torch.tensor([vocabulary.get_eos_token_id()])
return cls(guide, eos_tensor)

def get_next_instruction(self, state):
if self.is_final_state(state):
return self.eos_tensor
return None

def get_next_state(self, state, token_id):
if token_id == self.eos_tensor or self.is_final_state(state):
return self._guide.final_state
return self._guide.advance(token_id)

def is_final_state(self, state):
return self._guide.is_finished()

def copy(self):
return RegexGuide(self._guide, self.eos_tensor)


CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"])
Expand Down Expand Up @@ -274,3 +307,21 @@ def must_terminate_state(self, state: CFGState) -> bool:
def copy(self) -> "CFGGuide":
"""Create a copy of the Guide."""
return CFGGuide(self.cfg_string, self.tokenizer)


class BetterFSM:
def __init__(self):
self.finals = None
self.flat_transition_map = None


def get_token_transition_keys(*args, **kwargs):
pass


def make_deterministic_fsm(*args, **kwargs):
pass


def reduced_vocabulary(*args, **kwargs):
pass
3 changes: 2 additions & 1 deletion outlines/fsm/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
)
from lark.parsers.lalr_interactive_parser import InteractiveParser
from lark.parsers.lalr_parser import LALR_Parser, ParseConf, ParserState, _Parser
from outlines_core.fsm.regex import (

from outlines.fsm.guide import (
BetterFSM,
get_token_transition_keys,
make_deterministic_fsm,
Expand Down
2 changes: 1 addition & 1 deletion outlines/generate/choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import singledispatch
from typing import Callable, List, Union

from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema

from outlines.fsm.json_schema import get_schema_from_enum
from outlines.generate.api import SequenceGeneratorAdapter
Expand Down
2 changes: 1 addition & 1 deletion outlines/generate/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Optional, Union

from genson import SchemaBuilder
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema
from pydantic import BaseModel

from outlines.fsm.json_schema import get_schema_from_enum, get_schema_from_signature
Expand Down
2 changes: 1 addition & 1 deletion outlines/processors/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union

import torch
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema
from pydantic import BaseModel

from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies = [
"pycountry",
"airportsdata",
"torch",
"outlines_core==0.1.26",
"outlines_core==0.2.2",
"genson",
]
dynamic = ["version"]
Expand Down
2 changes: 1 addition & 1 deletion tests/fsm/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List

import pytest
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema
from pydantic import BaseModel, constr

from outlines.fsm.json_schema import get_schema_from_enum, get_schema_from_signature
Expand Down
2 changes: 1 addition & 1 deletion tests/generate/test_integration_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@

import pytest
import torch
from outlines_core.fsm.regex import reduced_vocabulary
from pydantic import BaseModel, constr

import outlines.generate as generate
import outlines.models as models
from outlines.fsm.guide import reduced_vocabulary
from outlines.models.transformers import Transformers, TransformerTokenizer
from outlines.samplers import beam_search, greedy, multinomial

Expand Down

0 comments on commit 875b369

Please sign in to comment.