Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tidy] Convert actions to classes #363

Draft
wants to merge 27 commits into
base: dev/actions_v2
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6cbf33d
First very raw version of class based action function with validation…
petar-qb Mar 4, 2024
df0c7cf
Minor improvements
petar-qb Mar 8, 2024
8fac8e5
Minor refactoring
petar-qb Mar 8, 2024
9851cb9
PoC - Validation and eager input arguments calculation
petar-qb Mar 12, 2024
87db5d7
Merge branch 'main' of https://github.com/mckinsey/vizro into tidy/ac…
petar-qb Mar 12, 2024
7b8a6c3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
c40ea85
Minor changes
petar-qb Mar 12, 2024
99b684d
Merge branch 'tidy/actions_to_class' of https://github.com/mckinsey/v…
petar-qb Mar 12, 2024
0c260e6
Minor change
petar-qb Mar 12, 2024
76fe66e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2024
6d88c50
Fixing a bug in model_manager._get_model_page_id + More refactoring
petar-qb Mar 28, 2024
400ac60
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
b823e6b
All actions moved to the class based approach
petar-qb Apr 8, 2024
31bd5a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
0c65e2c
Small TODO refactoring
petar-qb Apr 8, 2024
d0eac24
Solving conflicts
petar-qb Apr 8, 2024
1107486
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 8, 2024
caf5012
Small TODOs changes
petar-qb Apr 9, 2024
6b15541
Exposing vm.Page.actions argument
petar-qb Apr 15, 2024
4d2ba9a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
f1235c4
Examples and other changes added
petar-qb Apr 25, 2024
c8a2618
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2024
de20b21
Conflicts solved
petar-qb Apr 26, 2024
80238ed
Merge branch 'tidy/actions_to_class' of https://github.com/mckinsey/v…
petar-qb Apr 26, 2024
bedf430
Reverting docstrings for examples files
petar-qb Apr 26, 2024
3ed5f86
Add more validation + Clearning some TODO-AV2
petar-qb Apr 26, 2024
387ce00
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vizro-core/examples/_dev/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import vizro.models as vm
import vizro.plotly.express as px
from vizro import Vizro
from vizro.actions import filter_interaction
from vizro.actions import export_data_class_action, filter_interaction
from vizro.tables import dash_ag_grid, dash_data_table

df_gapminder = px.data.gapminder().query("year == 2007")


dashboard = vm.Dashboard(
pages=[
vm.Page(
Expand All @@ -27,6 +28,10 @@
color="continent",
),
),
vm.Button(
text="Export data",
actions=[vm.Action(function=export_data_class_action(targets=["scatter"]))],
),
],
controls=[vm.Filter(column="continent")],
),
Expand Down
3 changes: 2 additions & 1 deletion vizro-core/src/vizro/actions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from vizro.actions._on_page_load_action import _on_page_load
from vizro.actions._parameter_action import _parameter
from vizro.actions.export_data_action import export_data
from vizro.actions.export_data_class_action import export_data_class_action
from vizro.actions.filter_interaction_action import filter_interaction

# Please keep alphabetically ordered
__all__ = ["export_data", "filter_interaction"]
__all__ = ["export_data", "export_data_class_action", "filter_interaction"]
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,20 @@ def _get_inputs_of_controls(page: Page, control_type: ControlType) -> List[State
]


def _get_inputs_of_filters(page: Page, action_function: Callable[[Any], Dict[str, Any]]) -> List[State]:
"""Gets list of `States` for selected `control_type` of triggered `Page`."""
filter_actions_on_page = _get_matching_actions_by_function(
page_id=ModelID(str(page.id)), action_function=action_function
)
inputs = []
# TODO-actions: Take the "actions_info" into account once it's implemented.
for action in filter_actions_on_page:
triggered_model = model_manager._get_action_trigger(action_id=ModelID(str(action.id)))
inputs.append(State(component_id=triggered_model.id, component_property=triggered_model._input_property))

return inputs


def _get_inputs_of_figure_interactions(
page: Page, action_function: Callable[[Any], Dict[str, Any]]
) -> List[Dict[str, State]]:
Expand All @@ -45,6 +59,7 @@ def _get_inputs_of_figure_interactions(
page_id=ModelID(str(page.id)), action_function=action_function
)
inputs = []
# TODO-actions: Take the "actions_info" into account once it's implemented.
for action in figure_interactions_on_page:
triggered_model = model_manager._get_action_trigger(action_id=ModelID(str(action.id)))
required_attributes = ["_filter_interaction_input", "_filter_interaction"]
Expand Down
110 changes: 110 additions & 0 deletions vizro-core/src/vizro/actions/export_data_class_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import importlib
from typing import Any, Dict, List, Literal

from dash import Output, State, ctx, dcc

from vizro.actions._filter_action import _filter
from vizro.actions.filter_interaction_action import filter_interaction
from vizro.managers import model_manager
from vizro.models.types import CapturedActionCallable


class ExportDataClassAction(CapturedActionCallable):
def __init__(self, *args, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this __init__ for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used it to save input args and kwargs so they can be validated/adjusted in the _post_init.
Now, I got rid of the constructor and self._arguments is used inside the _post_init.

self._args = args
self._kwargs = kwargs
# Fake initialization - to let other actions see that this one exists.
super().__init__(*args, **kwargs)

def _post_init(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be called directly in __init__ or is that too early?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's too early to call _post_init validation inside the action initialisation phase. To be able to validate action input arguments properly, all dashboard models have to be initialised. Some of the models are initialised in the _pre_build phase of other models which means that the actions _post_init has to be called within the build phase.

"""Post initialization is called in the vm.Action build phase, and it is used to validate and calculate the
properties of the CapturedActionCallable. With this, we can validate the properties and raise errors before
the action is built. Also, "input"/"output"/"components" properties and "pure_function" can use these validated
and the calculated arguments.
"""
self._page_id = model_manager._get_model_page_id(model_id=self._action_id)

# Validate and calculate "targets"
targets = self._kwargs.get("targets")
if targets:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like an improvement on the old function version because we now validate targets upfront rather than at runtime. Is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's right! 😄

for target in targets:
if target not in model_manager:
raise ValueError(f"Component '{target}' does not exist on the page '{self._page_id}'.")
else:
targets = model_manager._get_page_model_ids_with_figure(page_id=self._page_id)
self._kwargs["targets"] = self.targets = targets
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand what's happening with the _kwargs stuff here, please could you explain?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now this line looks like this:
self._arguments["targets"] = self.targets = targets

  • targets - represents new validated and calculated action's targets
  • self._arguments["targets"] = targets - overwrites the pure_function input argument. Since the pure_function is a staticmethod (and I'm pretty sure it should remain), new calculated targets have to be propagated through the self._arguments.
  • self.targets = targets - Optionally, some of CapturedActionCallable attributes are also calculated in the _post_init phase. In this case, self.targets is created calculated so it can be easily reused inside the outputs and components` calculations.


# Validate and calculate "file_format"
file_format = self._kwargs.get("file_format", "csv")
if file_format not in ["csv", "xlsx"]:
raise ValueError(f'Unknown "file_format": {file_format}.' f' Known file formats: "csv", "xlsx".')
if file_format == "xlsx":
if importlib.util.find_spec("openpyxl") is None and importlib.util.find_spec("xlsxwriter") is None:
raise ModuleNotFoundError("You must install either openpyxl or xlsxwriter to export to xlsx format.")
self._kwargs["file_format"] = self.file_format = file_format

# Post initialization - to enable pure_function to use calculated input arguments like "targets".
super().__init__(*self._args, **self._kwargs)

@staticmethod
def pure_function(targets: List[str], file_format: Literal["csv", "xlsx"] = "csv", **inputs: Dict[str, Any]):
from vizro.actions._actions_utils import _get_filtered_data

data_frames = _get_filtered_data(
targets=targets,
ctds_filters=ctx.args_grouping["external"]["filters"],
ctds_filter_interaction=ctx.args_grouping["external"]["filter_interaction"],
)

outputs = {}
for target_id in targets:
if file_format == "csv":
writer = data_frames[target_id].to_csv
elif file_format == "xlsx":
writer = data_frames[target_id].to_excel

outputs[f"download_dataframe_{target_id}"] = dcc.send_data_frame(
writer=writer, filename=f"{target_id}.{file_format}", index=False
)

return outputs

@property
def inputs(self):
from vizro.actions._callback_mapping._callback_mapping_utils import (
_get_inputs_of_figure_interactions,
_get_inputs_of_filters,
)

page = model_manager[self._page_id]
return {
"filters": _get_inputs_of_filters(page=page, action_function=_filter.__wrapped__),
"filter_interaction": _get_inputs_of_figure_interactions(
page=page, action_function=filter_interaction.__wrapped__
),
# TODO-actions: Propagate theme_selector only if it exists on the page (could be overwritten by the user)
"theme_selector": State("theme_selector", "checked"),
}

@property
def outputs(self) -> Dict[str, Output]:
# TODO-actions: Take the "actions_info" into account once it's implemented.
return {
f"download_dataframe_{target}": Output(
component_id={"type": "download_dataframe", "action_id": self._action_id, "target_id": target},
component_property="data",
)
for target in self.targets
}

@property
def components(self):
# TODO-actions: Take the "actions_info" into account once it's implemented.
return [
dcc.Download(id={"type": "download_dataframe", "action_id": self._action_id, "target_id": target})
for target in self.targets
]


# Alias for ExportDataClassAction
export_data_class_action = ExportDataClassAction
36 changes: 21 additions & 15 deletions vizro-core/src/vizro/models/_action/_action.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib.util
import logging
from collections.abc import Collection, Mapping
from pprint import pformat
Expand All @@ -9,7 +8,7 @@
try:
from pydantic.v1 import Field, validator
except ImportError: # pragma: no cov
from pydantic import Field, validator
from pydantic import Field

import vizro.actions
from vizro.managers._model_manager import ModelID
Expand Down Expand Up @@ -49,18 +48,9 @@ class Action(VizroBaseModel):
# require, and make the code here look up the appropriate validation using the function as key
# This could then also involve other validations currently only carried out at run-time in pre-defined actions, such
# as e.g. checking if the correct arguments have been provided to the file_format in export_data.
@validator("function")
def validate_predefined_actions(cls, function):
if function._function.__name__ == "export_data":
file_format = function._arguments.get("file_format")
if file_format not in [None, "csv", "xlsx"]:
raise ValueError(f'Unknown "file_format": {file_format}.' f' Known file formats: "csv", "xlsx".')
if file_format == "xlsx":
if importlib.util.find_spec("openpyxl") is None and importlib.util.find_spec("xlsxwriter") is None:
raise ModuleNotFoundError(
"You must install either openpyxl or xlsxwriter to export to xlsx format."
)
return function
#
# @validator("function")
# def validate_predefined_actions(cls, function):

def _get_callback_mapping(self):
"""Builds callback inputs and outputs for the Action model callback, and returns action required components.
Expand All @@ -77,8 +67,14 @@ def _get_callback_mapping(self):
from vizro.actions._callback_mapping._get_action_callback_mapping import _get_action_callback_mapping

callback_inputs: Union[List[State], Dict[str, State]]
# TODO-actions: Refactor the following lines to:
# `callback_inputs = self.function.inputs + [State(*input.split(".")) for input in self.inputs]`
# After refactoring that's mentioned above, test overwriting of the predefined action.
# (by adding a new inputs/outputs to the overwritten action and check if it's working as expected)
if self.inputs:
callback_inputs = [State(*input.split(".")) for input in self.inputs]
elif hasattr(self.function, "inputs") and self.function.inputs:
callback_inputs = self.function.inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we're going to have a few switches here while we have both the "old" function actions and the new class ones.

Let's have a consistent way of doing this everywhere to make it clearer. I think just isinstance(function, CapturedActionCallable) would work? No need to check if self.function.inputs is Falsey or not I think either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we're going to have a few switches here while we have both the "old" function actions and the new class ones.

You're right, and after all actions become implemented as CapturedActionCallable, then the following line will be removed:
else: callback_inputs = _get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="inputs")

else:
callback_inputs = _get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="inputs")

Expand All @@ -91,10 +87,15 @@ def _get_callback_mapping(self):
# single element list (e.g. ["text"]).
if len(callback_outputs) == 1:
callback_outputs = callback_outputs[0]
elif hasattr(self.function, "outputs") and self.function.outputs:
callback_outputs = self.function.outputs
else:
callback_outputs = _get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="outputs")

action_components = _get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="components")
if hasattr(self.function, "components") and self.function.components:
action_components = self.function.components
else:
action_components = _get_action_callback_mapping(action_id=ModelID(str(self.id)), argument="components")

return callback_inputs, callback_outputs, action_components

Expand Down Expand Up @@ -152,6 +153,11 @@ def build(self):
List of required components (e.g. dcc.Download) for the Action model added to the `Dashboard` container.

"""
# Consider sending the entire action object
self.function._action_id = self.id
if hasattr(self.function, "_post_init"):
self.function._post_init()

external_callback_inputs, external_callback_outputs, action_components = self._get_callback_mapping()
callback_inputs = {
"external": external_callback_inputs,
Expand Down
24 changes: 24 additions & 0 deletions vizro-core/src/vizro/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import functools
import inspect
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Literal, Protocol, Union, runtime_checkable

try:
Expand Down Expand Up @@ -217,6 +218,29 @@ def _parse_json(
raise ValueError(f"_target_={function_name} must be wrapped in the @capture decorator.")


class CapturedActionCallable(CapturedCallable, metaclass=ABCMeta):
def __init__(self, *args, **kwargs):
super().__init__(self.pure_function, *args, **kwargs)

@staticmethod
@abstractmethod
# TODO-actions: Rename to "function"
def pure_function(*args, **kwargs):
"""This is the function that will be called when the action is triggered."""

@property
def inputs(self):
return []

@property
def outputs(self):
return []

@property
def components(self):
return []


class capture:
"""Captures a function call to create a [`CapturedCallable`][vizro.models.types.CapturedCallable].

Expand Down
Loading