Skip to content

Commit

Permalink
Merge pull request #88 from neuro-ml/support-for-pydantic<3.0
Browse files Browse the repository at this point in the history
Support for pydantic<3.0
  • Loading branch information
arseniybelkov authored Feb 11, 2025
2 parents b5aa412 + cf53034 commit 7b3f2da
Show file tree
Hide file tree
Showing 13 changed files with 234 additions and 73 deletions.
41 changes: 36 additions & 5 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,22 @@ jobs:
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12' ]
torch-version: [ '1.13', '2.0.1' ]
torch-version: [ '1.13', '2.0.1', '2.5.0']
exclude:
- torch-version: '1.13'
python-version: '3.12'
- python-version: '3.8'
torch-version: '2.5.0'
- python-version: '3.9'
torch-version: '2.5.0'
- python-version: '3.10'
torch-version: '2.5.0'
- python-version: '3.11'
torch-version: '1.13'
- python-version: '3.11'
torch-version: '2.5.0'
- python-version: '3.12'
torch-version: '1.13'
- python-version: '3.12'
torch-version: '2.0.1'

steps:
- uses: actions/checkout@v3
Expand All @@ -23,16 +35,35 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Build the package
- name: Check Python version # https://github.com/python/cpython/issues/95299
id: check-version
run: |
python_version=$(python --version | awk '{print $2}')
major=$(echo $python_version | cut -d'.' -f1)
minor=$(echo $python_version | cut -d'.' -f2)
if ([ "$major" -eq 3 ] && [ "$minor" -ge 12 ]); then
echo "setuptools_present=false" >> $GITHUB_ENV
else
echo "setuptools_present=true" >> $GITHUB_ENV
fi
- name: Build the package (python >= 3.12)
if: env.setuptools_present == 'false'
run: |
python -m pip install build
python -m build
- name: Build the package (python < 3.12)
if: env.setuptools_present == 'true'
run: |
python setup.py sdist
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .
python -m pip install -r tests/dev-requirements.txt
python -m pip install torch==${{ matrix.torch-version }}
python -m pip install -e .
cd tests
export MODULE_PARENT=$(python -c "import $MODULE_NAME, os; print(os.path.dirname($MODULE_NAME.__path__[0]))")
export MODULE_PARENT=${MODULE_PARENT%"/"}
Expand Down
2 changes: 1 addition & 1 deletion docs/cli/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ The command shown above will run SLURM job with 4 CPUs and 100G of RAM.

### Predefined run configs
You can predefine run configs to avoid reentering the same flags.
Create `~/.config/thunder/backends.yml` (you can run `thunder show` in your terminal,
Create `~/.config/thunder/backends.yml` (you can run `thunder backend list` in your terminal,
required path will be at the title of the table) in you home directory.
Now you can specify config name and its parameters:
```yaml
Expand Down
9 changes: 8 additions & 1 deletion docs/examples/mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,17 @@ module = ThunderModule(
architecture, nn.CrossEntropyLoss(), optimizer=torch.optim.Adam(architecture.parameters())
)

# Preparing metrics
# 'y' and 'x' are single label and
# model prediction for a single image,
# hence the 'np.argmax(x)' for extracting
# the predicted label.
group_accuracy = {lambda y, x: (y, np.argmax(x)): accuracy_score}

# Initialize a trainer
trainer = Trainer(
callbacks=[ModelCheckpoint(save_last=True),
MetricMonitor(group_metrics={lambda y, x: (np.argmax(y), x): accuracy_score})],
MetricMonitor(group_metrics=group_accuracy)],
accelerator="auto",
devices=1,
max_epochs=100,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
lightning>=2.0.0,<3.0.0
lazycon>=0.6.3,<1.0.0
typer>=0.9.0,<1.0.0
pydantic<2.0.0
pydantic<3.0.0
click
torch
toolz
Expand Down
16 changes: 11 additions & 5 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ def test_build(temp_dir, mock_backend):

with cleanup(experiment):
result = invoke('build', config, experiment, '-u', 'c=3')
assert result.exit_code != 0
assert result.exit_code != 0, result.output
assert 'are missing from the config' in str(result.exception)

result = invoke('build', config, experiment, '-u', 'a=10')
assert result.exit_code == 0
assert result.exit_code == 0, result.output
assert Config.load(experiment / 'experiment.config').a == 10

with cleanup(experiment):
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_build_overwrite(temp_dir):
config.write_text('b = 2')

result = invoke('build', config, experiment, "--overwrite")
assert result.exit_code == 0
assert result.exit_code == 0, result.output
assert not hasattr(read_config(experiment / "experiment.config"), "a")
assert read_config(experiment / "experiment.config").b == 2

Expand Down Expand Up @@ -187,6 +187,11 @@ def test_backend_add(temp_dir, mock_backend):
local = load_backend_configs()
assert "new_config" in local and "new_config_2" in local

invoke("backend", "add", "new_config_3", "backend=cli", "n_workers=8")
local = load_backend_configs()
assert "new_config" in local and "new_config_2" in local
assert "new_config_3" in local


def test_backend_list(temp_dir, mock_backend):
# language=yaml
Expand All @@ -208,10 +213,11 @@ def test_backend_list(temp_dir, mock_backend):


def test_backend_set(temp_dir, mock_backend):
assert invoke("backend", "add", "config", "backend=slurm", "ram=100G", "--force").exit_code == 0
result = invoke("backend", "add", "config", "backend=slurm", "ram=100G", "--force")
assert result.exit_code == 0, result.output
result = invoke("backend", "set", "config")

assert result.exit_code == 0
assert result.exit_code == 0, result.output
local = load_backend_configs()
assert local[local["meta"].default].config.ram == "100G"

Expand Down
29 changes: 19 additions & 10 deletions thunder/backend/interface.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from pathlib import Path
from typing import Dict, Optional, Sequence, Type

from pydantic import BaseModel, Extra, validator
from pydantic import BaseModel

from ..layout import Node
from ..pydantic_compat import PYDANTIC_MAJOR, NoExtra, field_validator, model_validate


class BackendConfig(BaseModel):
class Config:
extra = Extra.ignore
class BackendConfig(NoExtra):
"""Backend Parameters"""


class Backend:
Expand All @@ -19,24 +19,33 @@ def run(config: BackendConfig, experiment: Path, nodes: Optional[Sequence[Node]]
"""Start running the given `nodes` of an experiment located at the given path"""


class BackendEntryConfig(BaseModel):
class BackendEntryConfig(NoExtra):
backend: str
config: BackendConfig

@validator('config', pre=True)
@field_validator("config", mode="before")
def _val_config(cls, v, values):
val = backends[values['backend']]
return val.Config.parse_obj(v)
return parse_backend_config(v, values)

@property
def backend_cls(self):
return backends[self.backend]

class Config:
extra = Extra.ignore

if PYDANTIC_MAJOR == 2:
def parse_backend_config(v, values):
val = backends[values.data["backend"]]
return model_validate(val.Config, v)
else:
def parse_backend_config(v, values):
val = backends[values["backend"]]
return model_validate(val.Config, v)


class MetaEntry(BaseModel):
"""
Default backend set by `thunder backend set`
"""
default: str


Expand Down
36 changes: 18 additions & 18 deletions thunder/backend/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from typing import Optional, Sequence

from deli import save
from pydantic import validator
from pytimeparse.timeparse import timeparse
from typer import Option
from typing_extensions import Annotated

from ..layout import Node
from ..pydantic_compat import field_validator
from .interface import Backend, BackendConfig, backends


Expand All @@ -26,45 +26,45 @@

class Slurm(Backend):
class Config(BackendConfig):
ram: Annotated[str, Option(
..., '-r', '--ram', '--mem',
ram: Annotated[Optional[str], Option(
None, '-r', '--ram', '--mem',
help='The amount of RAM required per node. Default units are megabytes. '
'Different units can be specified using the suffix [K|M|G|T].'
)] = None
cpu: Annotated[int, Option(
..., '-c', '--cpu', '--cpus-per-task', show_default=False,
cpu: Annotated[Optional[int], Option(
None, ..., '-c', '--cpu', '--cpus-per-task', show_default=False,
help='Number of CPU cores to allocate. Default to 1'
)] = None
gpu: Annotated[int, Option(
..., '-g', '--gpu', '--gpus-per-node',
gpu: Annotated[Optional[int], Option(
None, '-g', '--gpu', '--gpus-per-node',
help='Number of GPUs to allocate'
)] = None
partition: Annotated[str, Option(
..., '-p', '--partition',
partition: Annotated[Optional[str], Option(
None, '-p', '--partition',
help='Request a specific partition for the resource allocation'
)] = None
nodelist: Annotated[str, Option(
...,
nodelist: Annotated[Optional[str], Option(
None,
help='Request a specific list of hosts. The list may be specified as a comma-separated '
'list of hosts, a range of hosts (host[1-5,7,...] for example).'
'list of hosts, a range of hosts (host[1-5,7,None] for example).'
)] = None
time: Annotated[str, Option(
..., '-t', '--time',
time: Annotated[Optional[str], Option(
None, '-t', '--time',
help='Set a limit on the total run time of the job allocation. When the time limit is reached, '
'each task in each job step is sent SIGTERM followed by SIGKILL.'
)] = None
limit: Annotated[int, Option(
...,
limit: Annotated[Optional[int], Option(
None,
help='Limit the number of jobs that are simultaneously running during the experiment',
)] = None

@validator('time')
@field_validator("time")
def val_time(cls, v):
if v is None:
return
return parse_duration(v)

@validator('limit')
@field_validator("limit")
def val_limit(cls, v):
assert v is None or v > 0, 'The jobs limit, if specified, must be positive'
return v
Expand Down
51 changes: 34 additions & 17 deletions thunder/cli/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typer.models import ParamMeta

from ..backend import BackendEntryConfig, MetaEntry, backends
from ..pydantic_compat import model_validate, resolve_pydantic_major
from .app import app


Expand Down Expand Up @@ -92,23 +93,39 @@ def populate(backend_name):
show_default=False,
),
)]
for field in entry.backend_cls.Config.__fields__.values():
annotation = field.outer_type_
# TODO: https://stackoverflow.com/a/68337036
if not hasattr(annotation, '__metadata__') or not hasattr(annotation, '__origin__'):
raise ValueError('Please use the `Annotated` syntax to annotate you backend config')

# TODO
default, = annotation.__metadata__
default = copy.deepcopy(default)
default.default = getattr(entry.config, field.name)
default.help = f'[{backend_name} backend] {default.help}'
backend_params.append(ParamMeta(
name=field.name, default=default, annotation=annotation.__origin__,
))
backend_params.extend(_collect_backend_params(entry, backend_name))
return backend_params


if resolve_pydantic_major() >= 2:
def _collect_backend_params(entry, backend_name):
"""
Config Annotation depends on pydantic version.
"""
for field_name, field in entry.backend_cls.Config.model_fields.items():
field_clone = copy.deepcopy(field)
field_clone.default = getattr(entry.config, field_name)
yield ParamMeta(
name=field_name, default=field_clone.default, annotation=field.annotation,
)
else:
def _collect_backend_params(entry, backend_name):
for field in entry.backend_cls.Config.__fields__.values():
annotation = field.outer_type_
# TODO: https://stackoverflow.com/a/68337036
if not hasattr(annotation, '__metadata__') or not hasattr(annotation, '__origin__'):
raise ValueError('Please use the `Annotated` syntax to annotate you backend config')

# TODO
default, = annotation.__metadata__
default = copy.deepcopy(default)
default.default = getattr(entry.config, field.name)
default.help = f'[{backend_name} backend] {default.help}'
yield ParamMeta(
name=field.name, default=default, annotation=annotation.__origin__,
)


def collect_backends() -> ChainMap:
"""
Collects backend for each config.
Expand Down Expand Up @@ -144,7 +161,7 @@ def collect_configs() -> Tuple[ChainMap, Union[MetaEntry, None]]:
def load_backend_configs() -> Dict[str, Union[BackendEntryConfig, MetaEntry]]:
path = BACKENDS_CONFIG_PATH
if not path.exists():
# print(path, flush=True)
# TODO: return Option[Dict]
return {}

with path.open('r') as file:
Expand All @@ -153,5 +170,5 @@ def load_backend_configs() -> Dict[str, Union[BackendEntryConfig, MetaEntry]]:
return {}
# FIXME
assert isinstance(local, dict), type(local)
return {k: BackendEntryConfig.parse_obj(v)
if k != "meta" else MetaEntry.parse_obj(v) for k, v in local.items()}
return {k: model_validate(BackendEntryConfig, v)
if k != "meta" else model_validate(MetaEntry, v) for k, v in local.items()}
Loading

0 comments on commit 7b3f2da

Please sign in to comment.