You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I was wondering if there were any efforts on great.py natively supporting Distributed Data Parallels? Currently I am doing a workaround by editing my own trainer file and saving it via torch save.
Below is how I invoke it.
torchrun --nproc_per_node=8 ddptest.py
import os
import pandas as pd
from be_great import GReaT
import torch.distributed as dist
import torch
from collections import OrderedDict
def main():
# Set CUDA devices for each process
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
dataFile = "/edit/for/your/own/repo.csv"
data = pd.read_csv(dataFile)
great = GReaT("gpt2-xl",
batch_size=8,
epochs=50,
fp16=True
)
# Move the model to the appropriate GPU
great.model.to(local_rank)
# Wrap the model for distributed training
great.model = torch.nn.parallel.DistributedDataParallel(
great.model, device_ids=[local_rank], output_device=local_rank
)
trainer = great.fit(data, data.columns.to_list())
# Save the model only from rank 0 process
if dist.get_rank() == 0:
# Create a new state dict with corrected key names
state_dict = great.model.state_dict()
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# Save the model with the modified state dictl
torch.save(new_state_dict, "/edit/for/your/own/model.pt")
if __name__ == "__main__":
# Initialize the distributed process group
dist.init_process_group(backend="nccl")
main()
Again thank you so much for this awesome framework.
The text was updated successfully, but these errors were encountered:
So far we do not have plans about adding native distributed data parallels support. However, it will be great to have, therefore any contributions are very welcome.
Also, thank you for providing a simple workaround script, it will be definitely useful for others!
Hi, I was wondering if there were any efforts on great.py natively supporting Distributed Data Parallels? Currently I am doing a workaround by editing my own trainer file and saving it via torch save.
Below is how I invoke it.
torchrun --nproc_per_node=8 ddptest.py
Again thank you so much for this awesome framework.
The text was updated successfully, but these errors were encountered: