Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for PydanticV2 #212

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,7 @@ docs/_build/
.tox/

# Coverage
cov_data/
cov_data/

#Custom
test.py
10 changes: 7 additions & 3 deletions rocketry/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, ClassVar
from pydantic.dataclasses import dataclass, Field
from pydantic import BaseModel

if TYPE_CHECKING:
from rocketry import Session

class RedBase:
class RedBase():
"""Baseclass for all Rocketry classes"""
Jypear marked this conversation as resolved.
Show resolved Hide resolved
session: 'Session' = None

# Commented this out for now as it was causing issues with the new pydantic implementation
session: 'Session'
14 changes: 12 additions & 2 deletions rocketry/_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,19 @@ def _setup_defaults():
_FuncTaskCondWrapper
)
for cls_task in cls_tasks:
cls_task.update_forward_refs(Session=Session, BaseCondition=BaseCondition)
#cls_task.update_forward_refs(Session=Session, BaseCondition=BaseCondition)
cls_task.model_rebuild(
force=True,
_types_namespace={"Session": Session, "BaseCondition": BaseCondition},
_parent_namespace_depth=4
)

Config.update_forward_refs(BaseCondition=BaseCondition)
# Config.update_forward_refs(BaseCondition=BaseCondition)
Config.model_rebuild(
force=True,
_types_namespace={"Session": Session, "BaseCondition": BaseCondition},
_parent_namespace_depth=4
)
#Session.update_forward_refs(
# Task=Task, Parameters=Parameters, Scheduler=Scheduler
#)
4 changes: 2 additions & 2 deletions rocketry/conditions/meta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import copy
from typing import Callable, Optional, Pattern, Union
from typing import Callable, ClassVar, Optional, Pattern, Union

from pydantic import Field
from rocketry.args import Session
Expand All @@ -13,7 +13,7 @@
class _FuncTaskCondWrapper(FuncTask):

# For some reason, the order of cls attrs broke here so we need to reorder then:
session: _Session
session: ClassVar[_Session]
name: Optional[str] = Field(description="Name of the task. Must be unique")

def _handle_return(self, value):
Expand Down
88 changes: 50 additions & 38 deletions rocketry/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import threading
from queue import Empty
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Dict, Type, Union, Tuple, Optional
from typing_extensions import Annotated
try:
from typing import Literal
except ImportError: # pragma: no cover
from typing_extensions import Literal

from pydantic import BaseModel, Field, PrivateAttr, validator
from pydantic import BaseModel, Field, PrivateAttr, ConfigDict, field_validator, field_serializer

from rocketry._base import RedBase
from rocketry.core.condition import BaseCondition, AlwaysFalse, All
Expand Down Expand Up @@ -94,7 +95,7 @@ def is_async(self) -> bool:
def is_thread(self) -> bool:
return isinstance(self.task, threading.Thread)

class Task(RedBase, BaseModel):
class Task(BaseModel, RedBase):
"""Base class for Tasks.

A task can be a function, command or other procedure that
Expand Down Expand Up @@ -192,42 +193,37 @@ class Task(RedBase, BaseModel):
... return ...

"""
class Config:
arbitrary_types_allowed= True
underscore_attrs_are_private = True
validate_assignment = True
json_encoders = {
Parameters: lambda v: v.to_json(),
'BaseCondition': lambda v: str(v),
FunctionType: lambda v: v.__name__,
'Session': lambda v: id(v),
}

model_config = ConfigDict(
arbitrary_types_allowed= True,
validate_assignment = True,
validate_default=True,
extra='allow',
)

session: 'Session' = Field()
session: 'Session' = Field(validate_default=False, default=None)

# Class
permanent: bool = False # Whether the task is not meant to finish (Ie. RestAPI)
_actions: ClassVar[Tuple] = ("run", "fail", "success", "inaction", "terminate", None, "crash")
fmt_log_message: str = r"Task '{task}' status: '{action}'"

daemon: Optional[bool]
daemon: Optional[bool] = False
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, if this is not set (should be None), this is fetched from the session configs thus setting default here doesn't allow us to see that the user didn't want to specify the daemon attr for this particular task (thus should be gotten from the config).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did have issues with this field as part of the migration as well. If this field is not set, then during the __init__ for Task it would raise a validation error for daemon when calling super().__init__(**kwargs).

Exception has occurred: ValidationError
1 validation error for FuncTask
daemon
  Field required [type=missing, input_value={'func': <function do_thi...s', 'name': 'do_things'}, input_type=dict]
    For further information visit https://errors.pydantic.dev/2.0.3/v/missing

Noting as well that this validation error is coming from FuncTask rather than task so I'm not sure if this is some inheritance weirdness again.

Its strange with the field clearly being stated as an Optional[bool] in the class. Would that functionality still work as intended when setting the field with a default value of None instead of False? That still seems to allow the code to run. Or would that still cause the same issues with not inheriting that attribute from Config?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean on this now

    def run_as_process(self, params:Parameters, direct_params:Parameters, task_run:TaskRun, daemon=None, log_queue: multiprocessing.Queue=None):
        """Create a new process and run the task on that."""

        session = self.session

        params = params.pre_materialize(task=self, session=session)
        direct_params = direct_params.pre_materialize(task=self, session=session)

        # Daemon resolution: task.daemon >> scheduler.tasks_as_daemon
        log_queue = session.scheduler._log_queue if log_queue is None else log_queue

        daemon = self.daemon if self.daemon is not None else session.config.tasks_as_daemon

Hopefully with a default value of None this should work now rather than False.

batches: List[Parameters] = Field(
default_factory=list,
description="Run batches (parameters). If not empty, run is triggered regardless of starting condition"
)

# Instance
name: Optional[str] = Field(description="Name of the task. Must be unique")
description: Optional[str] = Field(description="Description of the task for documentation")
logger_name: Optional[str] = Field(description="Logger name to be used in logging the task records")
execution: Optional[Literal['main', 'async', 'thread', 'process']]
name: Optional[str] = Field(description="Name of the task. Must be unique", default=None)
description: Optional[str] = Field(description="Description of the task for documentation", default=None)
logger_name: Optional[str] = Field(description="Logger name to be used in logging the task records", default="rocketry.task")
execution: Optional[Literal['main', 'async', 'thread', 'process']] = None
priority: int = 0
disabled: bool = False
force_run: bool = False
force_termination: bool = False
status: Optional[Literal['run', 'fail', 'success', 'terminate', 'inaction', 'crash']] = Field(description="Latest status of the task")
timeout: Optional[datetime.timedelta]
status: Optional[Literal['run', 'fail', 'success', 'terminate', 'inaction', 'crash']] = Field(description="Latest status of the task", default=None)
timeout: Optional[datetime.timedelta] = None

parameters: Parameters = Parameters()

Expand All @@ -237,7 +233,7 @@ class Config:
multilaunch: Optional[bool] = None
on_startup: bool = False
on_shutdown: bool = False
func_run_id: Callable = None
func_run_id: Union[Callable, None] = None

_last_run: Optional[float]
_last_success: Optional[float]
Expand All @@ -252,29 +248,29 @@ class Config:

_mark_running = False

@validator('start_cond', pre=True)
@field_validator('start_cond', mode="before")
def parse_start_cond(cls, value, values):
from rocketry.parse.condition import parse_condition
session = values['session']
session = values.data['session']
if isinstance(value, str):
value = parse_condition(value, session=session)
elif value is None:
value = AlwaysFalse()
return copy(value)

@validator('end_cond', pre=True)
@field_validator('end_cond', mode="before")
def parse_end_cond(cls, value, values):
from rocketry.parse.condition import parse_condition
session = values['session']
session = values.data['session']
if isinstance(value, str):
value = parse_condition(value, session=session)
elif value is None:
value = AlwaysFalse()
return copy(value)

@validator('logger_name', pre=True, always=True)
@field_validator('logger_name', mode="before")
def parse_logger_name(cls, value, values):
session = values['session']
session = values.data['session']

if isinstance(value, str):
logger_name = value
Expand All @@ -287,7 +283,7 @@ def parse_logger_name(cls, value, values):
raise ValueError(f"Logger name must start with '{basename}' as session finds loggers with names")
return logger_name

@validator('timeout', pre=True, always=True)
@field_validator('timeout', mode="before")
def parse_timeout(cls, value, values):
if value == "never":
return datetime.timedelta.max
Expand All @@ -296,6 +292,22 @@ def parse_timeout(cls, value, values):
if value is not None:
return to_timedelta(value)
return value

@field_serializer("parameters", when_used="json")
def ser_parameters(self, parameters):
return parameters.to_json()

@field_serializer("start_cond", when_used="json")
def ser_start_cond(self, start_cond):
return str(start_cond)

@field_serializer("end_cond", when_used="json")
def ser_end_cond(self, end_cond):
return str(end_cond)

@field_serializer("session", when_used="json", check_fields=False)
def ser_session(self, session):
return id(session)

@property
def logger(self):
Expand Down Expand Up @@ -339,9 +351,9 @@ def _get_name(self, name=None, **kwargs):
return self.get_default_name(**kwargs)
return name

@validator('name', pre=True)
@field_validator('name', mode="before")
def parse_name(cls, value, values):
session = values['session']
session = values.data['session']
on_exists = session.config.task_pre_exist
name_exists = value in session
if name_exists:
Expand All @@ -359,9 +371,9 @@ def parse_name(cls, value, values):
return name
return value

@validator('name', pre=False)
@field_validator('name', mode="after")
def validate_name(cls, value, values):
session = values['session']
session = values.data['session']
on_exists = session.config.task_pre_exist
name_exists = value in session

Expand All @@ -371,17 +383,17 @@ def validate_name(cls, value, values):
raise ValueError(f"Task name '{value}' already exists. Please pick another")
return value

@validator('parameters', pre=True)
@field_validator('parameters', mode="before")
def parse_parameters(cls, value):
if isinstance(value, Parameters):
return value
return Parameters(value)

@validator('force_run', pre=False)
@field_validator('force_run', mode="after")
def parse_force_run(cls, value, values):
if value:
warnings.warn("Attribute 'force_run' is deprecated. Please use method set_running() instead", DeprecationWarning)
values['batches'].append(Parameters())
values.data['batches'].append(Parameters())
return value

def __hash__(self):
Expand Down Expand Up @@ -1294,8 +1306,8 @@ def __getstate__(self):
#state['__dict__'] = state['__dict__'].copy()

# remove unpicklable
state['__private_attribute_values__'] = state['__private_attribute_values__'].copy()
priv_attrs = state['__private_attribute_values__']
state['__pydantic_private__'] = state['__pydantic_private__'].copy()
priv_attrs = state['__pydantic_private__']
Jypear marked this conversation as resolved.
Show resolved Hide resolved
priv_attrs['_lock'] = None
priv_attrs['_process'] = None
priv_attrs['_thread'] = None
Expand Down
25 changes: 14 additions & 11 deletions rocketry/log/log_record.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
from typing import Optional
from pydantic import BaseModel, Field, validator
from pydantic import field_validator, BaseModel, Field

from rocketry.pybox.time import to_datetime, to_timedelta

Expand Down Expand Up @@ -38,36 +38,39 @@ class LogRecord(MinimalRecord):

class TaskLogRecord(MinimalRecord):

start: Optional[datetime.datetime]
end: Optional[datetime.datetime]
runtime: Optional[datetime.timedelta]
start: Optional[datetime.datetime] = None
end: Optional[datetime.datetime] = None
runtime: Optional[datetime.timedelta] = None

message: str
exc_text: Optional[str]
exc_text: Optional[str] = None

@validator("start", pre=True)
@field_validator("start", mode="before")
@classmethod
def format_start(cls, value):
if value is not None:
value = to_datetime(value)
return value

@validator("end", pre=True)
@field_validator("end", mode="before")
@classmethod
def format_end(cls, value):
if value is not None:
value = to_datetime(value)
return value

@validator("runtime", pre=True)
@field_validator("runtime", mode="before")
@classmethod
def format_runtime(cls, value):
if value is not None:
value = to_timedelta(value)
return value

class MinimalRunRecord(MinimalRecord):
run_id: Optional[str]
run_id: Optional[str] = None

class RunRecord(LogRecord):
run_id: Optional[str]
run_id: Optional[str] = None

class TaskRunRecord(TaskLogRecord):
run_id: Optional[str]
run_id: Optional[str] = None
Loading
Loading