Skip to content

Commit c41db95

Browse files
committed
test(max_words): improve coverage
1 parent cb65326 commit c41db95

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

outlines/generate/api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,9 @@ def stream(
318318
if isinstance(stop_at, str):
319319
stop_at = [stop_at]
320320

321+
if max_words and max_tokens is None:
322+
max_tokens = 3 * max_words
323+
321324
stop_sequences = stop_at
322325
num_samples = self.num_samples
323326

tests/generate/test_generator.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,8 +536,11 @@ def __call__(self, biased_logits, *_):
536536
return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None
537537

538538
generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
539-
result = generator("test", max_words=5)
540-
assert result == "test test test test test"
539+
result = generator("test", max_words=3)
540+
assert result == "test test test"
541+
542+
sequence = generator.stream("test", max_words=3)
543+
assert "".join(sequence) == "test test test"
541544

542545

543546
def test_generator_max_tokens_from_max_words():
@@ -583,3 +586,6 @@ def __call__(self, biased_logits, *_):
583586
generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
584587
result = generator("test", max_words=2) # should generate max_words * 3 tokens
585588
assert result == "123456"
589+
590+
sequence = generator.stream("test", max_words=2)
591+
assert "".join(sequence) == "123456"

0 commit comments

Comments
 (0)