Skip to content

Commit 8ffd905

Browse files
committed
Fix varlen generation by passing seq_idx to causal_conv1d
1 parent ddce0c1 commit 8ffd905

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

mamba_ssm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.2.1"
1+
__version__ = "2.2.2"
22

33
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
44
from mamba_ssm.modules.mamba_simple import Mamba

mamba_ssm/modules/mamba2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
226226
conv_state.copy_(conv_varlen_states)
227227
assert self.activation in ["silu", "swish"]
228228
if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
229+
assert seq_idx is None, "varlen conv1d requires the causal_conv1d package"
229230
xBC = self.act(
230231
self.conv1d(xBC.transpose(1, 2)).transpose(1, 2)[:, -(self.dconv - 1):]
231232
) # (B, L, self.d_ssm + 2 * ngroups * d_state)
@@ -235,6 +236,7 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
235236
rearrange(self.conv1d.weight, "d 1 w -> d w"),
236237
bias=self.conv1d.bias,
237238
activation=self.activation,
239+
seq_idx=seq_idx,
238240
).transpose(1, 2)
239241
x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
240242
y = mamba_chunk_scan_combined(

tests/test_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,4 @@ def test_generation_varlen():
110110
sequences.append(sampled_tokens)
111111
out_varlen = torch.cat(scores, dim=1)
112112
print(f"Max diff: {(out_varlen - out_ref).abs().max()}")
113-
assert (out_varlen - out_ref).abs().max() < 5 * (out_loop - out_ref).abs().max()
113+
assert (out_varlen - out_ref).abs().max() < 2 * (out_loop - out_ref).abs().max()

0 commit comments

Comments
 (0)