diff --git a/outlines/generate/api.py b/outlines/generate/api.py index e8afa526d..edf4d0fbb 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -318,6 +318,9 @@ def stream( if isinstance(stop_at, str): stop_at = [stop_at] + if max_words and max_tokens is None: + max_tokens = 3 * max_words + stop_sequences = stop_at num_samples = self.num_samples diff --git a/tests/generate/test_generator.py b/tests/generate/test_generator.py index 4a7e7da1f..f83ee62b7 100644 --- a/tests/generate/test_generator.py +++ b/tests/generate/test_generator.py @@ -536,8 +536,11 @@ def __call__(self, biased_logits, *_): return torch.argmax(biased_logits, keepdims=True), torch.tensor([0]), None generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu") - result = generator("test", max_words=5) - assert result == "test test test test test" + result = generator("test", max_words=3) + assert result == "test test test" + + sequence = generator.stream("test", max_words=3) + assert "".join(sequence) == "test test test" def test_generator_max_tokens_from_max_words(): @@ -583,3 +586,6 @@ def __call__(self, biased_logits, *_): generator = SequenceGenerator(MockFSM(), MockModel(), sampler(), "cpu") result = generator("test", max_words=2) # should generate max_words * 3 tokens assert result == "123456" + + sequence = generator.stream("test", max_words=2) + assert "".join(sequence) == "123456"