@@ -33,48 +33,39 @@ To do so, we write subclass the [`Model`](hfppl.modeling.Model) class:
33
33
``` python
34
34
# examples/no_e.py
35
35
36
- from hfppl import Model, LMContext, TokenCategorical, CachedCausalLM
36
+ from hfppl import Model, LMContext, CachedCausalLM
37
37
38
38
# A LLaMPPL model subclasses the Model class
39
39
class MyModel (Model ):
40
40
41
41
# The __init__ method is used to process arguments
42
42
# and initialize instance variables.
43
43
def __init__ (self , lm , prompt , forbidden_letter ):
44
-
45
- # Always call the superclass's __init__.
46
44
super ().__init__ ()
47
45
48
46
# A stateful context object for the LLM, initialized with the prompt
49
47
self .context = LMContext(lm, prompt)
50
-
48
+ self .eos_token = lm.tokenizer.eos_token_id
49
+
51
50
# The forbidden letter
52
- self .forbidden_tokens = [ i for (i, v) in enumerate (lm.vocab)
53
- if forbidden_letter in v]
54
-
51
+ self .forbidden_tokens = set ( i for (i, v) in enumerate (lm.vocab)
52
+ if forbidden_letter in v)
53
+
55
54
# The step method is used to perform a single 'step' of generation.
56
55
# This might be a single token, a single phrase, or any other division.
57
56
# Here, we generate one token at a time.
58
57
async def step (self ):
59
- # Sample a token from the LLM -- automatically extends `self.context`.
60
- # We use `await` so that LLaMPPL can automatically batch language model calls.
61
- token = await self .sample(self .context.next_token(),
62
- proposal = self .proposal())
63
-
64
- # Condition on the token not having the forbidden letter
65
- self .condition(token.token_id not in self .forbidden_tokens)
58
+ # Condition on the next token *not* being a forbidden token.
59
+ await self .observe(self .context.mask_dist(self .forbidden_tokens), False )
60
+
61
+ # Sample the next token from the LLM -- automatically extends `self.context`.
62
+ token = await self .sample(self .context.next_token())
66
63
67
64
# Check for EOS or end of sentence
68
- if token.token_id == self .context.lm.tokenizer.eos_token_id or str (token) in [' .' , ' !' , ' ?' ]:
65
+ if token.token_id == self .eos_token or str (token) in [' .' , ' !' , ' ?' ]:
69
66
# Finish generation
70
67
self .finish()
71
68
72
- # Helper method to define a custom proposal
73
- def proposal (self ):
74
- logits = self .context.next_token_logprobs.copy()
75
- logits[self .forbidden_tokens] = - float (' inf' )
76
- return TokenCategorical(self .context.lm, logits)
77
-
78
69
# To improve performance, a hint that `self.forbidden_tokens` is immutable
79
70
def immutable_properties (self ):
80
71
return set ([' forbidden_tokens' ])
0 commit comments