Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to debug torchtune #964

Closed
leo-young opened this issue May 12, 2024 · 5 comments
Closed

How to debug torchtune #964

leo-young opened this issue May 12, 2024 · 5 comments

Comments

@leo-young
Copy link

Hi, thanks for the framework,

  1. I can't figure out how to debug torchtune, I want to set breakpoint in recipe, but how can I debug.
  2. Do you plan to add rlhf like dpo, ppo ?
@RdoubleA
Copy link
Contributor

Hey @leo-young, thanks for the question.

The easiest way to debug I've found is to use the python debugger pdb.

import pdb

# Set a breakpoint in your code
pdb.set_trace()

For RLHF, we do have a recipe for DPO. See https://github.com/pytorch/torchtune/blob/main/recipes/lora_dpo_single_device.py. PPO is actively being discussed on our discord (cc @SalmanMohammadi and @kartikayk). Here is the issue discussing it: #812

@optimass
Copy link

hi @leo-young
To debug, i.e. directly launch python code instead of using the CLI (which won't hit break points) you can use the following code I wrote (wIll need some changes)

import argparse
from pathlib import Path

from torch.distributed.run import get_args_parser as get_torchrun_args_parser, run
from torchtune._cli.cp import Copy
from torchtune._cli.download import Download
from torchtune._cli.ls import List
from torchtune._cli.run import Run
from torchtune._cli.validate import Validate


MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
RECIPE = "full_finetune_single_device"
MODEL_CONFIG = "llama3/8B_full_single_device"


ACTIVATE_WANDB = True
MAX_STEPS_PER_EPOCH = 20
EPOCHS = 2

HF_CHECKPOINTER = True
# HF_CHECKPOINTER = False

run_command = [
    "run",
    RECIPE,
    "--config",
    MODEL_CONFIG,
]

run_command.append(f"model_name={MODEL_NAME}")

if ACTIVATE_WANDB:
    run_command.append("metric_logger._component_=torchtune.utils.metric_logging.WandBLogger")
    run_command.append("metric_logger.project=torchtune")

if MAX_STEPS_PER_EPOCH:
    run_command.append(f"max_steps_per_epoch={MAX_STEPS_PER_EPOCH}")
if EPOCHS:
    run_command.append(f"epochs={EPOCHS}")

if HF_CHECKPOINTER:
    run_command.append("checkpointer._component_=torchtune.utils.FullModelHFCheckpointer")


class TuneCLIParser:
    """Holds all information related to running the CLI"""

    def __init__(self):
        # Initialize the top-level parser
        self._parser = argparse.ArgumentParser(
            prog="tune",
            description="Welcome to the TorchTune CLI!",
            add_help=True,
        )
        # Default command is to print help
        self._parser.set_defaults(func=lambda args: self._parser.print_help())

        # Add subcommands
        subparsers = self._parser.add_subparsers(title="subcommands")
        Download.create(subparsers)
        List.create(subparsers)
        Copy.create(subparsers)
        Run.create(subparsers)
        Validate.create(subparsers)

    # def parse_args(self) -> argparse.Namespace:
    #     """Parse CLI arguments"""
    #     return self._parser.parse_args()

    def parse_args(self, args=None) -> argparse.Namespace:
        """Parse CLI arguments"""
        return self._parser.parse_args(args)

    def run(self, args: argparse.Namespace) -> None:
        """Execute CLI"""
        args.func(args)


def main():
    parser = TuneCLIParser()
    args = parser.parse_args(run_command)
    parser.run(args)


if __name__ == "__main__":
    main()

@ebsmothers
Copy link
Contributor

@optimass I'm a bit surprised to hear you don't hit breakpoints when running via CLI. Personally I have no issues doing this via the method described by @RdoubleA. Are you sure that it's not due to running on a distributed recipe (there are other issues in that case that are unrelated to CLI)?

@optimass
Copy link

could be that I'm using Vscode's visual debugger!

@ebsmothers
Copy link
Contributor

Closing this since I think it's resolved. But feel free to re-open if you run into other issues here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants