Skip to content

Commit

Permalink
Fix condition when rolling cache (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Nov 3, 2023
1 parent 92a403b commit 958678d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwar
will_cache_be_exceeded = self.start_pos + seqlen > self.max_seq_len

# Reset and avoid retaining state when processing context
if will_cache_be_exceeded:
if will_cache_be_exceeded and seqlen > 1:
self.start_pos = self.cache.roll_kv_n_steps(self.start_pos, n=self.start_pos)
# Slowly roll out old tokens without performance hit if exceeded during decoding
elif will_cache_be_exceeded and seqlen == 1:
Expand Down

0 comments on commit 958678d

Please sign in to comment.