Skip to content

Commit

Permalink
Bump outlines-core to v0.2.3
Browse files Browse the repository at this point in the history
  • Loading branch information
yvan-sraka committed Feb 11, 2025
1 parent 69418da commit 3c2893d
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 65 deletions.
2 changes: 1 addition & 1 deletion benchmarks/bench_json_schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema

from outlines.caching import cache_disabled
from outlines.fsm.guide import RegexGuide
Expand Down
2 changes: 1 addition & 1 deletion docs/cookbook/chain_of_thought.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ We could generate a response using the json schema but for a change we will use

```python
from outlines.fsm.json_schema import convert_json_schema_to_str
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema

schema_str = convert_json_schema_to_str(json_schema=json_schema)
regex_str = build_regex_from_schema(schema_str)
Expand Down
2 changes: 1 addition & 1 deletion docs/cookbook/react_agent.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ We could generate a response using the json schema but we will use the regex and

```python
from outlines.fsm.json_schema import convert_json_schema_to_str
from outlines_core.fsm.json_schema import build_regex_from_schema
from outlines_core.json_schema import build_regex_from_schema

json_schema = Decision.model_json_schema()
schema_str = convert_json_schema_to_str(json_schema=json_schema)
Expand Down
238 changes: 186 additions & 52 deletions outlines/fsm/guide.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
import collections
import copy
import warnings
from typing import TYPE_CHECKING, Any, Generator, Union
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Generator, List, Optional, Union, Tuple, Dict, Set, cast

import torch
from lark.indenter import DedentError
from lark.lexer import UnexpectedCharacters, UnexpectedToken
from outlines_core.fsm.guide import Generate
from outlines_core.fsm.guide import Guide as CoreGuide
from outlines_core.fsm.guide import RegexGuide as CoreRegexGuide
from outlines_core.fsm.guide import Write
from outlines_core.fsm.guide import (
create_states_mapping as uncached_create_states_mapping,
)
from outlines_core import Guide as CoreGuide
from outlines_core import Index, Vocabulary

from outlines import grammars
from outlines.fsm.parsing import PartialLark, PartialParserState
Expand All @@ -21,6 +18,35 @@
from outlines.models.tokenizer import Tokenizer


@dataclass(frozen=True)
class Write:
"""Write instruction.
Attributes
----------
tokens
The sequence of tokens to be added to the current sequence by the
generation process.
"""

tokens: List[int]


@dataclass(frozen=True)
class Generate:
"""Generate instruction
Attributes
----------
tokens
The tokens that lead to a valid completion if generated. A value
of ``None`` indicates that all tokens are allowed.
"""

tokens: Optional[List[int]]


Instruction = Union[Write, Generate]


Expand All @@ -44,13 +70,14 @@ class StopAtEOSGuide(Guide):
start_state = 0 # TODO: remove start_state, use only initial_state
initial_state = 0

def __init__(self, tokenizer: "Tokenizer"):
def __init__(self, tokenizer: "Tokenizer", index: Index):
"""Initialize the generation guide.
model
The logit generator used to generate the next token.
"""
super().__init__(index)
self.eos_token_id = tokenizer.eos_token_id
self.vocabulary = tokenizer.vocabulary.values()

Expand All @@ -72,34 +99,68 @@ def copy(self):
return self


def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)
class RegexGuide(Guide):
"""Guide to generate text in the language of a regular expression."""

states_to_token_maps = None

class RegexGuide(CoreRegexGuide):
"""
Guide to generate text in the language of a regular expression.
CoreRegexGuide with outlines cache
"""
def __init__(self, guide, eos_tensor, index: Index):
super().__init__(index)
self.eos_tensor = eos_tensor
self._guide = guide

@classmethod
def from_regex(
cls,
regex_string: str,
tokenizer,
**kwargs,
):
return super().from_regex(
regex_string,
tokenizer,
_create_states_mapping=cached_create_states_mapping,
**kwargs,
)
def from_regex(cls, regex, tokenizer):
vocabulary = tokenizer.vocabulary
index = Index(regex, vocabulary)
guide = Guide(index)

eos_tensor = torch.tensor([vocabulary.get_eos_token_id()])
return cls(guide, eos_tensor, index)

def get_next_instruction(self):
if self.is_final_state():
return self.eos_tensor
return None

def get_next_state(self, token_id):
if token_id == self.eos_tensor or self.is_final_state():
return self._guide.final_state
return self._guide.advance(token_id)

def is_final_state(self):
return self._guide.is_finished()

def copy(self, index):
return RegexGuide(self._guide, self.eos_tensor, index)


CFGState = collections.namedtuple("CFGState", ["parser_state", "prev_token"])


def can_terminate_state(state: CFGState) -> bool:
"""Generation is allowed to terminate"""
if state.parser_state is not None:
try:
copy.copy(state.parser_state).feed_eof()
except UnexpectedToken:
return False
return True


def is_final_state(state: CFGState) -> bool:
# TODO: remove this method, use can_terminate_state and must_terminate_state
# here and in RegexGuide per https://github.com/dottxt-ai/outlines/issues/885
return can_terminate_state(state)


def must_terminate_state(state: CFGState) -> bool:
"""Generation must terminate, no legal continuations"""
return state.parser_state is None or set(state.parser_state.accepts()).issubset(
{"$END"}
)


class CFGGuide(Guide):
"""Guide to generate text that is in the language of a context-free Lark grammar."""

Expand Down Expand Up @@ -164,23 +225,22 @@ def iter_valid_token_ids(
Parameters
----------
parser_state
state: CFGState
The current state of the parser, or None if complete.
token_ids
candidate_token_ids: list
The list of token ids to check for validity.
Yields
------
int
Valid token ids.
"""
if state.parser_state is None:
yield self.eos_token_id
return

for token_id in candidate_token_ids:
if token_id == self.eos_token_id:
if self.can_terminate_state(state):
if can_terminate_state(state):
yield token_id
else:
try:
Expand Down Expand Up @@ -251,26 +311,100 @@ def _get_parser_state_token_applied(

return parser_state

def is_final_state(self, state: CFGState) -> bool:
# TODO: remove this method, use can_terminate_state and must_terminate_state
# here and in RegexGuide per https://github.com/dottxt-ai/outlines/issues/885
return self.can_terminate_state(state)

def can_terminate_state(self, state: CFGState) -> bool:
"""Generation is allowed to terminate"""
if state.parser_state is not None:
try:
copy.copy(state.parser_state).feed_eof()
except UnexpectedToken:
return False
return True

def must_terminate_state(self, state: CFGState) -> bool:
"""Generation must terminate, no legal continuations"""
return state.parser_state is None or set(state.parser_state.accepts()).issubset(
{"$END"}
)

def copy(self) -> "CFGGuide":
"""Create a copy of the Guide."""
return CFGGuide(self.cfg_string, self.tokenizer)



def byte_symbol(byte: int) -> str:
return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte)



# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
@lru_cache()
def gpt2_bytes_to_unicode():
"""
Returns list of utf-8 byte and a mapping to Unicode strings. We specifically avoid mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on Unicode strings. This means you need a large # of Unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and Unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))



@lru_cache()
def gpt2_unicode_to_bytes():
return {v: k for k, v in gpt2_bytes_to_unicode().items()}

@lru_cache
def reduced_vocabulary(
tokenizer,
re_replacement_seq=None, re_llama_byte_token=None) -> Tuple[Dict[str, List[int]], Set[int]]:
"""Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
# TODO FIXME: See if we can get the underlying Rust tokenizers from HF and
# do all this in Rust
empty_token_ids = set()
vocabulary: Dict[str, List[int]] = {}
for token, token_idx in tokenizer.vocabulary.items():
if token in tokenizer.special_tokens:
continue

token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string(
token
)

if token_str:
if isinstance(token, bytes):
# Handle BPE tokenizers where the tokens are directly stored as bytes
# https://github.com/QwenLM/Qwen/blob/main/tokenization_note.md#regular-tokens
token_str = "".join(byte_symbol(b) for b in token)

elif "\ufffd" in token_str and not re_replacement_seq.match(token):
# invalid utf-8 sequences are replaced with � (\ufffd), but there
# might also be tokens specifically for �, ��, ���, etc.

if re_llama_byte_token.match(token):
# llama-like tokenizers have <0xXX> tokens for all
# bytes >= 0x80 and represent all incomplete utf-8
# sequences using such tokens
token_bytes = [int(token[3:5], 16)]
else:
# gpt2-like tokenizers have multibyte tokens that can
# have a mix of full and incomplete utf-8 characters,
# for example, b` \xf0` can be one token; these tokenizers
# map each byte to a valid utf-8 character
token_bytes = cast(
List[int], [gpt2_unicode_to_bytes().get(c) for c in token]
)
if None in token_bytes:
raise RuntimeError(
f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}"
)
token_str = "".join(byte_symbol(b) for b in token_bytes)

assert isinstance(token_str, str)

vocabulary.setdefault(token_str, []).append(token_idx)
else:
empty_token_ids.add(token_idx)

return vocabulary, empty_token_ids
Loading

0 comments on commit 3c2893d

Please sign in to comment.