Skip to content

Commit a80dcdc

Browse files
committed
[fix] Apply black and isort formatting
1 parent 57d73f0 commit a80dcdc

File tree

9 files changed

+152
-56
lines changed

9 files changed

+152
-56
lines changed

roleplay/VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.0.0
1+
2.0.1

roleplay/actions/generate_dialogues.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from aim import Run, Text
88
from omegaconf import DictConfig
99
from tqdm import tqdm
10-
1110
from urartu.common.action import Action
1211
from urartu.common.dataset import Dataset
1312

@@ -58,7 +57,9 @@ def main(self):
5857
model_inquirer.aim_run = self.aim_run
5958
model_responder.aim_run = self.aim_run
6059

61-
for idx, sample in tqdm(enumerate(dataset.dataset), total=len(dataset.dataset), desc="samples"):
60+
for idx, sample in tqdm(
61+
enumerate(dataset.dataset), total=len(dataset.dataset), desc="samples"
62+
):
6263
for persona, persona_hash in tqdm(personas, desc="personas", leave=False):
6364
self.aim_run["personas"][persona_hash] = persona
6465

@@ -67,7 +68,10 @@ def main(self):
6768
dialog = []
6869
raw_dialog = []
6970

70-
instructions = [instruct.lstrip().rstrip() for instruct in sample[task_cfg.dataset.input_key].split("\n")]
71+
instructions = [
72+
instruct.lstrip().rstrip()
73+
for instruct in sample[task_cfg.dataset.input_key].split("\n")
74+
]
7175

7276
if self.action_cfg.task.model_inquirer.regenerate_tries:
7377
regeneratinon_idx = 0
@@ -97,7 +101,9 @@ def main(self):
97101
inquirer_output, _ = model_inquirer.generate(
98102
prompt=inquirer_prompt,
99103
generate_cfg=(
100-
inquirer_generate_cfg if inquirer_generate_cfg else self.action_cfg.task.model_inquirer.generate
104+
inquirer_generate_cfg
105+
if inquirer_generate_cfg
106+
else self.action_cfg.task.model_inquirer.generate
101107
),
102108
)
103109
if not inquirer_output:
@@ -121,13 +127,20 @@ def main(self):
121127
if model_inquirer.stop_dialog(inquirer_output):
122128
break
123129

124-
inquirer_output_extract, num_prompts = model_inquirer.extract_prompt(prompt=inquirer_output)
130+
inquirer_output_extract, num_prompts = (
131+
model_inquirer.extract_prompt(prompt=inquirer_output)
132+
)
125133

126134
if self.action_cfg.task.model_inquirer.regenerate_tries:
127135
# --------------------- if model_inquirer failed to provide prompt ---------------------
128136
if inquirer_output_extract is None:
129-
if regeneratinon_idx < self.action_cfg.task.model_inquirer.regenerate_tries:
130-
inquirer_generate_cfg = model_inquirer.get_generation_cfg()
137+
if (
138+
regeneratinon_idx
139+
< self.action_cfg.task.model_inquirer.regenerate_tries
140+
):
141+
inquirer_generate_cfg = (
142+
model_inquirer.get_generation_cfg()
143+
)
131144
regeneratinon_idx += 1
132145
continue
133146
else:
@@ -156,11 +169,16 @@ def main(self):
156169

157170
# As the context for model_inquirer is getting bigger much faster -> Starts answering it's own questions
158171
# To prevent this keep in the inquirer_history only the output prompt(the thing that model_responder will see).
159-
model_inquirer.update_history(prompt=inquirer_prompt, output_extract=inquirer_output_extract)
172+
model_inquirer.update_history(
173+
prompt=inquirer_prompt,
174+
output_extract=inquirer_output_extract,
175+
)
160176

161177
# ------------------------------------------ Model B ------------------------------------------
162178

163-
responder_prompt = model_responder.get_prompt(turn=turn, response_msg=inquirer_output_extract)
179+
responder_prompt = model_responder.get_prompt(
180+
turn=turn, response_msg=inquirer_output_extract
181+
)
164182

165183
self.track(
166184
prompt=responder_prompt,
@@ -171,9 +189,11 @@ def main(self):
171189
"persona_hash": persona_hash,
172190
},
173191
)
174-
responder_output, responder_model_output_template = model_responder.generate(
175-
prompt=responder_prompt,
176-
generate_cfg=self.action_cfg.task.model_responder.generate,
192+
responder_output, responder_model_output_template = (
193+
model_responder.generate(
194+
prompt=responder_prompt,
195+
generate_cfg=self.action_cfg.task.model_responder.generate,
196+
)
177197
)
178198
if not responder_output:
179199
break
@@ -192,7 +212,10 @@ def main(self):
192212
self.aim_run["num_non_coherent_model_responder"] += 1
193213
break
194214

195-
model_responder.update_history(prompt=responder_prompt, output_extract=responder_model_output_template)
215+
model_responder.update_history(
216+
prompt=responder_prompt,
217+
output_extract=responder_model_output_template,
218+
)
196219

197220
# --------------------------------------- Save the dialog ---------------------------------------
198221
dialog.append(
@@ -209,7 +232,9 @@ def main(self):
209232
turn += 1
210233
pbar.update(1)
211234

212-
with jsonlines.open(records_dir.joinpath(f"{self.cfg.seed}.jsonl"), mode="a") as writer:
235+
with jsonlines.open(
236+
records_dir.joinpath(f"{self.cfg.seed}.jsonl"), mode="a"
237+
) as writer:
213238
writer.write(
214239
{
215240
"persona": persona,

roleplay/common/model.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,15 @@ def extract_prompt(self, prompt: str) -> str:
5050

5151
def stop_dialog(self, prompt):
5252
translator = str.maketrans("", "", string.punctuation)
53-
prompt_first_token = re.split(r"\s+|\n", prompt.strip())[0].strip().translate(translator).strip()
54-
prompt_last_token = re.split(r"\s+|\n", prompt.strip())[-1].strip().translate(translator).strip()
53+
prompt_first_token = (
54+
re.split(r"\s+|\n", prompt.strip())[0].strip().translate(translator).strip()
55+
)
56+
prompt_last_token = (
57+
re.split(r"\s+|\n", prompt.strip())[-1]
58+
.strip()
59+
.translate(translator)
60+
.strip()
61+
)
5562
if (
5663
self.spec_tokens.conv_stop_token == prompt_first_token
5764
or self.spec_tokens.conv_stop_token == prompt_last_token
@@ -74,10 +81,16 @@ def is_non_coherent(self, text):
7481
if n_grams and len(n_grams) >= max(self.cfg.non_coherent_r, n):
7582
if n_grams[-1] == n_gram or n_grams[-n] == n_gram:
7683
last_rs = n_grams[-self.cfg.non_coherent_r :]
77-
if len(last_rs) == self.cfg.non_coherent_r and len(set(last_rs)) == 1:
84+
if (
85+
len(last_rs) == self.cfg.non_coherent_r
86+
and len(set(last_rs)) == 1
87+
):
7888
return True
7989
last_rs = n_grams[-n::-n][: self.cfg.non_coherent_r]
80-
if len(last_rs) == self.cfg.non_coherent_r and len(set(last_rs)) == 1:
90+
if (
91+
len(last_rs) == self.cfg.non_coherent_r
92+
and len(set(last_rs)) == 1
93+
):
8194
return True
8295
n_grams.append(n_gram)
8396
return False
@@ -102,5 +115,7 @@ def collate_tokenize(data, tokenizer, input_key):
102115
else:
103116
input_text = element[input_key]
104117
input_batch.append(input_text)
105-
tokenized = tokenizer(input_batch, padding="longest", truncation=True, return_tensors="pt").to(Device.get_device())
118+
tokenized = tokenizer(
119+
input_batch, padding="longest", truncation=True, return_tensors="pt"
120+
).to(Device.get_device())
106121
return tokenized

roleplay/common/persona.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ def get_personas(cfg) -> List[Tuple[str, Dict[str, str]]]:
1313
features = person["person"]
1414

1515
for feature_name in features.keys():
16-
persona = persona.replace(f"<{feature_name.upper()}>", features[feature_name])
16+
persona = persona.replace(
17+
f"<{feature_name.upper()}>", features[feature_name]
18+
)
1719

1820
persona_hash = hashlib.md5(str(features).encode()).hexdigest()
1921
personas.append((persona, persona_hash))

roleplay/datasets/hf_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Dict, List
22

3-
from urartu.common.dataset import Dataset
43
from datasets import Dataset as HFDataset
4+
from urartu.common.dataset import Dataset
55

66

77
class HFDatasets(Dataset):

roleplay/models/causal_lm_model.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import torch
55
from transformers import AutoModelForCausalLM, AutoTokenizer
6-
76
from urartu.common.device import Device
87
from urartu.utils.dtype import eval_dtype
8+
99
from roleplay.common.model import Model
1010

1111

@@ -40,7 +40,9 @@ def get_prompt(self, turn, response_msg=None, persona=None, instructions=None):
4040

4141
if turn == 0:
4242
return (
43-
self.conv_template.first_turn_input.replace(self.spec_tokens.persona_placeholder, persona)
43+
self.conv_template.first_turn_input.replace(
44+
self.spec_tokens.persona_placeholder, persona
45+
)
4446
.replace(
4547
self.spec_tokens.objective_placeholder,
4648
f"{instructions[0]}",
@@ -55,17 +57,23 @@ def get_prompt(self, turn, response_msg=None, persona=None, instructions=None):
5557
assert response_msg is not None, "response_msg cannot be None"
5658

5759
if len(instructions) > 1 and turn < len(instructions):
58-
response_forwarding = self.conv_template.mid_response_forwarding.replace(
59-
self.spec_tokens.next_prompt, instructions[turn]
60+
response_forwarding = (
61+
self.conv_template.mid_response_forwarding.replace(
62+
self.spec_tokens.next_prompt, instructions[turn]
63+
)
6064
)
6165
else:
62-
response_forwarding = self.conv_template.response_forwarding.replace(
63-
self.spec_tokens.next_prompt, ""
66+
response_forwarding = (
67+
self.conv_template.response_forwarding.replace(
68+
self.spec_tokens.next_prompt, ""
69+
)
6470
)
6571

6672
return self.conv_template.n_th_turn_input.replace(
6773
self.spec_tokens.user_msg,
68-
response_forwarding.replace(self.spec_tokens.response_placeholder, response_msg).replace(
74+
response_forwarding.replace(
75+
self.spec_tokens.response_placeholder, response_msg
76+
).replace(
6977
self.spec_tokens.conv_stop_placeholder,
7078
self.spec_tokens.conv_stop_token,
7179
),
@@ -79,7 +87,9 @@ def get_prompt(self, turn, response_msg=None, persona=None, instructions=None):
7987
response_msg,
8088
)
8189
else:
82-
return self.conv_template.n_th_turn_input.replace(self.spec_tokens.user_msg, response_msg)
90+
return self.conv_template.n_th_turn_input.replace(
91+
self.spec_tokens.user_msg, response_msg
92+
)
8393
else:
8494
raise NotImplementedError(f"unknown role: {self.role}")
8595

@@ -88,17 +98,25 @@ def generate(self, prompt: str, generate_cfg):
8898
model_prompt = prompt
8999
if self.history:
90100
model_prompt = f'{"".join(self.history)}{prompt}'
91-
prompt_tokenized = self.tokenizer.encode(model_prompt, return_tensors="pt").to(self.model.device)
101+
prompt_tokenized = self.tokenizer.encode(model_prompt, return_tensors="pt").to(
102+
self.model.device
103+
)
92104

93105
with torch.no_grad():
94106
output_tokenized = self.model.generate(prompt_tokenized, **generate_cfg)
95107

96108
output = self.tokenizer.decode(output_tokenized[0], skip_special_tokens=True)
97109

98-
output_o = output.replace(str(self.tokenizer.bos_token), "").replace(str(self.tokenizer.eos_token), "").strip()
110+
output_o = (
111+
output.replace(str(self.tokenizer.bos_token), "")
112+
.replace(str(self.tokenizer.eos_token), "")
113+
.strip()
114+
)
99115

100116
model_prompt_o = (
101-
model_prompt.replace(str(self.tokenizer.bos_token), "").replace(str(self.tokenizer.eos_token), "").strip()
117+
model_prompt.replace(str(self.tokenizer.bos_token), "")
118+
.replace(str(self.tokenizer.eos_token), "")
119+
.strip()
102120
)
103121

104122
turn_response = output_o.replace(model_prompt_o, "", 1)
@@ -110,7 +128,9 @@ def generate(self, prompt: str, generate_cfg):
110128
self.aim_run["num_self_replies"] += 1
111129

112130
turn_response = turn_response.lstrip()
113-
model_output_template = self.conv_template.model_output.replace(self.spec_tokens.model_answer, turn_response)
131+
model_output_template = self.conv_template.model_output.replace(
132+
self.spec_tokens.model_answer, turn_response
133+
)
114134

115135
del output_tokenized
116136

roleplay/models/openai_model.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,23 @@ def get_prompt(self, turn, response_msg, persona=None, instructions=None):
4242
return prompt
4343
else:
4444
if len(instructions) > 1 and turn < len(instructions):
45-
response_forwarding = self.conv_template.mid_response_forwarding.replace(
46-
self.spec_tokens.next_prompt, instructions[turn]
45+
response_forwarding = (
46+
self.conv_template.mid_response_forwarding.replace(
47+
self.spec_tokens.next_prompt, instructions[turn]
48+
)
4749
)
4850
else:
49-
response_forwarding = self.conv_template.response_forwarding.replace(
50-
self.spec_tokens.next_prompt, ""
51+
response_forwarding = (
52+
self.conv_template.response_forwarding.replace(
53+
self.spec_tokens.next_prompt, ""
54+
)
5155
)
5256

5357
return self.conv_template.n_th_turn_input.replace(
5458
self.spec_tokens.user_msg,
55-
response_forwarding.replace(self.spec_tokens.response_placeholder, response_msg),
59+
response_forwarding.replace(
60+
self.spec_tokens.response_placeholder, response_msg
61+
),
5662
)
5763
elif self.role == "model_responder":
5864
if turn == 0:
@@ -61,7 +67,9 @@ def get_prompt(self, turn, response_msg, persona=None, instructions=None):
6167
response_msg,
6268
)
6369
else:
64-
return self.conv_template.n_th_turn_input.replace(self.spec_tokens.user_msg, response_msg)
70+
return self.conv_template.n_th_turn_input.replace(
71+
self.spec_tokens.user_msg, response_msg
72+
)
6573
else:
6674
raise NotImplementedError(f"unknown role: {self.role}")
6775

@@ -74,13 +82,21 @@ def generate(self, prompt: Union[str, Tuple[str, str]], generate_cfg):
7482
else:
7583
self.history.append(HumanMessage(content=prompt))
7684

77-
num_history_words = sum([self._get_num_tokens(item.content) for item in self.history])
85+
num_history_words = sum(
86+
[self._get_num_tokens(item.content) for item in self.history]
87+
)
7888
if generate_cfg.max_new_tokens + num_history_words > self.cfg.context_length:
79-
delta = generate_cfg.max_new_tokens + num_history_words - self.cfg.context_length
89+
delta = (
90+
generate_cfg.max_new_tokens
91+
+ num_history_words
92+
- self.cfg.context_length
93+
)
8094
i = 1
8195
while delta > 0:
8296
len_human_utterance = self._get_num_tokens(self.history[i].content)
83-
len_aiassistant_utterance = self._get_num_tokens(self.history[i + 1].content)
97+
len_aiassistant_utterance = self._get_num_tokens(
98+
self.history[i + 1].content
99+
)
84100
delta -= len_human_utterance + len_aiassistant_utterance
85101
i += 2
86102
del self.history[1:i]

0 commit comments

Comments
 (0)