Skip to content

Commit

Permalink
[Setup] Fix dependency (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 authored Feb 3, 2023
1 parent 2fe133a commit 3001c8a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@ def setup():

setuptools.setup(
name="slapo",
description="Slapo: A Scahedule LAnguage for Progressive Optimization.",
description="Slapo: A Schedule LAnguage for Progressive Optimization.",
version=get_version(),
author="Slapo Community",
long_description=long_description,
long_description_content_type="text/markdown",
setup_requires=[],
install_requires=[],
install_requires=[
"packaging",
"psutil",
],
packages=setuptools.find_packages(),
url="https://github.com/awslabs/slapo",
python_requires=">=3.7",
Expand Down
9 changes: 4 additions & 5 deletions slapo/model_dialect/deepspeed/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

from deepspeed.runtime.pipe.topology import (
PipeModelDataParallelTopology,
PipelineParallelGrid,
)

from ..registry import register_model_dialect
from ...logger import get_logger, INFO

Expand All @@ -17,6 +12,10 @@
def init_ds_engine(model, **kwargs):
"""Initialize the DeepSpeed engine."""
import deepspeed
from deepspeed.runtime.pipe.topology import (
PipeModelDataParallelTopology,
PipelineParallelGrid,
)

if "config" not in kwargs:
raise ValueError("DeepSpeed config not provided.")
Expand Down
14 changes: 7 additions & 7 deletions slapo/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
from torch.fx._symbolic_trace import HAS_VARSTUFF, PH, _assert_is_none, _patch_function
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.fx.node import base_types
from transformers.utils.fx import (
_IS_IN_DEBUG_MODE,
_MANUAL_META_OVERRIDES,
Proxy,
_proxies_to_metas,
)

from .logger import get_logger

Expand Down Expand Up @@ -281,7 +275,13 @@ def trace(model: nn.Module, **kwargs: dict[str, Any]):
logger.debug("Tracer: %s Model: %s", tracer_cls_name, model.__class__.__name__)
if isinstance(tracer_cls_name, str):
if tracer_cls_name == "huggingface":
from transformers.utils.fx import HFTracer
from transformers.utils.fx import (
HFTracer,
_IS_IN_DEBUG_MODE,
_MANUAL_META_OVERRIDES,
Proxy,
_proxies_to_metas,
)

assert (
"concrete_args" in kwargs
Expand Down

0 comments on commit 3001c8a

Please sign in to comment.