Skip to content

Commit d54a59f

Browse files
committed
Update 'Getting Started' doc to use same example as README
1 parent f505208 commit d54a59f

File tree

1 file changed

+12
-21
lines changed

1 file changed

+12
-21
lines changed

docs/getting_started.md

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,48 +33,39 @@ To do so, we write subclass the [`Model`](hfppl.modeling.Model) class:
3333
```python
3434
# examples/no_e.py
3535

36-
from hfppl import Model, LMContext, TokenCategorical, CachedCausalLM
36+
from hfppl import Model, LMContext, CachedCausalLM
3737

3838
# A LLaMPPL model subclasses the Model class
3939
class MyModel(Model):
4040

4141
# The __init__ method is used to process arguments
4242
# and initialize instance variables.
4343
def __init__(self, lm, prompt, forbidden_letter):
44-
45-
# Always call the superclass's __init__.
4644
super().__init__()
4745

4846
# A stateful context object for the LLM, initialized with the prompt
4947
self.context = LMContext(lm, prompt)
50-
48+
self.eos_token = lm.tokenizer.eos_token_id
49+
5150
# The forbidden letter
52-
self.forbidden_tokens = [i for (i, v) in enumerate(lm.vocab)
53-
if forbidden_letter in v]
54-
51+
self.forbidden_tokens = set(i for (i, v) in enumerate(lm.vocab)
52+
if forbidden_letter in v)
53+
5554
# The step method is used to perform a single 'step' of generation.
5655
# This might be a single token, a single phrase, or any other division.
5756
# Here, we generate one token at a time.
5857
async def step(self):
59-
# Sample a token from the LLM -- automatically extends `self.context`.
60-
# We use `await` so that LLaMPPL can automatically batch language model calls.
61-
token = await self.sample(self.context.next_token(),
62-
proposal=self.proposal())
63-
64-
# Condition on the token not having the forbidden letter
65-
self.condition(token.token_id not in self.forbidden_tokens)
58+
# Condition on the next token *not* being a forbidden token.
59+
await self.observe(self.context.mask_dist(self.forbidden_tokens), False)
60+
61+
# Sample the next token from the LLM -- automatically extends `self.context`.
62+
token = await self.sample(self.context.next_token())
6663

6764
# Check for EOS or end of sentence
68-
if token.token_id == self.context.lm.tokenizer.eos_token_id or str(token) in ['.', '!', '?']:
65+
if token.token_id == self.eos_token or str(token) in ['.', '!', '?']:
6966
# Finish generation
7067
self.finish()
7168

72-
# Helper method to define a custom proposal
73-
def proposal(self):
74-
logits = self.context.next_token_logprobs.copy()
75-
logits[self.forbidden_tokens] = -float('inf')
76-
return TokenCategorical(self.context.lm, logits)
77-
7869
# To improve performance, a hint that `self.forbidden_tokens` is immutable
7970
def immutable_properties(self):
8071
return set(['forbidden_tokens'])

0 commit comments

Comments
 (0)