Skip to content

Commit f427399

Browse files
author
Ralf Waldukat
committed
fix: critical fixes for recurrent/hybrid model support
After external code review (GPT-5.2), fixed 4 critical issues: 1. CRITICAL: Fixed tokens[:-1] bug in prefix matching - Was silently breaking prefix matching for ALL models - Caused false rewind detection and cache inefficiency - Impact: Transformers AND recurrent models 2. CRITICAL: Implement proper reset() for recurrent models - Now actually clears llama_memory backend state - Root cause fix for 'sequence positions not consecutive' crash - Without this, reset was a no-op for recurrent models 3. CRITICAL: Enforce strict append policy for recurrent models - Prevents KV cache rewinding that's impossible without state snapshots - Forces full reset on history edits instead of crashing 4. Performance: Cache _is_recurrent to avoid repeated FFI calls 5. Documentation: Simplified comments and updated docstring 6. Testing: All existing tests pass + Mistral-Small-3.2-24B validated Resolves multi-turn crashes for Nemotron-A3B, Mamba, RWKV, Jamba models. Reviewed-by: GPT-5.2 (OpenAI) Tested-by: pytest + Mistral-Small-3.2-24B Fixes: #2108 (recurrent model crashes) Compatible-with: #2109 (Granite-Docling/SmolVLM special tokens)
1 parent 831dbe5 commit f427399

File tree

1 file changed

+41
-3
lines changed

1 file changed

+41
-3
lines changed

llama_cpp/llama.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ def __init__(
190190
type_v: KV cache data type for V (default: f16)
191191
spm_infill: Use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this.
192192
193+
Note:
194+
Recurrent and hybrid models (Mamba, RWKV, Nemotron-A3B, Jamba) cannot
195+
rewind their state and require full reset on history edits. This is handled
196+
automatically to maintain compatibility. Standard transformers are unaffected.
197+
193198
Raises:
194199
ValueError: If the model path does not exist.
195200
@@ -555,6 +560,11 @@ def free_lora_adapter():
555560

556561
self._sampler = None
557562

563+
# Cache recurrent/hybrid model detection to avoid repeated FFI calls
564+
self._is_recurrent_model = llama_cpp.llama_model_is_recurrent(
565+
self._model.model
566+
) or llama_cpp.llama_model_is_hybrid(self._model.model)
567+
558568
@property
559569
def ctx(self) -> llama_cpp.llama_context_p:
560570
return self._ctx.ctx
@@ -582,6 +592,19 @@ def eval_logits(self) -> Deque[List[float]]:
582592
maxlen=self._n_ctx if self._logits_all else 1,
583593
)
584594

595+
@property
596+
def _is_recurrent(self) -> bool:
597+
"""Check if model is recurrent (SSM) or hybrid (SSM+Attention).
598+
599+
These models (Mamba, RWKV, Nemotron, Jamba, etc.) cannot rewind their
600+
recurrent state without snapshots. Only strict forward progression or
601+
full reset is allowed.
602+
603+
Returns:
604+
True if model has recurrent state that cannot be rewound.
605+
"""
606+
return self._is_recurrent_model
607+
585608
def tokenize(
586609
self, text: bytes, add_bos: bool = True, special: bool = False
587610
) -> List[int]:
@@ -640,6 +663,11 @@ def reset(self):
640663
"""Reset the model state."""
641664
self.n_tokens = 0
642665

666+
if self._is_recurrent:
667+
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
668+
if mem is not None:
669+
llama_cpp.llama_memory_clear(mem, True)
670+
643671
def eval(self, tokens: Sequence[int]):
644672
"""Evaluate a list of tokens.
645673
@@ -891,19 +919,29 @@ def generate(
891919
# Check for kv cache prefix match
892920
if reset and self.n_tokens > 0:
893921
longest_prefix = 0
894-
for a, b in zip(self._input_ids, tokens[:-1]):
922+
for a, b in zip(self._input_ids, tokens):
895923
if a == b:
896924
longest_prefix += 1
897925
else:
898926
break
927+
928+
# Recurrent models cannot rewind state; reset if needed
929+
if self._is_recurrent and longest_prefix < self.n_tokens:
930+
longest_prefix = 0
931+
reset = True
932+
if self.verbose:
933+
print(
934+
"Llama.generate: recurrent model requires full state reset",
935+
file=sys.stderr,
936+
)
937+
899938
if longest_prefix > 0:
900939
reset = False
901940
tokens = tokens[longest_prefix:]
902941
self.n_tokens = longest_prefix
903942
if self.verbose:
904943
print(
905-
f"Llama.generate: {longest_prefix} prefix-match hit, "
906-
f"remaining {len(tokens)} prompt tokens to eval",
944+
f"Llama.generate: {longest_prefix} prefix-match hit, {len(tokens)} tokens to eval",
907945
file=sys.stderr,
908946
)
909947

0 commit comments

Comments
 (0)