Skip to content

Commit

Permalink
test(max_words): improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
g-prz committed Jan 16, 2025
1 parent b5d5226 commit e2e7318
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
3 changes: 3 additions & 0 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ def stream(
if isinstance(stop_at, str):
stop_at = [stop_at]

if max_words and max_tokens is None:
max_tokens = 3 * max_words

stop_sequences = stop_at
num_samples = self.num_samples

Expand Down
10 changes: 8 additions & 2 deletions tests/generate/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,11 @@ def __call__(self, biased_logits, *_):
return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None

generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
result = generator("test", max_words=5)
assert result == "test test test test test"
result = generator("test", max_words=3)
assert result == "test test test"

sequence = generator.stream("test", max_words=3)
assert "".join(sequence) == "test test test"


def test_generator_max_tokens_from_max_words():
Expand Down Expand Up @@ -583,3 +586,6 @@ def __call__(self, biased_logits, *_):
generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu")
result = generator("test", max_words=2) # should generate max_words * 3 tokens
assert result == "123456"

sequence = generator.stream("test", max_words=2)
assert "".join(sequence) == "123456"

0 comments on commit e2e7318

Please sign in to comment.