Skip to content

Commit 8cf5c9a

Browse files
authored
fix: Support Pydantic 2.0 (#68)
* fix: Support Pydantic 2.0 * fix: Linting * fix: Remove typo
1 parent b68689a commit 8cf5c9a

File tree

5 files changed

+55
-58
lines changed

5 files changed

+55
-58
lines changed

env.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dependencies:
2424
- h5py
2525
- pyarrow
2626
- matplotlib
27-
- pydantic
27+
- pydantic >=2.0.0
2828

2929
# Chemistry
3030
- datamol >=0.8.0

molfeat/store/modelcard.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
from typing import Optional
2-
from typing import List
3-
from typing import Union
4-
51
from datetime import datetime
6-
from pydantic.typing import Literal
7-
from pydantic import BaseModel
8-
from pydantic import Field
2+
from typing import List, Literal, Optional, Union
3+
94
import datamol as dm
5+
from pydantic import BaseModel, ConfigDict, Field
106

117

128
def get_model_init(card):
@@ -53,6 +49,12 @@ def get_model_init(card):
5349

5450

5551
class ModelInfo(BaseModel):
52+
model_config = ConfigDict(
53+
protected_namespaces=(
54+
"protected_",
55+
) # Prevents warning from usage of model_ prefix in fields
56+
)
57+
5658
name: str
5759
inputs: str = "smiles"
5860
type: Literal["pretrained", "hand-crafted", "hashed", "count"]
@@ -62,12 +64,12 @@ class ModelInfo(BaseModel):
6264
description: str
6365
representation: Literal["graph", "line-notation", "vector", "tensor", "other"]
6466
require_3D: Optional[bool] = False
65-
tags: Optional[List[str]]
67+
tags: Optional[List[str]] = []
6668
authors: Optional[List[str]]
67-
reference: Optional[str]
69+
reference: Optional[str] = None
6870
created_at: datetime = Field(default_factory=datetime.now)
69-
sha256sum: Optional[str]
70-
model_usage: Optional[str]
71+
sha256sum: Optional[str] = None
72+
model_usage: Optional[str] = None
7173

7274
def path(self, root_path: str):
7375
"""Generate the folder path where to save this model
@@ -86,9 +88,9 @@ def match(self, new_card: Union["ModelInfo", dict], match_only: Optional[List[st
8688
match_only: list of minimum attribute that should match between the two model information
8789
"""
8890

89-
self_content = self.dict().copy()
91+
self_content = self.model_dump().copy()
9092
if not isinstance(new_card, dict):
91-
new_card = new_card.dict()
93+
new_card = new_card.model_dump()
9294
new_content = new_card.copy()
9395
# we always remove the datetime field
9496
self_content.pop("created_at", None)

molfeat/store/modelstore.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,20 @@
1-
from typing import Optional
2-
from typing import Any
3-
from typing import Union
4-
from typing import Callable
5-
6-
import yaml
7-
import joblib
8-
import pathlib
91
import os
10-
import fsspec
2+
import pathlib
113
import tempfile
12-
import platformdirs
13-
import filelock
4+
from typing import Any, Callable, Optional, Union
5+
146
import datamol as dm
7+
import filelock
8+
import fsspec
9+
import joblib
10+
import platformdirs
11+
import yaml
1512
from dotenv import load_dotenv
1613
from loguru import logger
1714

1815
from molfeat.store.modelcard import ModelInfo
1916
from molfeat.utils import commons
2017

21-
2218
load_dotenv()
2319

2420

molfeat/trans/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from pathlib import Path
12
from typing import Mapping
23
from typing import Union
34
from typing import List
@@ -575,11 +576,11 @@ def to_state_json(self) -> str:
575576
def to_state_yaml(self) -> str:
576577
return yaml.dump(self.to_state_dict(), Dumper=yaml.SafeDumper)
577578

578-
def to_state_json_file(self, filepath: str):
579+
def to_state_json_file(self, filepath: Union[str, Path]):
579580
with fsspec.open(filepath, "w") as f:
580581
f.write(self.to_state_json()) # type: ignore
581582

582-
def to_state_yaml_file(self, filepath: str):
583+
def to_state_yaml_file(self, filepath: Union[str, Path]):
583584
with fsspec.open(filepath, "w") as f:
584585
f.write(self.to_state_yaml()) # type: ignore
585586

@@ -674,7 +675,7 @@ def from_state_yaml(
674675

675676
@staticmethod
676677
def from_state_json_file(
677-
filepath: str,
678+
filepath: Union[str, Path],
678679
override_args: Optional[dict] = None,
679680
) -> "MoleculeTransformer":
680681
with fsspec.open(filepath, "r") as f:
@@ -683,7 +684,7 @@ def from_state_json_file(
683684

684685
@staticmethod
685686
def from_state_yaml_file(
686-
filepath: str,
687+
filepath: Union[str, Path],
687688
override_args: Optional[dict] = None,
688689
) -> "MoleculeTransformer":
689690
with fsspec.open(filepath, "r") as f:

tests/test_state.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,33 @@
1-
import pytest
2-
3-
import numpy as np
41
import datamol as dm
2+
import numpy as np
3+
import pytest
54

65
from molfeat._version import __version__ as MOLFEAT_VERSION
7-
from molfeat.trans.fp import FPVecTransformer
8-
from molfeat.trans.fp import FPVecFilteredTransformer
9-
from molfeat.trans.base import MoleculeTransformer
10-
from molfeat.trans.base import PrecomputedMolTransformer
11-
from molfeat.trans.graph import AdjGraphTransformer
12-
from molfeat.trans.graph import DGLGraphTransformer
13-
from molfeat.trans.graph import TopoDistGraphTransformer
14-
from molfeat.trans.graph import PYGGraphTransformer
15-
from molfeat.trans.pretrained import PretrainedDGLTransformer
16-
from molfeat.trans.pretrained import GraphormerTransformer
17-
from molfeat.trans.pretrained import PretrainedHFTransformer
18-
19-
from molfeat.calc.atom import AtomCalculator
6+
from molfeat.calc import (
7+
CATS,
8+
FPCalculator,
9+
Pharmacophore2D,
10+
RDKitDescriptors2D,
11+
ScaffoldKeyCalculator,
12+
)
13+
from molfeat.calc._atom_bond_features import atom_chiral_tag_one_hot, atom_one_hot
14+
from molfeat.calc.atom import AtomCalculator, AtomMaterialCalculator
2015
from molfeat.calc.bond import BondCalculator
21-
from molfeat.calc.atom import AtomMaterialCalculator
22-
from molfeat.calc import FPCalculator
23-
from molfeat.calc import ScaffoldKeyCalculator
24-
from molfeat.calc import RDKitDescriptors2D
25-
from molfeat.calc import CATS
26-
from molfeat.calc import Pharmacophore2D
27-
from molfeat.calc._atom_bond_features import atom_chiral_tag_one_hot
28-
from molfeat.calc._atom_bond_features import atom_one_hot
29-
from molfeat.trans.graph import MolTreeDecompositionTransformer
30-
31-
from molfeat.utils.cache import MolToKey
32-
from molfeat.utils.cache import FileCache
16+
from molfeat.trans.base import MoleculeTransformer, PrecomputedMolTransformer
17+
from molfeat.trans.fp import FPVecFilteredTransformer, FPVecTransformer
18+
from molfeat.trans.graph import (
19+
AdjGraphTransformer,
20+
DGLGraphTransformer,
21+
MolTreeDecompositionTransformer,
22+
PYGGraphTransformer,
23+
TopoDistGraphTransformer,
24+
)
25+
from molfeat.trans.pretrained import (
26+
GraphormerTransformer,
27+
PretrainedDGLTransformer,
28+
PretrainedHFTransformer,
29+
)
30+
from molfeat.utils.cache import FileCache, MolToKey
3331
from molfeat.utils.state import compare_state
3432

3533

0 commit comments

Comments
 (0)