diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 994d5f149..d2bcbc5fe 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -786,6 +786,22 @@ def add_mirostat_v2(self, seed: int, tau: float, eta: float): sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta) self._add_sampler(sampler) + def add_xtc(self, probability: float, threshold: float, min_keep: int, seed: int): + sampler = llama_cpp.llama_sampler_init_xtc(probability, threshold, min_keep, seed) + self._add_sampler(sampler) + + def add_dry(self, model: LlamaModel, multiplier: float, base: float, + allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []): + + # Convert Python strings to bytes + seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers] + # Create array of char* + arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes) + sampler = llama_cpp.llama_sampler_init_dry(model.model, multiplier, base, + allowed_length, penalty_last_n, + arr, len(seq_breakers)) + self._add_sampler(sampler) + def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar): sampler = llama_cpp.llama_sampler_init_grammar( model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8") diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d15a88b00..cd2eef86b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -677,6 +677,13 @@ def _init_sampler( mirostat_mode: int = 0, mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, @@ -744,12 +751,14 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p): else: n_probs = 0 min_keep = max(1, n_probs) + sampler.add_dry(self._model, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers) sampler.add_top_k(top_k) sampler.add_typical(typical_p, min_keep) sampler.add_top_p(top_p, min_keep) sampler.add_min_p(min_p, min_keep) sampler.add_temp(temp) sampler.add_dist(self._seed) + sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed) return sampler def sample( @@ -766,6 +775,13 @@ def sample( mirostat_mode: int = 0, mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, @@ -801,6 +817,13 @@ def sample( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, penalize_nl=penalize_nl, logits_processor=logits_processor, grammar=grammar, @@ -830,6 +853,13 @@ def generate( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, @@ -869,6 +899,13 @@ def generate( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, penalize_nl=penalize_nl, logits_processor=logits_processor, grammar=grammar, @@ -921,6 +958,13 @@ def generate( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, logits_processor=logits_processor, grammar=grammar, penalize_nl=penalize_nl, @@ -1137,6 +1181,13 @@ def _create_completion( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, @@ -1325,6 +1376,13 @@ def logit_bias_processor( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, @@ -1757,6 +1815,13 @@ def create_completion( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, @@ -1820,6 +1885,13 @@ def create_completion( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, @@ -1854,6 +1926,13 @@ def __call__( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, @@ -1917,6 +1996,13 @@ def __call__( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, @@ -1948,6 +2034,13 @@ def create_chat_completion( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, @@ -2021,6 +2114,13 @@ def create_chat_completion( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=grammar, diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 457c6dddb..75d0a9567 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -3233,6 +3233,38 @@ def llama_sampler_init_xtc( ) -> llama_sampler_p: ... +# LLAMA_API struct llama_sampler * llama_sampler_init_dry( +# const struct llama_model * model, +# float dry_multiplier, +# float dry_base, +# int32_t dry_allowed_length, +# int32_t dry_penalty_last_n, +# const char ** seq_breakers, +# size_t num_breakers); +@ctypes_function( +"llama_sampler_init_dry", + [ + llama_model_p_ctypes, + ctypes.c_float, + ctypes.c_float, + ctypes.c_int32, + ctypes.c_int32, + ctypes.POINTER(ctypes.c_char_p), + ctypes.c_size_t + ], + llama_sampler_p_ctypes, +) +def llama_sampler_init_dry( + model: llama_model_p, + dry_multiplier: float, + dry_base: float, + dry_allowed_length: int, + dry_penalty_last_n: int, + seq_breakers: list[str], + num_breakers: int, +) -> llama_sampler_p: + ... + # /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. # /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.