Skip to content

Commit

Permalink
[Tidy] Remove component to data mapping from data manager (#451)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
antonymilne and pre-commit-ci[bot] committed May 14, 2024
1 parent 64d8602 commit 5d5d8a8
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 87 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
4 changes: 2 additions & 2 deletions vizro-core/src/vizro/actions/_actions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ def _get_filtered_data(
) -> Dict[ModelID, pd.DataFrame]:
filtered_data = {}
for target in targets:
data_frame = data_manager._get_component_data(target)

data_source_name = model_manager[target]["data_frame"]
data_frame = data_manager[data_source_name].load()
data_frame = _apply_filters(data_frame=data_frame, ctds_filters=ctds_filters, target=target)
data_frame = _apply_filter_interaction(
data_frame=data_frame, ctds_filter_interaction=ctds_filter_interaction, target=target
Expand Down
33 changes: 1 addition & 32 deletions vizro-core/src/vizro/managers/_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import logging
import os
import warnings
from typing import Any, Callable, Dict, Optional, Protocol, Union

Expand All @@ -15,10 +14,8 @@

logger = logging.getLogger(__name__)

# Really ComponentID and DataSourceName should be NewType and not just aliases but then for a user's code to type check
# Really DataSourceName should be NewType and not just aliases but then for a user's code to type check
# correctly they would need to cast all strings to these types.
# TODO: remove these type aliases once have moved component to data mapping to models
ComponentID = str
DataSourceName = str
pd_DataFrameCallable = Callable[[], pd.DataFrame]

Expand Down Expand Up @@ -186,7 +183,6 @@ class DataManager:

def __init__(self):
self.__data: Dict[DataSourceName, Union[_DynamicData, _StaticData]] = {}
self.__component_to_data: Dict[ComponentID, DataSourceName] = {}
self._frozen_state = False
self.cache = Cache(config={"CACHE_TYPE": "NullCache"})
# In future, possibly we will accept just a config dict. Would need to work out whether to handle merging with
Expand Down Expand Up @@ -227,33 +223,6 @@ def __getitem__(self, name: DataSourceName) -> Union[_DynamicData, _StaticData]:
except KeyError as exc:
raise KeyError(f"Data source {name} does not exist.") from exc

@_state_modifier
def _add_component(self, component_id: ComponentID, name: DataSourceName):
"""Adds a mapping from `component_id` to `name`."""
# TODO: once have removed self.__component_to_data, we shouldn't need this function any more.
# Maybe always updated capturedcallable data_frame to data source name string then.
if name not in self.__data:
raise KeyError(f"Data source {name} does not exist.")
if component_id in self.__component_to_data:
raise ValueError(
f"Component with id={component_id} already exists and is mapped to data "
f"{self.__component_to_data[component_id]}. Components must uniquely map to a data source across the "
f"whole dashboard. If you are working from a Jupyter Notebook, please either restart the kernel, or "
f"use 'from vizro import Vizro; Vizro._reset()`."
)
self.__component_to_data[component_id] = name

def _get_component_data(self, component_id: ComponentID) -> pd.DataFrame:
# TODO: once have removed self.__component_to_data, we shouldn't need this function any more. Calling
# functions would just do data_manager[name].load().
"""Returns the original data for `component_id`."""
if component_id not in self.__component_to_data:
raise KeyError(f"Component {component_id} does not exist. You need to call add_component first.")
name = self.__component_to_data[component_id]

logger.debug("Loading data %s on process %s", name, os.getpid())
return self[name].load()

def _clear(self):
# We do not actually call self.cache.clear() because (a) it would only work when self._cache_has_app is True,
# which is not the case when e.g. Vizro._reset is called, and (b) because we do not want to accidentally
Expand Down
42 changes: 22 additions & 20 deletions vizro-core/src/vizro/models/_components/_components_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,30 @@ def _callable_mode_validator_factory(mode: str):
return validator("figure", allow_reuse=True)(check_callable_mode)


def _process_callable_data_frame(captured_callable, values):
def _process_callable_data_frame(captured_callable):
# Possibly all this validator's functionality should move into CapturedCallable (or a subclass of it) in the
# future. This would mean that data is added to the data manager outside the context of a dashboard though,
# which might not be desirable.
data_frame = captured_callable["data_frame"]

if isinstance(data_frame, str):
# Named data source, which could be dynamic or static. This means px.scatter("iris") from the Python API and
# specification of "data_frame": "iris" through JSON. In these cases, data already exists in the data manager
# and just needs to be linked to the component.
data_source_name = data_frame
else:
# Unnamed data source, which must be a pd.DataFrame and hence static data. This means px.scatter(pd.DataFrame())
# and is only possible from the Python API. Extract dataframe from the captured function and put it into the
# data manager.
# Unlike with model_manager, it doesn't matter if the random seed is different across workers here. So long as
# we always fetch static data from the data manager by going through the appropriate Figure component, the right
# data source name will be fetched. It also doesn't matter if multiple Figures with the same underlying data
# each have their own entry in the data manager, since the underlying pd.DataFrame will still be the same and
# not copied into each one, so no memory is wasted.
logger.debug("Adding data to data manager for Figure with id %s", values["id"])
data_source_name = str(uuid.uuid4())
data_manager[data_source_name] = data_frame

data_manager._add_component(values["id"], data_source_name)
# No need to keep the data in the captured function any more so remove it to save memory.
del captured_callable["data_frame"]
# specification of "data_frame": "iris" through JSON. In these cases, data already exists in the data manager.
return captured_callable

# Unnamed data source, which must be a pd.DataFrame and hence static data. This means px.scatter(pd.DataFrame())
# and is only possible from the Python API. Extract dataframe from the captured function and put it into the
# data manager.
# Unlike with model_manager, it doesn't matter if the random seed is different across workers here. So long as
# we always fetch static data from the data manager by going through the appropriate Figure component, the right
# data source name will be fetched. It also doesn't matter if multiple Figures with the same underlying data
# each have their own entry in the data manager, since the underlying pd.DataFrame will still be the same and
# not copied into each one, so no memory is wasted.
# Replace the "data_frame" argument in the captured callable with the data_source_name for consistency with
# dynamic data and to save memory. This way we always access data via the same interface regardless of whether it's
# static or dynamic.
data_source_name = str(uuid.uuid4())
data_manager[data_source_name] = data_frame
captured_callable["data_frame"] = data_source_name

return captured_callable
2 changes: 1 addition & 1 deletion vizro-core/src/vizro/models/_components/ag_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class AgGrid(VizroBaseModel):

# Convenience wrapper/syntactic sugar.
def __call__(self, **kwargs):
kwargs.setdefault("data_frame", data_manager._get_component_data(self.id))
kwargs.setdefault("data_frame", data_manager[self["data_frame"]].load())
figure = self.figure(**kwargs)
figure.id = self._input_component_id
return figure
Expand Down
5 changes: 4 additions & 1 deletion vizro-core/src/vizro/models/_components/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ class Graph(VizroBaseModel):

# Convenience wrapper/syntactic sugar.
def __call__(self, **kwargs):
kwargs.setdefault("data_frame", data_manager._get_component_data(str(self.id)))
# This default value is not actually used anywhere at the moment since __call__ is always used with data_frame
# specified. It's here to match Table and AgGrid and because we might want to use __call__ more in future.
# If the functionality of process_callable_data_frame moves to CapturedCallable then this would move there too.
kwargs.setdefault("data_frame", data_manager[self["data_frame"]].load())
fig = self.figure(**kwargs)

# Remove top margin if title is provided
Expand Down
2 changes: 1 addition & 1 deletion vizro-core/src/vizro/models/_components/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Table(VizroBaseModel):

# Convenience wrapper/syntactic sugar.
def __call__(self, **kwargs):
kwargs.setdefault("data_frame", data_manager._get_component_data(self.id))
kwargs.setdefault("data_frame", data_manager[self["data_frame"]].load())
figure = self.figure(**kwargs)
figure.id = self._input_component_id
return figure
Expand Down
16 changes: 12 additions & 4 deletions vizro-core/src/vizro/models/_controls/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,20 @@ def _set_targets(self):
for component_id in model_manager._get_page_model_ids_with_figure(
page_id=model_manager._get_model_page_id(model_id=ModelID(str(self.id)))
):
data_frame = data_manager._get_component_data(component_id)
# TODO: consider making a helper method in data_manager or elsewhere to reduce this operation being
# duplicated across Filter so much, and/or consider storing the result to avoid repeating it.
# Need to think about this in connection with how to update filters on the fly and duplicated calls
# issue outlined in https://github.com/mckinsey/vizro/pull/398#discussion_r1559120849.
data_source_name = model_manager[component_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
if self.column in data_frame.columns:
self.targets.append(component_id)
if not self.targets:
raise ValueError(f"Selected column {self.column} not found in any dataframe on this page.")

def _set_column_type(self):
data_frame = data_manager._get_component_data(self.targets[0])
data_source_name = model_manager[self.targets[0]]["data_frame"]
data_frame = data_manager[data_source_name].load()

if is_numeric_dtype(data_frame[self.column]):
self._column_type = "numerical"
Expand All @@ -146,7 +152,8 @@ def _set_numerical_and_temporal_selectors_values(self):
min_values = []
max_values = []
for target_id in self.targets:
data_frame = data_manager._get_component_data(target_id)
data_source_name = model_manager[target_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
min_values.append(data_frame[self.column].min())
max_values.append(data_frame[self.column].max())

Expand All @@ -173,7 +180,8 @@ def _set_categorical_selectors_options(self):
if isinstance(self.selector, SELECTORS["categorical"]) and not self.selector.options:
options = set()
for target_id in self.targets:
data_frame = data_manager._get_component_data(target_id)
data_source_name = model_manager[target_id]["data_frame"]
data_frame = data_manager[data_source_name].load()
options |= set(data_frame[self.column])

self.selector.options = sorted(options)
Expand Down
6 changes: 3 additions & 3 deletions vizro-core/src/vizro/models/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def __getitem__(self, arg_name: str):
"""Gets the value of a bound argument."""
return self.__bound_arguments[arg_name]

def __delitem__(self, arg_name: str):
"""Deletes a bound argument."""
del self.__bound_arguments[arg_name]
def __setitem__(self, arg_name: str, value):
"""Sets the value of a bound argument."""
self.__bound_arguments[arg_name] = value

@property
def _arguments(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,11 @@ class TestProcessAgGridDataFrame:
def test_process_figure_data_frame_str_df(self, dash_ag_grid_with_str_dataframe, gapminder):
data_manager["gapminder"] = gapminder
ag_grid = vm.AgGrid(id="ag_grid", figure=dash_ag_grid_with_str_dataframe)
assert data_manager._get_component_data("ag_grid").equals(gapminder)
with pytest.raises(KeyError, match="'data_frame'"):
ag_grid["data_frame"]
assert data_manager[ag_grid["data_frame"]].load().equals(gapminder)

def test_process_figure_data_frame_df(self, standard_ag_grid, gapminder):
ag_grid = vm.AgGrid(id="ag_grid", figure=standard_ag_grid)
assert data_manager._get_component_data("ag_grid").equals(gapminder)
with pytest.raises(KeyError, match="'data_frame'"):
ag_grid["data_frame"]
assert data_manager[ag_grid["data_frame"]].load().equals(gapminder)


class TestPreBuildAgGrid:
Expand Down
8 changes: 2 additions & 6 deletions vizro-core/tests/unit/vizro/models/_components/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,11 @@ class TestProcessGraphDataFrame:
def test_process_figure_data_frame_str_df(self, standard_px_chart_with_str_dataframe, gapminder):
data_manager["gapminder"] = gapminder
graph = vm.Graph(id="graph", figure=standard_px_chart_with_str_dataframe)
assert data_manager._get_component_data("graph").equals(gapminder)
with pytest.raises(KeyError, match="'data_frame'"):
graph["data_frame"]
assert data_manager[graph["data_frame"]].load().equals(gapminder)

def test_process_figure_data_frame_df(self, standard_px_chart, gapminder):
graph = vm.Graph(id="graph", figure=standard_px_chart)
assert data_manager._get_component_data("graph").equals(gapminder)
with pytest.raises(KeyError, match="'data_frame'"):
graph["data_frame"]
assert data_manager[graph["data_frame"]].load().equals(gapminder)


class TestBuild:
Expand Down
8 changes: 2 additions & 6 deletions vizro-core/tests/unit/vizro/models/_components/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,11 @@ class TestProcessTableDataFrame:
def test_process_figure_data_frame_str_df(self, dash_table_with_str_dataframe, gapminder):
data_manager["gapminder"] = gapminder
table = vm.Table(id="table", figure=dash_table_with_str_dataframe)
assert data_manager._get_component_data("table").equals(gapminder)
with pytest.raises(KeyError, match="'data_frame'"):
table["data_frame"]
assert data_manager[table["data_frame"]].load().equals(gapminder)

def test_process_figure_data_frame_df(self, standard_dash_table, gapminder):
table = vm.Table(id="table", figure=standard_dash_table)
assert data_manager._get_component_data("table").equals(gapminder)
with pytest.raises(KeyError, match="'data_frame'"):
table["data_frame"]
assert data_manager[table["data_frame"]].load().equals(gapminder)


class TestPreBuildTable:
Expand Down
8 changes: 3 additions & 5 deletions vizro-core/tests/unit/vizro/models/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,9 @@ def test_getitem_unknown_args(self, captured_callable):
with pytest.raises(KeyError):
captured_callable["c"]

def test_delitem(self, captured_callable):
del captured_callable["a"]

with pytest.raises(KeyError):
captured_callable["a"]
def test_setitem(self, captured_callable):
captured_callable["a"] = 2
assert captured_callable["a"] == 2


@pytest.mark.parametrize(
Expand Down

0 comments on commit 5d5d8a8

Please sign in to comment.