Skip to content

Commit 7ee0623

Browse files
SSYernarfacebook-github-bot
authored andcommitted
YAML config support for pipeline benchmarking (pytorch#3180)
Summary: Pull Request resolved: pytorch#3180 Added a support for YAML file configuration of the pipeline benchmarking. This feature makes easier to reproduce complex configurations without the need to CLI arguments passing. Example `.yaml ` file should look like: ``` RunOptions: world_size: 2 PipelineConfig: pipeline: "sparse" ``` Also, configs can be listed in a 'flat' way as well: ``` world_size: 2 pipeline: "sparse" ``` To run, add the `--yaml_config` flag with the `.yaml` file path. Additional flags can overwrite the `yaml` file configs as well if desired. Differential Revision: D78127340
1 parent 5574def commit 7ee0623

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import click
4141

4242
import torch
43+
import yaml
4344
from torch import multiprocessing as mp
4445
from torch.autograd.profiler import record_function
4546
from torchrec.distributed import DistributedModelParallel
@@ -477,6 +478,13 @@ def wrapper() -> Any:
477478
sig = inspect.signature(func)
478479
parser = argparse.ArgumentParser(func.__doc__)
479480

481+
parser.add_argument(
482+
"--yaml_config",
483+
type=str,
484+
default=None,
485+
help="YAML config file for benchmarking",
486+
)
487+
480488
# Add loglevel argument with current logger level as default
481489
parser.add_argument(
482490
"--loglevel",
@@ -485,6 +493,21 @@ def wrapper() -> Any:
485493
help="Set the logging level (e.g. info, debug, warning, error)",
486494
)
487495

496+
pre_args, _ = parser.parse_known_args()
497+
498+
yaml_defaults: Dict[str, Any] = {}
499+
if pre_args.yaml_config:
500+
try:
501+
with open(pre_args.yaml_config, "r") as f:
502+
yaml_defaults = yaml.safe_load(f) or {}
503+
logger.info(
504+
f"Loaded YAML config from {pre_args.yaml_config}: {yaml_defaults}"
505+
)
506+
except Exception as e:
507+
logger.warning(
508+
f"Failed to load YAML config because {e}. Proceeding without it."
509+
)
510+
488511
seen_args = set() # track all --<name> we've added
489512

490513
for _name, param in sig.parameters.items():
@@ -509,11 +532,17 @@ def wrapper() -> Any:
509532
ftype = non_none[0]
510533
origin = get_origin(ftype)
511534

512-
# Handle default_factory value
513-
default_value = (
514-
f.default_factory() # pyre-ignore [29]
515-
if f.default_factory is not MISSING
516-
else f.default
535+
# Handle default_factory value and allow YAML config to override it
536+
default_value = yaml_defaults.get(
537+
arg_name, # flat lookup
538+
yaml_defaults.get(cls.__name__, {}).get( # hierarchy lookup
539+
arg_name,
540+
(
541+
f.default_factory() # pyre-ignore [29]
542+
if f.default_factory is not MISSING
543+
else f.default
544+
),
545+
),
517546
)
518547

519548
arg_kwargs = {

0 commit comments

Comments
 (0)