Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Phi3poc #2301

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f0c2b00
poc
JessicaXYWang Sep 12, 2024
603777a
poc
JessicaXYWang Oct 15, 2024
47ae241
Merge branch 'master' into phi3poc
JessicaXYWang Oct 15, 2024
23f8ca0
rename module
JessicaXYWang Oct 15, 2024
bb5b2b6
Merge branch 'phi3poc' of https://github.com/JessicaXYWang/SynapseML …
JessicaXYWang Oct 15, 2024
f235535
update dependency
JessicaXYWang Oct 17, 2024
f2ab308
Merge branch 'master' into phi3poc
JessicaXYWang Oct 17, 2024
3ee9168
add set device type
JessicaXYWang Oct 21, 2024
b30f168
add Downloader
JessicaXYWang Jan 2, 2025
d760733
remove import
JessicaXYWang Jan 2, 2025
6efa59c
Merge branch 'master' into phi3poc
JessicaXYWang Jan 2, 2025
c7397f3
update lm
JessicaXYWang Jan 10, 2025
e1105fd
Merge branch 'phi3poc' of https://github.com/JessicaXYWang/SynapseML …
JessicaXYWang Jan 10, 2025
e59a981
Merge branch 'master' into phi3poc
JessicaXYWang Jan 10, 2025
ff8ad7f
pyarrow version conflict
JessicaXYWang Jan 13, 2025
56e623d
Merge branch 'phi3poc' of https://github.com/JessicaXYWang/SynapseML …
JessicaXYWang Jan 13, 2025
efa6aa0
update transformers version
JessicaXYWang Jan 14, 2025
2f5338c
add dependency
JessicaXYWang Jan 14, 2025
ff89511
update transformers version
JessicaXYWang Jan 14, 2025
b3dc5da
add phi3 test
JessicaXYWang Jan 16, 2025
c0cd463
test missing transformers library
JessicaXYWang Jan 16, 2025
e3e331c
update databricks test
JessicaXYWang Jan 16, 2025
382a20e
update databricks test
JessicaXYWang Jan 16, 2025
0a0f80c
update db library
JessicaXYWang Jan 17, 2025
eac0293
update doc
JessicaXYWang Jan 23, 2025
7a3e315
format
JessicaXYWang Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
373 changes: 373 additions & 0 deletions core/src/main/python/synapse/ml/llm/HuggingFaceCausallmTransform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,373 @@
from pyspark.ml import Transformer
from pyspark.ml.param.shared import (
HasInputCol,
HasOutputCol,
Param,
Params,
TypeConverters,
)
from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, StructType, StructField
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from transformers import AutoTokenizer, AutoModelForCausalLM
from pyspark import keyword_only
import re
import os
import subprocess
import shutil
import sys

class Peekable:
def __init__(self, iterable):
self._iterator = iter(iterable)
self._cache = []

def __iter__(self):
return self

def __next__(self):
if self._cache:
return self._cache.pop(0)
else:
return next(self._iterator)

def peek(self, n=1):
"""Peek at the next n elements without consuming them."""
while len(self._cache) < n:
try:
self._cache.append(next(self._iterator))
except StopIteration:
break
if n == 1:
return self._cache[0] if self._cache else None
else:
return self._cache[:n]


class ModelParam:
def __init__(self, **kwargs):
self.param = {}
self.param.update(kwargs)

def get_param(self):
return self.param


class ModelConfig:
def __init__(self, **kwargs):
self.config = {}
self.config.update(kwargs)

def get_config(self):
return self.config

def set_config(self, **kwargs):
self.config.update(kwargs)


def camel_to_snake(text):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there might already be one in library to use

return re.sub(r"(?<!^)(?=[A-Z])", "_", text).lower()


class HuggingFaceCausalLM(
Transformer, HasInputCol, HasOutputCol, DefaultParamsReadable, DefaultParamsWritable
):

modelName = Param(
Params._dummy(),
"modelName",
"model name",
typeConverter=TypeConverters.toString,
)
inputCol = Param(
Params._dummy(),
"inputCol",
"input column",
typeConverter=TypeConverters.toString,
)
outputCol = Param(
Params._dummy(),
"outputCol",
"output column",
typeConverter=TypeConverters.toString,
)
modelParam = Param(Params._dummy(), "modelParam", "Model Parameters")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explain difference between model params and other params (you can just link to other docs if easier)

modelConfig = Param(Params._dummy(), "modelConfig", "Model configuration")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explain difference between model config and other params (you can just link to other docs if easier)

useFabricLakehouse = Param(
Params._dummy(),
"useFabricLakehouse",
"Use FabricLakehouse",
Copy link
Collaborator

@mhamilton723 mhamilton723 Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is for a local cache then you might be able to make the verbage generic like useLocalCache

typeConverter=TypeConverters.toBoolean,
)
lakehousePath = Param(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be able to get rid of earlier param just check if this is None

Params._dummy(),
"lakehousePath",
"Fabric Lakehouse Path for Model",
typeConverter=TypeConverters.toString,
)
deviceMap = Param(
Params._dummy(),
"deviceMap",
"device map",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might need to explain a bit more about this param and what it takes

typeConverter=TypeConverters.toString,
)
torchDtype = Param(
Params._dummy(),
"torchDtype",
"torch dtype",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise here

typeConverter=TypeConverters.toString,
)

@keyword_only
def __init__(
self,
modelName=None,
inputCol=None,
outputCol=None,
useFabricLakehouse=False,
lakehousePath=None,
deviceMap=None,
torchDtype=None,
):
super(HuggingFaceCausalLM, self).__init__()
self._setDefault(
modelName=modelName,
inputCol=inputCol,
outputCol=outputCol,
modelParam=ModelParam(),
modelConfig=ModelConfig(),
useFabricLakehouse=useFabricLakehouse,
lakehousePath=None,
deviceMap=None,
torchDtype=None,
)
kwargs = self._input_kwargs
self.setParams(**kwargs)

def load_model(self):
"""
Loads model and tokenizer either from Fabric Lakehouse or the HuggingFace Hub,
depending on the 'useFabricLakehouse' param.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you name it more generically place that name here

"""
model_name = self.getModelName()
model_config = self.getModelConfig().get_config()
device_map = self.getDeviceMap()
torch_dtype = self.getTorchDtype()

if device_map:
model_config["device_map"] = device_map
if torch_dtype:
model_config["torch_dtype"] = torch_dtype

if self.getUseFabricLakehouse():
local_path = (
self.getLakehousePath() or f"/lakehouse/default/Files/{model_name}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: hf_cache

)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switch to just use cachePath and then in our docs well say this is a good place to store things

model = AutoModelForCausalLM.from_pretrained(
local_path, local_files_only=True, **model_config
)
tokenizer = AutoTokenizer.from_pretrained(local_path, local_files_only=True)
else:
model = AutoModelForCausalLM.from_pretrained(model_name, **model_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

return model, tokenizer

@keyword_only
def setParams(self):
kwargs = self._input_kwargs
return self._set(**kwargs)

def setModelName(self, value):
return self._set(modelName=value)

def getModelName(self):
return self.getOrDefault(self.modelName)

def setInputCol(self, value):
return self._set(inputCol=value)

def getInputCol(self):
return self.getOrDefault(self.inputCol)

def setOutputCol(self, value):
return self._set(outputCol=value)

def getOutputCol(self):
return self.getOrDefault(self.outputCol)

def setModelParam(self, **kwargs):
param = ModelParam(**kwargs)
return self._set(modelParam=param)

def getModelParam(self):
return self.getOrDefault(self.modelParam)

def setModelConfig(self, **kwargs):
config = ModelConfig(**kwargs)
return self._set(modelConfig=config)

def getModelConfig(self):
return self.getOrDefault(self.modelConfig)

def setLakehousePath(self, value):
return self._set(lakehousePath=value)

def getLakehousePath(self):
return self.getOrDefault(self.lakehousePath)

def setUseFabricLakehouse(self, value: bool):
return self._set(useFabricLakehouse=value)

def getUseFabricLakehouse(self):
return self.getOrDefault(self.useFabricLakehouse)

def setDeviceMap(self, value):
return self._set(deviceMap=value)

def getDeviceMap(self):
return self.getOrDefault(self.deviceMap)

def setTorchDtype(self, value):
return self._set(torchDtype=value)

def getTorchDtype(self):
return self.getOrDefault(self.torchDtype)

def _predict_single_complete(self, prompt, model, tokenizer):
param = self.getModelParam().get_param()
inputs = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model.generate(inputs, **param)
decoded_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return decoded_output

def _predict_single_chat(self, prompt, model, tokenizer):
param = self.getModelParam().get_param()
chat = [{"role": "user", "content": prompt}]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the prompt is a list, then assume its of structure of "chat"

Copy link
Collaborator

@mhamilton723 mhamilton723 Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
chat = [{"role": "user", "content": prompt}]
if isinstance(prompt, list):
chat = prompt
else:
chat = [{"role": "user", "content": prompt}]

formatted_chat = tokenizer.apply_chat_template(
chat, tokenize=False, add_generation_prompt=True
)
tokenized_chat = tokenizer(
formatted_chat, return_tensors="pt", add_special_tokens=False
)
inputs = {
key: tensor.to(model.device) for key, tensor in tokenized_chat.items()
}
merged_inputs = {**inputs, **param}
outputs = model.generate(**merged_inputs)
decoded_output = tokenizer.decode(
outputs[0][inputs["input_ids"].size(1) :], skip_special_tokens=True
)
return decoded_output

def _transform(self, dataset):
"""Transform method to apply the chat model."""

def _process_partition(iterator, task):
"""Process each partition of the data."""
peekable_iterator = Peekable(iterator)
try:
first_row = peekable_iterator.peek()
except StopIteration:
return None

model, tokenizer = self.load_model()

for row in peekable_iterator:
prompt = row[self.getInputCol()]
if task == "chat":
result = self._predict_single_chat(prompt, model, tokenizer)
elif task == "complete":
result = self._predict_single_complete(prompt, model, tokenizer)
row_dict = row.asDict()
row_dict[self.getOutputCol()] = result
yield Row(**row_dict)

input_schema = dataset.schema
output_schema = StructType(
input_schema.fields + [StructField(self.getOutputCol(), StringType(), True)]
)
result_rdd = dataset.rdd.mapPartitions(
lambda partition: _process_partition(partition, "chat")
)
result_df = result_rdd.toDF(output_schema)
return result_df

def complete(self, dataset):
input_schema = dataset.schema
output_schema = StructType(
input_schema.fields + [StructField(self.getOutputCol(), StringType(), True)]
)
result_rdd = dataset.rdd.mapPartitions(
lambda partition: self._process_partition(partition, "complete")
)
result_df = result_rdd.toDF(output_schema)
return result_df


class Downloader:
def __init__(
self,
base_cache_dir="./cache",
base_url="https://mmlspark.blob.core.windows.net/huggingface/",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets use

%sh
azcopy cp https://mmlspark.blob.core.windows.net/huggingface/blah /lakehouse/blah

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Youy can also put in the mardown cell a little explanation of this and how its just for a speedup otherwise it will download from the huggingface hub

):
self.base_cache_dir = base_cache_dir
self.base_url = base_url

def _ensure_directory_exists(self, directory_path):
if not os.path.exists(directory_path):
os.makedirs(directory_path)

def download_model_from_az(self, repo_id, local_path=None, overwrite="false"):
local_path = os.path.join(
local_path or self.base_cache_dir, repo_id.rsplit("/", 1)[0]
)

blob_url = f"{self.base_url}{repo_id}"

self._ensure_directory_exists(local_path)

command = [
"azcopy",
"copy",
blob_url,
local_path,
"--recursive=true",
f"--overwrite={overwrite}",
]

try:
with subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
) as proc:
for line in proc.stdout:
print(line, end="")
for line in proc.stderr:
print(line, end="", file=sys.stderr)
return_code = proc.wait()
if return_code != 0:
raise subprocess.CalledProcessError(return_code, command)
except subprocess.CalledProcessError as e:
raise IOError("Error during download ", e.stderr)

def copy_contents_to_lakehouse(
self, dst, src=None, lakehouse_path=None, dirs_exist_ok=False
):
if not any(os.scandir("/lakehouse/")):
raise FileNotFoundError("No lakehouse attached")

if lakehouse_path is None:
lakehouse_path = f"/lakehouse/default/Files/{dst}"
os.makedirs(lakehouse_path, exist_ok=True)

src = os.path.join(src or self.base_cache_dir, dst)

for item in os.listdir(src):
src_path = os.path.join(src, item)
dst_path = os.path.join(lakehouse_path, item)

if os.path.isdir(src_path):
shutil.copytree(src_path, dst_path, dirs_exist_ok=dirs_exist_ok)
else:
shutil.copy2(src_path, dst_path)
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ dependencies:
- onnxmltools==1.7.0
- matplotlib
- Pillow
- transformers==4.32.1
- transformers==4.45.2
- huggingface-hub>=0.8.1
- langchain==0.0.152
- openai==0.27.5
Expand Down
Loading