Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for XTC and DRY samplers #1843

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,22 @@ def add_mirostat_v2(self, seed: int, tau: float, eta: float):
sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
self._add_sampler(sampler)

def add_xtc(self, probability: float, threshold: float, min_keep: int, seed: int):
sampler = llama_cpp.llama_sampler_init_xtc(probability, threshold, min_keep, seed)
self._add_sampler(sampler)

def add_dry(self, model: LlamaModel, multiplier: float, base: float,
allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []):

# Convert Python strings to bytes
seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers]
# Create array of char*
arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes)
sampler = llama_cpp.llama_sampler_init_dry(model.model, multiplier, base,
allowed_length, penalty_last_n,
arr, len(seq_breakers))
self._add_sampler(sampler)

def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
sampler = llama_cpp.llama_sampler_init_grammar(
model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
Expand Down
100 changes: 100 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,13 @@ def _init_sampler(
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
Expand Down Expand Up @@ -744,12 +751,14 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p):
else:
n_probs = 0
min_keep = max(1, n_probs)
sampler.add_dry(self._model, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers)
sampler.add_top_k(top_k)
sampler.add_typical(typical_p, min_keep)
sampler.add_top_p(top_p, min_keep)
sampler.add_min_p(min_p, min_keep)
sampler.add_temp(temp)
sampler.add_dist(self._seed)
sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed)
return sampler

def sample(
Expand All @@ -766,6 +775,13 @@ def sample(
mirostat_mode: int = 0,
mirostat_eta: float = 0.1,
mirostat_tau: float = 5.0,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
Expand Down Expand Up @@ -801,6 +817,13 @@ def sample(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
Expand Down Expand Up @@ -830,6 +853,13 @@ def generate(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
Expand Down Expand Up @@ -869,6 +899,13 @@ def generate(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
penalize_nl=penalize_nl,
logits_processor=logits_processor,
grammar=grammar,
Expand Down Expand Up @@ -921,6 +958,13 @@ def generate(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
logits_processor=logits_processor,
grammar=grammar,
penalize_nl=penalize_nl,
Expand Down Expand Up @@ -1137,6 +1181,13 @@ def _create_completion(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
Expand Down Expand Up @@ -1325,6 +1376,13 @@ def logit_bias_processor(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
repeat_penalty=repeat_penalty,
Expand Down Expand Up @@ -1757,6 +1815,13 @@ def create_completion(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
Expand Down Expand Up @@ -1820,6 +1885,13 @@ def create_completion(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
model=model,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
Expand Down Expand Up @@ -1854,6 +1926,13 @@ def __call__(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_processor: Optional[LogitsProcessorList] = None,
Expand Down Expand Up @@ -1917,6 +1996,13 @@ def __call__(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
model=model,
stopping_criteria=stopping_criteria,
logits_processor=logits_processor,
Expand Down Expand Up @@ -1948,6 +2034,13 @@ def create_chat_completion(
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
xtc_probability: float = 0.0,
xtc_threshold: float = 0.1,
dry_multiplier: float = 0.0,
dry_allowed_length: int = 2,
dry_base: float = 1.75,
dry_range: int = 0,
dry_seq_breakers: list[str] = [],
model: Optional[str] = None,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
Expand Down Expand Up @@ -2021,6 +2114,13 @@ def create_chat_completion(
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
xtc_probability=xtc_probability,
xtc_threshold=xtc_threshold,
dry_multiplier=dry_multiplier,
dry_allowed_length=dry_allowed_length,
dry_base=dry_base,
dry_range=dry_range,
dry_seq_breakers=dry_seq_breakers,
model=model,
logits_processor=logits_processor,
grammar=grammar,
Expand Down
32 changes: 32 additions & 0 deletions llama_cpp/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3233,6 +3233,38 @@ def llama_sampler_init_xtc(
) -> llama_sampler_p:
...

# LLAMA_API struct llama_sampler * llama_sampler_init_dry(
# const struct llama_model * model,
# float dry_multiplier,
# float dry_base,
# int32_t dry_allowed_length,
# int32_t dry_penalty_last_n,
# const char ** seq_breakers,
# size_t num_breakers);
@ctypes_function(
"llama_sampler_init_dry",
[
llama_model_p_ctypes,
ctypes.c_float,
ctypes.c_float,
ctypes.c_int32,
ctypes.c_int32,
ctypes.POINTER(ctypes.c_char_p),
ctypes.c_size_t
],
llama_sampler_p_ctypes,
)
def llama_sampler_init_dry(
model: llama_model_p,
dry_multiplier: float,
dry_base: float,
dry_allowed_length: int,
dry_penalty_last_n: int,
seq_breakers: list[str],
num_breakers: int,
) -> llama_sampler_p:
...


# /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
# /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
Expand Down