Skip to content

Commit

Permalink
Merge pull request #24 from CyberAgentAILab/feat/plot-color
Browse files Browse the repository at this point in the history
Add color parameter to plot util
  • Loading branch information
TomeHirata authored Feb 13, 2025
2 parents 451ec54 + 7d9a005 commit 04a97db
Show file tree
Hide file tree
Showing 11 changed files with 136 additions and 132 deletions.
Binary file modified docs/source/_static/dte_empirical.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/_static/dte_moment.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/_static/dte_simple.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/_static/dte_uniform.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/pte_empirical.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/source/_static/pte_simple.png
Binary file not shown.
Binary file modified docs/source/_static/qte.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/get_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ To compute PTE, we can use "predict_pte" method.
pte, lower_bound, upper_bound = estimator.predict_pte(target_treatment_arm=1, control_treatment_arm=0, width=1, locations=locations, variance_type="simple")
plot(locations, pte, lower_bound, upper_bound, chart_type="bar", title="PTE of adjusted estimator with simple confidence band")
.. image:: _static/pte_simple.png
.. image:: _static/pte_empirical.png
:alt: PTE of adjusted estimator with simple confidence band
:height: 300px
:width: 450px
Expand Down
10 changes: 7 additions & 3 deletions dte_adj/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ def plot(
means: np.ndarray,
lower_bounds: np.ndarray,
upper_bounds: np.ndarray,
chart_type="line",
chart_type: str = "line",
color: str = "green",
ax: Optional[axis.Axis] = None,
title: Optional[str] = None,
xlabel: Optional[str] = None,
Expand All @@ -23,6 +24,7 @@ def plot(
lower_bounds (np.Array): Lower bound for the distributional parameters.
upper_bounds (np.Array): Upper bound for the distributional parameters.
chart_type (str): Chart type of the plotting. Available values are line or bar.
color (str): The color of lines or bars.
ax (matplotlib.axes.Axes, optional): Target axes instance. If None, a new figure and axes will be created.
title (str, optional): Axes title.
xlabel (str, optional): X-axis title label.
Expand All @@ -35,12 +37,12 @@ def plot(
fig, ax = plt.subplots()

if chart_type == "line":
ax.plot(X, means, label="Values", color="blue")
ax.plot(X, means, label="Values", color=color)
ax.fill_between(
X,
lower_bounds,
upper_bounds,
color="gray",
color=color,
alpha=0.3,
label="Confidence Interval",
)
Expand All @@ -53,6 +55,8 @@ def plot(
np.maximum(upper_bounds - means, 0),
],
capsize=5,
color=color,
width=(X.max() - X.min()) / len(X),
)
else:
raise ValueError(f"Chart type {chart_type} is not supported")
Expand Down
254 changes: 127 additions & 127 deletions example/example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_plot(self, mock_plt):
self.assertTrue(np.array_equal(x_fill, x_values_arg))
self.assertTrue(np.array_equal(lower_fill, lower_bands))
self.assertTrue(np.array_equal(upper_fill, upper_bands))
self.assertEqual(fill_between_kwargs["color"], "gray")
self.assertEqual(fill_between_kwargs["color"], "green")
self.assertAlmostEqual(fill_between_kwargs["alpha"], 0.3)
self.assertEqual(fill_between_kwargs["label"], "Confidence Interval")
mock_ax.set_title.assert_called_once_with("Test Title")
Expand Down

0 comments on commit 04a97db

Please sign in to comment.