Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature request] High-level API support for DRY and XTC samplers #1813

Open
ddh0 opened this issue Oct 28, 2024 · 3 comments
Open

[Feature request] High-level API support for DRY and XTC samplers #1813

ddh0 opened this issue Oct 28, 2024 · 3 comments

Comments

@ddh0
Copy link
Contributor

ddh0 commented Oct 28, 2024

Is your feature request related to a problem? Please describe.

Recently llama.cpp added support for the DRY and XTC samplers which can help reduce repetition and increase creativity without losing coherence. It would be wonderful if users of llama-cpp-python could take advantage of these advanced samplers.

Describe the solution you'd like

Ideally the high-level API would expose parameters so that the end user / developer may use XTC and DRY, in the same way that we can currently use temperature, top-p, min-p, etc. Functions like Llama.create_completion() would be updated with these new parameters.

Additional context

I would be happy to help in any way I can with the implementation of these samplers, but I'm not sure where to start. @abetlen If there is anything I can do to help get this supported as quickly as possible, please point me in the right direction.

Thank you!

@zpin
Copy link

zpin commented Nov 12, 2024

This patch adds the DRY and XTC samplers:

diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py
index 0aff348..fb78d31 100644
--- a/llama_cpp/_internals.py
+++ b/llama_cpp/_internals.py
@@ -800,6 +800,22 @@ class LlamaSampler:
         sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta)
         self._add_sampler(sampler)
 
+    def add_xtc(self, probability: float, threshold: float, min_keep: int, seed: int):
+        sampler = llama_cpp.llama_sampler_init_xtc(probability, threshold, min_keep, seed)
+        self._add_sampler(sampler)
+
+    def add_dry(self, model: LlamaModel, multiplier: float, base: float,
+                allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []):
+
+        # Convert Python strings to bytes
+        seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers]
+        # Create array of char*
+        arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes)
+        sampler = llama_cpp.llama_sampler_init_dry(model.model, multiplier, base,
+                                                allowed_length, penalty_last_n,
+                                                arr, len(seq_breakers))
+        self._add_sampler(sampler)
+
     def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar):
         sampler = llama_cpp.llama_sampler_init_grammar(
             model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8")
diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py
index babb30c..42febcd 100644
--- a/llama_cpp/llama.py
+++ b/llama_cpp/llama.py
@@ -677,6 +677,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_eta: float = 0.1,
         mirostat_tau: float = 5.0,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         penalize_nl: bool = True,
         logits_processor: Optional[LogitsProcessorList] = None,
         grammar: Optional[LlamaGrammar] = None,
@@ -744,6 +751,7 @@ class Llama:
             else:
                 n_probs = 0
                 min_keep = max(1, n_probs)
+                sampler.add_dry(self._model, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers)
                 sampler.add_top_k(top_k)
                 sampler.add_tail_free(tfs_z, min_keep)
                 sampler.add_typical(typical_p, min_keep)
@@ -751,6 +759,7 @@ class Llama:
                 sampler.add_min_p(min_p, min_keep)
                 sampler.add_temp(temp)
                 sampler.add_dist(self._seed)
+                sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed)
         return sampler
 
     def sample(
@@ -767,6 +776,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_eta: float = 0.1,
         mirostat_tau: float = 5.0,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         penalize_nl: bool = True,
         logits_processor: Optional[LogitsProcessorList] = None,
         grammar: Optional[LlamaGrammar] = None,
@@ -802,6 +818,13 @@ class Llama:
                 mirostat_mode=mirostat_mode,
                 mirostat_tau=mirostat_tau,
                 mirostat_eta=mirostat_eta,
+                xtc_probability=xtc_probability,
+                xtc_threshold=xtc_threshold,
+                dry_multiplier=dry_multiplier,
+                dry_allowed_length=dry_allowed_length,
+                dry_base=dry_base,
+                dry_range=dry_range,
+                dry_seq_breakers=dry_seq_breakers,
                 penalize_nl=penalize_nl,
                 logits_processor=logits_processor,
                 grammar=grammar,
@@ -831,6 +854,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_tau: float = 5.0,
         mirostat_eta: float = 0.1,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         penalize_nl: bool = True,
         logits_processor: Optional[LogitsProcessorList] = None,
         stopping_criteria: Optional[StoppingCriteriaList] = None,
@@ -870,6 +900,13 @@ class Llama:
             mirostat_mode=mirostat_mode,
             mirostat_tau=mirostat_tau,
             mirostat_eta=mirostat_eta,
+            xtc_probability=xtc_probability,
+            xtc_threshold=xtc_threshold,
+            dry_multiplier=dry_multiplier,
+            dry_allowed_length=dry_allowed_length,
+            dry_base=dry_base,
+            dry_range=dry_range,
+            dry_seq_breakers=dry_seq_breakers,
             penalize_nl=penalize_nl,
             logits_processor=logits_processor,
             grammar=grammar,
@@ -922,6 +959,13 @@ class Llama:
                     mirostat_mode=mirostat_mode,
                     mirostat_tau=mirostat_tau,
                     mirostat_eta=mirostat_eta,
+                    xtc_probability=xtc_probability,
+                    xtc_threshold=xtc_threshold,
+                    dry_multiplier=dry_multiplier,
+                    dry_allowed_length=dry_allowed_length,
+                    dry_base=dry_base,
+                    dry_range=dry_range,
+                    dry_seq_breakers=dry_seq_breakers,
                     logits_processor=logits_processor,
                     grammar=grammar,
                     penalize_nl=penalize_nl,
@@ -1138,6 +1182,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_tau: float = 5.0,
         mirostat_eta: float = 0.1,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         model: Optional[str] = None,
         stopping_criteria: Optional[StoppingCriteriaList] = None,
         logits_processor: Optional[LogitsProcessorList] = None,
@@ -1326,6 +1377,13 @@ class Llama:
             mirostat_mode=mirostat_mode,
             mirostat_tau=mirostat_tau,
             mirostat_eta=mirostat_eta,
+            xtc_probability=xtc_probability,
+            xtc_threshold=xtc_threshold,
+            dry_multiplier=dry_multiplier,
+            dry_allowed_length=dry_allowed_length,
+            dry_base=dry_base,
+            dry_range=dry_range,
+            dry_seq_breakers=dry_seq_breakers,
             frequency_penalty=frequency_penalty,
             presence_penalty=presence_penalty,
             repeat_penalty=repeat_penalty,
@@ -1758,6 +1816,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_tau: float = 5.0,
         mirostat_eta: float = 0.1,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         model: Optional[str] = None,
         stopping_criteria: Optional[StoppingCriteriaList] = None,
         logits_processor: Optional[LogitsProcessorList] = None,
@@ -1821,6 +1886,13 @@ class Llama:
             mirostat_mode=mirostat_mode,
             mirostat_tau=mirostat_tau,
             mirostat_eta=mirostat_eta,
+            xtc_probability=xtc_probability,
+            xtc_threshold=xtc_threshold,
+            dry_multiplier=dry_multiplier,
+            dry_allowed_length=dry_allowed_length,
+            dry_base=dry_base,
+            dry_range=dry_range,
+            dry_seq_breakers=dry_seq_breakers,
             model=model,
             stopping_criteria=stopping_criteria,
             logits_processor=logits_processor,
@@ -1855,6 +1927,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_tau: float = 5.0,
         mirostat_eta: float = 0.1,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         model: Optional[str] = None,
         stopping_criteria: Optional[StoppingCriteriaList] = None,
         logits_processor: Optional[LogitsProcessorList] = None,
@@ -1918,6 +1997,13 @@ class Llama:
             mirostat_mode=mirostat_mode,
             mirostat_tau=mirostat_tau,
             mirostat_eta=mirostat_eta,
+            xtc_probability=xtc_probability,
+            xtc_threshold=xtc_threshold,
+            dry_multiplier=dry_multiplier,
+            dry_allowed_length=dry_allowed_length,
+            dry_base=dry_base,
+            dry_range=dry_range,
+            dry_seq_breakers=dry_seq_breakers,
             model=model,
             stopping_criteria=stopping_criteria,
             logits_processor=logits_processor,
@@ -1949,6 +2035,13 @@ class Llama:
         mirostat_mode: int = 0,
         mirostat_tau: float = 5.0,
         mirostat_eta: float = 0.1,
+        xtc_probability: float = 0.0,
+        xtc_threshold: float = 0.1,
+        dry_multiplier: float = 0.0,
+        dry_allowed_length: int = 2,
+        dry_base: float = 1.75,
+        dry_range: int = 0,
+        dry_seq_breakers: list[str] = [],
         model: Optional[str] = None,
         logits_processor: Optional[LogitsProcessorList] = None,
         grammar: Optional[LlamaGrammar] = None,
@@ -2022,6 +2115,13 @@ class Llama:
             mirostat_mode=mirostat_mode,
             mirostat_tau=mirostat_tau,
             mirostat_eta=mirostat_eta,
+            xtc_probability=xtc_probability,
+            xtc_threshold=xtc_threshold,
+            dry_multiplier=dry_multiplier,
+            dry_allowed_length=dry_allowed_length,
+            dry_base=dry_base,
+            dry_range=dry_range,
+            dry_seq_breakers=dry_seq_breakers,
             model=model,
             logits_processor=logits_processor,
             grammar=grammar,
diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py
index 66feed8..dbc9cea 100644
--- a/llama_cpp/llama_cpp.py
+++ b/llama_cpp/llama_cpp.py
@@ -3244,6 +3244,38 @@ def llama_sampler_init_xtc(
 ) -> llama_sampler_p:
     ...
 
+#    LLAMA_API struct llama_sampler *    llama_sampler_init_dry(
+#            const struct llama_model *  model,
+#                               float    dry_multiplier,
+#                               float    dry_base,
+#                             int32_t    dry_allowed_length,
+#                             int32_t    dry_penalty_last_n,
+#                          const char ** seq_breakers,
+#                              size_t    num_breakers);
+@ctypes_function(
+"llama_sampler_init_dry",
+    [
+        llama_model_p_ctypes,
+        ctypes.c_float,
+        ctypes.c_float,
+        ctypes.c_int32,
+        ctypes.c_int32,
+        ctypes.POINTER(ctypes.c_char_p),
+        ctypes.c_size_t
+    ],
+    llama_sampler_p_ctypes,
+)
+def llama_sampler_init_dry(
+    model: llama_model_p,
+    dry_multiplier: float,
+    dry_base: float,
+    dry_allowed_length: int,
+    dry_penalty_last_n: int,
+    seq_breakers: list[str],
+    num_breakers: int,
+) -> llama_sampler_p:
+    ...
+
 
 # /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
 # /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.

@ExtReMLapin
Copy link
Contributor

@zpin feel free to open a PR !

@zpin
Copy link

zpin commented Nov 25, 2024

Done: #1843

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants