Skip to content

Commit f6db63d

Browse files
authored
Add ModelGenerator for LLama2, driver proxy and Vicuna (databrickslabs#9)
1 parent 96ec0c4 commit f6db63d

File tree

5 files changed

+646
-317
lines changed

5 files changed

+646
-317
lines changed

databricks/labs/doc_qa/evaluators/templated_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616

1717
logging.basicConfig(level=logging.INFO)
18-
logger = logging.getLogger(__name__)
18+
logger = logging.getLogger(__name__.split(".")[0])
1919

2020

2121
class ParameterType(Enum):

databricks/labs/doc_qa/llm_providers/openai_provider.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111

1212
openai_token = os.getenv('OPENAI_API_KEY')
13-
openai_org = os.getenv('OPENAI_ORGANIZATION')
1413

1514
class StatusCode429Error(Exception):
1615
pass
@@ -24,7 +23,6 @@ def request_openai(messages, functions=[], temperature=0.0, model="gpt-4"):
2423
headers = {
2524
"Content-Type": "application/json",
2625
"Authorization": f"Bearer {openai_token}",
27-
"OpenAI-Organization": openai_org,
2826
}
2927
data = {
3028
"model": model,

databricks/labs/doc_qa/model_generators/model_generator.py

Lines changed: 212 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import concurrent.futures
66

77
logging.basicConfig(level=logging.INFO)
8-
logger = logging.getLogger(__name__)
9-
8+
# Instead of using full name, only use the module name
9+
logger = logging.getLogger(__name__.split(".")[0])
1010

1111
class RowGenerateResult:
1212
"""
@@ -78,6 +78,7 @@ def __init__(
7878
self._prompt_formatter = prompt_formatter
7979
self._batch_size = batch_size
8080
self._concurrency = concurrency
81+
self.input_variables = prompt_formatter.variables
8182

8283
def _generate(
8384
self, prompts: list, temperature: float, max_tokens=256, system_prompt=None
@@ -95,7 +96,7 @@ def run_tasks(
9596
Returns:
9697
EvalResult: the evaluation result
9798
"""
98-
prompt_batches = []
99+
task_batches = []
99100
# First, traverse the input dataframe using batch size
100101
for i in range(0, len(input_df), self._batch_size):
101102
# Get the current batch
@@ -107,9 +108,13 @@ def run_tasks(
107108
# Format the input dataframe into prompts
108109
prompt = self._prompt_formatter.format(**row)
109110
prompts.append(prompt)
110-
prompt_batches.append(prompts)
111+
task = {
112+
"prompts": prompts,
113+
"df": batch_df,
114+
}
115+
task_batches.append(task)
111116
logger.info(
112-
f"Generated total number of batches for prompts: {len(prompt_batches)}"
117+
f"Generated total number of batches for prompts: {len(task_batches)}"
113118
)
114119

115120
# Call the _generate in parallel using multiple threads, each call with a batch of prompts
@@ -118,15 +123,28 @@ def run_tasks(
118123
) as executor:
119124
future_to_batch = {
120125
executor.submit(
121-
self._generate, prompts, temperature, max_tokens, system_prompt
122-
): prompts
123-
for prompts in prompt_batches
126+
self._generate,
127+
task["prompts"],
128+
temperature,
129+
max_tokens,
130+
system_prompt,
131+
): task
132+
for task in task_batches
124133
}
125134
batch_generate_results = []
126135
for future in concurrent.futures.as_completed(future_to_batch):
127-
prompts = future_to_batch[future]
136+
task = future_to_batch[future]
128137
try:
129138
result = future.result()
139+
batch_df = task["df"]
140+
# Add the columns from batch_df where the column name is in the input_variables, add as attribute and value to the RowEvalResult
141+
for index, row in enumerate(result.rows):
142+
for input_variable in self.input_variables:
143+
setattr(
144+
row,
145+
input_variable,
146+
batch_df[input_variable].iloc[index],
147+
)
130148
batch_generate_results.append(result)
131149
except Exception as exc:
132150
logger.error(f"Exception occurred when running the task: {exc}")
@@ -268,7 +286,7 @@ def __init__(
268286
self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
269287

270288
def _format_prompt(self, message: str, system_prompt_opt: str) -> str:
271-
if system_prompt_opt is None:
289+
if system_prompt_opt is not None:
272290
texts = [f"[INST] <<SYS>>\n{system_prompt_opt}\n<</SYS>>\n\n"]
273291
texts.append(f"{message.strip()} [/INST]")
274292
return "".join(texts)
@@ -323,3 +341,187 @@ def _generate(
323341
is_successful=True,
324342
error_msg=None,
325343
)
344+
345+
346+
class VicunaModelGenerator(BaseModelGenerator):
347+
def __init__(
348+
self,
349+
prompt_formatter: PromptTemplate,
350+
model_name_or_path: str,
351+
batch_size: int = 1,
352+
concurrency: int = 1,
353+
) -> None:
354+
"""
355+
Args:
356+
prompt_formatter (PromptTemplate): the prompt format to format the input dataframe into prompts row by row according to the column names
357+
model_name (str): the model name
358+
batch_size (int, optional): Batch size that will be used to run tasks. Defaults to 1, which means it's sequential.
359+
360+
Recommendations:
361+
- for A100 80GB, use batch_size 1 for vicuna-33b
362+
- for A100 80GB x 2, use batch_size 64 for vicuna-33b
363+
"""
364+
super().__init__(prompt_formatter, batch_size, concurrency)
365+
# require the concurrency to be 1 to avoid race condition during inference
366+
if concurrency != 1:
367+
raise ValueError(
368+
"VicunaModelGenerator currently only supports concurrency 1"
369+
)
370+
self._model_name_or_path = model_name_or_path
371+
import torch
372+
from transformers import (
373+
AutoModelForCausalLM,
374+
AutoTokenizer,
375+
TextIteratorStreamer,
376+
)
377+
378+
if torch.cuda.is_available():
379+
self._model = AutoModelForCausalLM.from_pretrained(
380+
model_name_or_path, torch_dtype=torch.float16, device_map="auto"
381+
)
382+
else:
383+
raise ValueError("VicunaModelGenerator currently only supports GPU")
384+
self._tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
385+
386+
def _format_prompt(self, message: str, system_prompt_opt: str) -> str:
387+
if system_prompt_opt is not None:
388+
return f"""{system_prompt_opt}
389+
390+
USER: {message}
391+
ASSISTANT:
392+
"""
393+
else:
394+
return f"""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
395+
396+
USER: {message}
397+
ASSISTANT:
398+
"""
399+
400+
def _generate(
401+
self, prompts: list, temperature: float, max_tokens=256, system_prompt=None
402+
) -> BatchGenerateResult:
403+
from transformers import pipeline
404+
405+
all_formatted_prompts = [
406+
self._format_prompt(message=message, system_prompt_opt=system_prompt)
407+
for message in prompts
408+
]
409+
410+
top_p = 0.95
411+
repetition_penalty = 1.15
412+
pipe = pipeline(
413+
"text-generation",
414+
model=self._model,
415+
tokenizer=self._tokenizer,
416+
max_new_tokens=max_tokens,
417+
temperature=temperature,
418+
top_p=top_p,
419+
repetition_penalty=repetition_penalty,
420+
return_full_text=False,
421+
)
422+
responses = pipe(all_formatted_prompts)
423+
rows = []
424+
for index, response in enumerate(responses):
425+
response_content = response[0]["generated_text"]
426+
row_generate_result = RowGenerateResult(
427+
is_successful=True,
428+
error_msg=None,
429+
answer=response_content,
430+
temperature=temperature,
431+
max_tokens=max_tokens,
432+
model_name=self._model_name_or_path,
433+
top_p=top_p,
434+
repetition_penalty=repetition_penalty,
435+
prompts=all_formatted_prompts[index],
436+
)
437+
rows.append(row_generate_result)
438+
439+
return BatchGenerateResult(
440+
num_rows=len(rows),
441+
num_successful_rows=len(rows),
442+
rows=rows,
443+
is_successful=True,
444+
error_msg=None,
445+
)
446+
447+
448+
class DriverProxyModelGenerator(BaseModelGenerator):
449+
def __init__(
450+
self,
451+
url: str,
452+
pat_token: str,
453+
prompt_formatter: PromptTemplate,
454+
batch_size: int = 32,
455+
concurrency: int = 1,
456+
) -> None:
457+
"""
458+
Args:
459+
prompt_formatter (PromptTemplate): the prompt format to format the input dataframe into prompts row by row according to the column names
460+
model_name (str): the model name
461+
batch_size (int, optional): Batch size that will be used to run tasks. Defaults to 1, which means it's sequential.
462+
463+
Recommendations:
464+
- for A100 80GB, use batch_size 16 for llama-2-13b-chat
465+
"""
466+
super().__init__(prompt_formatter, batch_size, concurrency)
467+
self._url = url
468+
self._pat_token = pat_token
469+
470+
def _format_prompt(self, message: str, system_prompt_opt: str) -> str:
471+
if system_prompt_opt is not None:
472+
texts = [f"[INST] <<SYS>>\n{system_prompt_opt}\n<</SYS>>\n\n"]
473+
texts.append(f"{message.strip()} [/INST]")
474+
return "".join(texts)
475+
else:
476+
texts = [f"[INST] \n\n"]
477+
texts.append(f"{message.strip()} [/INST]")
478+
return "".join(texts)
479+
480+
def _generate(
481+
self, prompts: list, temperature: float, max_tokens=256, system_prompt=None
482+
) -> BatchGenerateResult:
483+
top_p = 0.95
484+
485+
all_formatted_prompts = [
486+
self._format_prompt(message=message, system_prompt_opt=system_prompt)
487+
for message in prompts
488+
]
489+
490+
import requests
491+
import json
492+
493+
headers = {
494+
"Authentication": f"Bearer {self._pat_token}",
495+
"Content-Type": "application/json",
496+
}
497+
498+
data = {
499+
"prompts": all_formatted_prompts,
500+
"temperature": temperature,
501+
"max_tokens": max_tokens,
502+
}
503+
504+
response = requests.post(self._url, headers=headers, data=json.dumps(data))
505+
506+
# Extract the "outputs" as a JSON array from the response
507+
outputs = response.json()["outputs"]
508+
rows = []
509+
for index, response_content in enumerate(outputs):
510+
row_generate_result = RowGenerateResult(
511+
is_successful=True,
512+
error_msg=None,
513+
answer=response_content,
514+
temperature=temperature,
515+
max_tokens=max_tokens,
516+
top_p=top_p,
517+
prompts=all_formatted_prompts[index],
518+
)
519+
rows.append(row_generate_result)
520+
521+
return BatchGenerateResult(
522+
num_rows=len(rows),
523+
num_successful_rows=len(rows),
524+
rows=rows,
525+
is_successful=True,
526+
error_msg=None,
527+
)

0 commit comments

Comments
 (0)