Skip to content

Commit 7d71e06

Browse files
committed
Use black code formatter
1 parent d54a59f commit 7d71e06

19 files changed

+545
-351
lines changed

examples/grammar_constraint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
1010
Requires synchromesh (github.com/kanishkg/synchromesh)
1111
"""
12+
1213
import asyncio
1314
import os
1415
from typing import List

examples/haiku.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,37 @@
44
import os
55

66
# download the CMU pronunciation dictionary (if we haven't already)
7-
nltk.download('cmudict')
7+
nltk.download("cmudict")
88

99
# Load the CMU pronunciation dictionary and use it for syllable counting
1010
from nltk.corpus import cmudict
11+
1112
CMUDICT = cmudict.dict()
1213

14+
1315
def count_syllables(word, unknown_word_syllables=100):
14-
16+
1517
# Use the dictionary to get the list of possible phonetic representations for the word
1618
phonetic_transcriptions = CMUDICT.get(word.strip().lower(), [])
17-
19+
1820
# Count the number of syllables based on the number of phonetic transcriptions
19-
syllable_count = min([len([ph for ph in transcription if ph[-1].isdigit()]) for transcription in phonetic_transcriptions], default=unknown_word_syllables)
21+
syllable_count = min(
22+
[
23+
len([ph for ph in transcription if ph[-1].isdigit()])
24+
for transcription in phonetic_transcriptions
25+
],
26+
default=unknown_word_syllables,
27+
)
2028

2129
return syllable_count
2230

31+
2332
# Load the language model (llama2 if authorized, else mistral-7b).
24-
if 'HF_AUTH_TOKEN' in os.environ:
25-
HF_AUTH_TOKEN = os.environ['HF_AUTH_TOKEN']
26-
LLM = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN)
33+
if "HF_AUTH_TOKEN" in os.environ:
34+
HF_AUTH_TOKEN = os.environ["HF_AUTH_TOKEN"]
35+
LLM = CachedCausalLM.from_pretrained(
36+
"meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN
37+
)
2738
else:
2839
LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
2940

@@ -74,21 +85,22 @@ def count_syllables(word, unknown_word_syllables=100):
7485
# Useful constants
7586
NEWLINE_TOKEN, EOS_TOKEN = 13, LLM.tokenizer.eos_token_id
7687

88+
7789
# LLaMPPL model
7890
class Haiku(Model):
79-
91+
8092
def __init__(self, prompt, syllable_pattern=[5, 7, 5]):
8193
super().__init__()
8294
self.context = LMContext(LLM, prompt, 0.7)
8395
self.syllable_pattern = syllable_pattern
84-
96+
8597
async def step(self):
8698
# Get the number of syllables required in the next line
8799
syllables_remaining = self.syllable_pattern.pop(0)
88-
100+
89101
# Loop to sample words until this line is over
90102
while syllables_remaining > 0:
91-
103+
92104
# Sample a word
93105
word, punctuation = await self.call(sample_word(self.context))
94106

@@ -103,18 +115,19 @@ async def step(self):
103115
await self.observe(self.context.next_token(), EOS_TOKEN)
104116
self.finish()
105117
return
106-
118+
107119
# Otherwise, observe a line break
108120
await self.observe(self.context.next_token(), NEWLINE_TOKEN)
109121

110122
# Print current result
111123
print(str(self.context))
112124

125+
113126
# Run inference
114-
SYLLABLES_PER_LINE = [5, 7, 5] # [5, 3, 5] for a Lune
127+
SYLLABLES_PER_LINE = [5, 7, 5] # [5, 3, 5] for a Lune
115128
particles = asyncio.run(smc_standard(Haiku(poem_prompt, SYLLABLES_PER_LINE), 120))
116129

117130
print("--------")
118-
for (i,particle) in enumerate(particles):
131+
for i, particle in enumerate(particles):
119132
print(f"Poem {i} (weight {particle.weight}):")
120-
print(f"{particle.context}")
133+
print(f"{particle.context}")

examples/hard_constraints.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,30 @@
44

55
import os
66

7-
if 'HF_AUTH_TOKEN' in os.environ:
8-
HF_AUTH_TOKEN = os.environ['HF_AUTH_TOKEN']
7+
if "HF_AUTH_TOKEN" in os.environ:
8+
HF_AUTH_TOKEN = os.environ["HF_AUTH_TOKEN"]
99

10-
# Load the language model.
10+
# Load the language model.
1111
# Mistral and Vicuna are open models; to use a model with restricted access, like LLaMA 2,
1212
# pass your HuggingFace API key as the optional `auth_token` argument:
1313
# LLM = CachedCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", auth_token=HF_AUTH_TOKEN)
14-
# LLM = CachedCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5")
15-
LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
14+
LLM = CachedCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5")
15+
# LLM = CachedCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
1616
LLM.batch_size = 40
1717

18-
MASKS = {i : set(j for (j,v) in enumerate(LLM.vocab)
19-
if j != LLM.tokenizer.eos_token_id and '\n' not in v and
20-
any(c.isalpha() or c in string.punctuation for c in v) and
21-
len(v.strip()) <= 5 and (not v[0].isalpha() or i+len(v) <= 5))
22-
for i in range(6)}
18+
MASKS = {
19+
i: set(
20+
j
21+
for (j, v) in enumerate(LLM.vocab)
22+
if j != LLM.tokenizer.eos_token_id
23+
and "\n" not in v
24+
and any(c.isalpha() or c in string.punctuation for c in v)
25+
and len(v.strip()) <= 5
26+
and (not v[0].isalpha() or i + len(v) <= 5)
27+
)
28+
for i in range(6)
29+
}
30+
2331

2432
class ConstraintModel(Model):
2533
def __init__(self, prompt, max_tokens):
@@ -33,26 +41,27 @@ async def step(self):
3341

3442
# Condition on next token being from mask
3543
await self.observe(self.context.mask_dist(mask), True)
36-
44+
3745
# Generate proposed token.
3846
token = await self.sample(self.context.next_token())
39-
47+
4048
# Reduce number of max tokens remaining
4149
self.max_tokens -= 1
42-
50+
4351
print(f"{self.context}")
4452

4553
# Check if done
4654
if token == LLM.tokenizer.eos_token_id or self.max_tokens == 0:
4755
self.finish()
4856
return
49-
57+
5058
def active_constraint_mask(self):
5159
string_so_far = str(self.context)
5260
words = string_so_far.split()
5361
last_word = words[-1] if len(words) > 0 else ""
5462
return MASKS[min(5, len(last_word))]
55-
63+
64+
5665
# From Politico.com
5766
prompt = """3 things to watch …
5867
@@ -64,10 +73,12 @@ def active_constraint_mask(self):
6473

6574
LLM.cache_kv(LLM.tokenizer.encode(prompt))
6675

76+
6777
async def main():
6878
constraint_model = ConstraintModel(prompt, 50)
6979
particles = await smc_standard(constraint_model, 40)
7080
for p in particles:
7181
print(f"{p.context}")
7282

73-
asyncio.run(main())
83+
84+
asyncio.run(main())

hfppl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@
66
from .distributions import *
77
from .modeling import *
88
from .inference import *
9-
from .chunks import *
9+
from .chunks import *

hfppl/chunks.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,51 @@
11
import string
22
from .modeling import submodel
33

4+
45
@submodel
56
async def sample_word(self, context, max_tokens=5, allow_punctuation=True):
67
"""Sample a word from the `LMContext` object `context`."""
78
last_token = context.lm.vocab[context.tokens[-1]] if len(context.tokens) > 0 else ""
89
last_character = last_token[-1] if len(last_token) > 0 else ""
9-
needs_space = last_character not in string.whitespace and last_character not in ['-', "'", '"']
10+
needs_space = last_character not in string.whitespace and last_character not in [
11+
"-",
12+
"'",
13+
'"',
14+
]
1015
if needs_space:
1116
starts_word_mask = context.lm.masks.STARTS_NEW_WORD
1217
else:
1318
starts_word_mask = context.lm.masks.CONTINUES_CURRENT_WORD
14-
19+
1520
# Force model to start a new word
1621
await self.observe(context.mask_dist(starts_word_mask), True)
1722

1823
word = ""
1924
num_tokens = 0
2025
while True:
21-
token = await self.sample(context.next_token())
22-
word += context.lm.vocab[token.token_id]
26+
token = await self.sample(context.next_token())
27+
word += context.lm.vocab[token.token_id]
2328
num_tokens += 1
2429

2530
if num_tokens == max_tokens:
26-
await self.observe(context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False)
31+
await self.observe(
32+
context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD), False
33+
)
2734
break
2835

29-
if not (await self.sample(context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD))):
36+
if not (
37+
await self.sample(
38+
context.mask_dist(context.lm.masks.CONTINUES_CURRENT_WORD)
39+
)
40+
):
3041
break
31-
42+
3243
# Sample punctuation, if desired
3344
punctuation = ""
34-
if allow_punctuation and await self.sample(context.mask_dist(context.lm.masks.PUNCTUATION)):
45+
if allow_punctuation and await self.sample(
46+
context.mask_dist(context.lm.masks.PUNCTUATION)
47+
):
3548
punctuation_token = await self.sample(context.next_token())
3649
punctuation = context.lm.vocab[punctuation_token.token_id]
3750

38-
return word, punctuation
51+
return word, punctuation

hfppl/distributions/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717
from .tokencategorical import TokenCategorical
1818
from .transformer import Transformer
1919
from .lmcontext import LMContext
20-
from .bernoulli import Bernoulli
20+
from .bernoulli import Bernoulli

hfppl/distributions/bernoulli.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
import numpy as np
44

5+
56
class Bernoulli(Distribution):
6-
"""A Bernoulli distribution.
7-
"""
8-
7+
"""A Bernoulli distribution."""
8+
99
def __init__(self, p):
1010
"""Create a Bernoulli distribution.
11-
11+
1212
Args:
1313
p: the probability-of-True for the Bernoulli distribution.
1414
"""
@@ -20,6 +20,6 @@ async def sample(self):
2020

2121
async def log_prob(self, value):
2222
return np.log(self.p) if value else np.log1p(-self.p)
23-
23+
2424
async def argmax(self, idx):
25-
return ((self.p > 0.5) if idx == 0 else (self.p < 0.5))
25+
return (self.p > 0.5) if idx == 0 else (self.p < 0.5)

hfppl/distributions/distribution.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,28 @@
11
class Distribution:
22
"""Abstract base class for a distribution."""
33

4-
54
async def sample(self):
65
"""Generate a random sample from the distribution.
7-
6+
87
Returns:
98
x: a value randomly sampled from the distribution."""
109
raise NotImplementedError()
11-
10+
1211
async def log_prob(self, x):
1312
"""Compute the log probability of a value under this distribution,
1413
or the log probability density if the distribution is continuous.
15-
14+
1615
Args:
1716
x: the point at which to evaluate the log probability.
1817
Returns:
19-
logprob (float): the log probability of `x`."""
18+
logprob (float): the log probability of `x`."""
2019
raise NotImplementedError()
21-
20+
2221
async def argmax(self, n):
2322
"""Return the nth most probable outcome under this distribution (assuming this is a discrete distribution).
24-
23+
2524
Args:
2625
n (int): which value to return to, indexed from most probable (n=0) to least probable (n=|support|).
2726
Returns:
2827
x: the nth most probable outcome from this distribution."""
29-
raise NotImplementedError()
28+
raise NotImplementedError()

hfppl/distributions/geometric.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from .distribution import Distribution
2+
import numpy as np
3+
24

35
class Geometric(Distribution):
4-
"""A Geometric distribution.
5-
"""
6-
6+
"""A Geometric distribution."""
7+
78
def __init__(self, p):
89
"""Create a Geometric distribution.
9-
10+
1011
Args:
1112
p: the rate of the Geometric distribution.
1213
"""
@@ -17,7 +18,7 @@ async def sample(self):
1718
return n, await self.log_prob(n)
1819

1920
async def log_prob(self, value):
20-
return np.log(self.p) + np.log(1 - self.p)*(value - 1)
21-
21+
return np.log(self.p) + np.log(1 - self.p) * (value - 1)
22+
2223
async def argmax(self, idx):
23-
return idx - 1 # Most likely outcome is 0, then 1, etc.
24+
return idx - 1 # Most likely outcome is 0, then 1, etc.

0 commit comments

Comments
 (0)