-
Notifications
You must be signed in to change notification settings - Fork 258
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to debug torchtune #964
Comments
Hey @leo-young, thanks for the question. The easiest way to debug I've found is to use the python debugger
For RLHF, we do have a recipe for DPO. See https://github.com/pytorch/torchtune/blob/main/recipes/lora_dpo_single_device.py. PPO is actively being discussed on our discord (cc @SalmanMohammadi and @kartikayk). Here is the issue discussing it: #812 |
hi @leo-young import argparse
from pathlib import Path
from torch.distributed.run import get_args_parser as get_torchrun_args_parser, run
from torchtune._cli.cp import Copy
from torchtune._cli.download import Download
from torchtune._cli.ls import List
from torchtune._cli.run import Run
from torchtune._cli.validate import Validate
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
RECIPE = "full_finetune_single_device"
MODEL_CONFIG = "llama3/8B_full_single_device"
ACTIVATE_WANDB = True
MAX_STEPS_PER_EPOCH = 20
EPOCHS = 2
HF_CHECKPOINTER = True
# HF_CHECKPOINTER = False
run_command = [
"run",
RECIPE,
"--config",
MODEL_CONFIG,
]
run_command.append(f"model_name={MODEL_NAME}")
if ACTIVATE_WANDB:
run_command.append("metric_logger._component_=torchtune.utils.metric_logging.WandBLogger")
run_command.append("metric_logger.project=torchtune")
if MAX_STEPS_PER_EPOCH:
run_command.append(f"max_steps_per_epoch={MAX_STEPS_PER_EPOCH}")
if EPOCHS:
run_command.append(f"epochs={EPOCHS}")
if HF_CHECKPOINTER:
run_command.append("checkpointer._component_=torchtune.utils.FullModelHFCheckpointer")
class TuneCLIParser:
"""Holds all information related to running the CLI"""
def __init__(self):
# Initialize the top-level parser
self._parser = argparse.ArgumentParser(
prog="tune",
description="Welcome to the TorchTune CLI!",
add_help=True,
)
# Default command is to print help
self._parser.set_defaults(func=lambda args: self._parser.print_help())
# Add subcommands
subparsers = self._parser.add_subparsers(title="subcommands")
Download.create(subparsers)
List.create(subparsers)
Copy.create(subparsers)
Run.create(subparsers)
Validate.create(subparsers)
# def parse_args(self) -> argparse.Namespace:
# """Parse CLI arguments"""
# return self._parser.parse_args()
def parse_args(self, args=None) -> argparse.Namespace:
"""Parse CLI arguments"""
return self._parser.parse_args(args)
def run(self, args: argparse.Namespace) -> None:
"""Execute CLI"""
args.func(args)
def main():
parser = TuneCLIParser()
args = parser.parse_args(run_command)
parser.run(args)
if __name__ == "__main__":
main() |
could be that I'm using Vscode's visual debugger! |
Closing this since I think it's resolved. But feel free to re-open if you run into other issues here! |
Hi, thanks for the framework,
The text was updated successfully, but these errors were encountered: