Skip to content

Update typing #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ dependencies = [
"scikit-learn",
"statsmodels",
"scipy",
"typing-extensions",
]


# see https://peps.python.org/pep-0735/ and https://docs.astral.sh/uv/concepts/dependencies/#dependency-groups
[dependency-groups]
tests = [
Expand Down
8 changes: 6 additions & 2 deletions src/y0/algorithm/estimation/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,20 @@ def get_primal_ipw_point_estimate(
treatment: Variable,
treatment_value: int | float,
outcome: Variable,
) -> float:
) -> float | Any:
"""Estimate the counterfactual mean E[Y(t)] with the Primal IPW estimator on p-fixable graphs."""
# TODO: This function currently returns type Any to conform with mypy requirements.
# That's not the best reason to make a return type less restrictive.
# Consider warning the user if the type of the return value cannot be
# coerced to a float. See Issue #294. -callahanr
beta_primal = get_beta_primal(
data=data,
graph=graph,
treatment=treatment,
treatment_value=treatment_value,
outcome=outcome,
)
return cast(float, np.mean(beta_primal).item())
return np.mean(beta_primal).item()


def get_beta_primal(
Expand Down
4 changes: 3 additions & 1 deletion src/y0/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,7 +1564,9 @@ def _iter_variables(self) -> Iterable[Variable]:
yield from self.domain


Q = QFactor
# We need to declare a type for this alias to avoid false MyPy errors.
# See https://github.com/python/mypy/issues/7568
Q: type[QFactor] = QFactor

AA = Variable("AA")
A, B, C, D, E, F, G, M, R, S, T, U, W, X, Y, Z = map(Variable, "ABCDEFGMRSTUWXYZ")
Expand Down
2 changes: 1 addition & 1 deletion tests/data/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tqdm import tqdm, trange

from y0.algorithm.estimation import estimate_ace
from y0.examples import examples
from y0.examples import examples # type: ignore[attr-defined]

warnings.simplefilter(action="ignore", category=FutureWarning)

Expand Down
12 changes: 6 additions & 6 deletions tests/test_algorithm/test_cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
value_of_self_intervention,
)
from y0.dsl import A, B, D, Event, W, X, Y, Z
from y0.examples import (
from y0.examples import ( # type: ignore[attr-defined]
figure_9a,
figure_9b,
figure_9c,
Expand All @@ -47,11 +47,11 @@ class TestCounterfactualGraph(cases.GraphTestCase):
def test_world(self):
"""Test that a world contains an intervention."""
with self.assertRaises(TypeError):
input_world1: World = World([-x])
input_world1: World = World([-x]) # type: ignore[annotation-unchecked]
3 in input_world1 # noqa

with self.assertRaises(TypeError):
input_world1: World = World([3])
input_world1: World = World([3]) # type: ignore[annotation-unchecked]
3 in input_world1 # noqa

input_world2 = World([-x])
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_has_same_function(self):

def test_nodes_attain_same_value(self):
"""Test that two variables attain the same value."""
event: Event = {D: -d}
event: Event = {D: -d} # type: ignore[annotation-unchecked]
self.assertTrue(nodes_attain_same_value(figure_11a.graph, event, D, D @ -d))
self.assertTrue(nodes_attain_same_value(figure_11a.graph, event, D @ -d, D))
self.assertTrue(
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_has_same_confounders(self):
def test_parents_attain_same_values(self):
"""Test that the parents of two nodes attain the same value."""
graph = figure_9b.graph
event: Event = {Y @ -x: -y, D: -d, Z @ -d: -z, X: +x}
event: Event = {Y @ -x: -y, D: -d, Z @ -d: -z, X: +x} # type: ignore[annotation-unchecked]
self.assertTrue(parents_attain_same_values(figure_11a.graph, event, Z, Z @ -d))
self.assertTrue(parents_attain_same_values(figure_11a.graph, event, Z, Z @ -x))
self.assertTrue(parents_attain_same_values(figure_11a.graph, event, Z @ -d, Z @ -x))
Expand Down Expand Up @@ -544,7 +544,7 @@ def test_get_directed_edges(self):

def test_is_pw_equivalent(self):
"""Test that two nodes in a parallel world graph are the same (lemma 24)."""
event: Event = {Y @ -x: -y, D: -d, Z @ -d: -z, X: +x}
event: Event = {Y @ -x: -y, D: -d, Z @ -d: -z, X: +x} # type: ignore[annotation-unchecked]
self.assertTrue(is_pw_equivalent(figure_9b.graph, event, D @ -X, D))
self.assertTrue(is_pw_equivalent(figure_9b.graph, event, X @ -D, X))
self.assertTrue(is_pw_equivalent(figure_11a.graph, event, Z, Z @ -X))
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithm/test_conditional_independencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
get_conditional_independencies,
)
from y0.dsl import AA, B, C, D, E, F, G, Variable, X, Y
from y0.examples import (
from y0.examples import ( # type: ignore[attr-defined]
Example,
d_separation_example,
examples,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithm/test_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from y0.algorithm.estimation import ananke_average_causal_effect, df_covers_graph
from y0.algorithm.estimation.estimators import get_primal_ipw_ace, get_state_space_map
from y0.dsl import Variable
from y0.examples import examples, frontdoor, napkin, napkin_example
from y0.examples import examples, frontdoor, napkin, napkin_example # type: ignore[attr-defined]
from y0.graph import ANANKE_REQUIRED, is_p_fixable

TOLERANCE = 0.1
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithm/test_falsification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from y0.algorithm.conditional_independencies import get_conditional_independencies
from y0.algorithm.falsification import get_falsifications, get_graph_falsifications
from y0.examples import asia_example, frontdoor_example
from y0.examples import asia_example, frontdoor_example # type: ignore[attr-defined]
from y0.struct import get_conditional_independence_tests


Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithm/test_id_alg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
Z,
get_outcomes_and_treatments,
)
from y0.examples import (
from y0.examples import ( # type: ignore[attr-defined]
figure_6a,
line_1_example,
line_2_example,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithm/test_id_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
idc_star,
)
from y0.dsl import D, One, P, Sum, W, X, Y, Z, Zero
from y0.examples import (
from y0.examples import ( # type: ignore[attr-defined]
figure_9a,
figure_9c,
figure_9d,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithm/test_original_id_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
idc_star,
)
from y0.dsl import D, One, P, Sum, W, X, Y, Z, Zero
from y0.examples import (
from y0.examples import ( # type: ignore[attr-defined]
figure_9a,
figure_9c,
figure_9d,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithm/test_simplify_latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from y0.algorithm.taheri_design import taheri_design_dag
from y0.dsl import U1, U2, U3, Y1, Y2, Y3, U, Variable, W
from y0.examples import igf_example
from y0.examples import igf_example # type: ignore[attr-defined]
from y0.graph import set_latent

X1, X2, X3 = map(Variable, ["X1", "X2", "X3"])
Expand Down
12 changes: 7 additions & 5 deletions tests/test_algorithm/test_tian_pearl_identify.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,13 +675,15 @@ class TestComputeCFactorMarginalizingOverTopologicalSuccessors(cases.GraphTestCa
expected_result_2 = Fraction(expected_result_2_num, expected_result_2_den)

# Same thing, but with population probabilities
# @cthoyt mypy throws an error for each appearance of PP[] below:
# error: Value of type "type[PopulationProbabilityBuilderType]" is not indexable [index]
result_piece_pp = Product.safe(
[
PP[Pi1](W1),
PP[Pi1](W3 | W1),
PP[Pi1](W2 | (W3, W1)),
PP[Pi1](X | (W1, W3, W2, W4)),
PP[Pi1](Y | (W1, W3, W2, W4, X)),
PP[Pi1](W1), # type: ignore[index]
PP[Pi1](W3 | W1), # type: ignore[index]
PP[Pi1](W2 | (W3, W1)), # type: ignore[index]
PP[Pi1](X | (W1, W3, W2, W4)), # type: ignore[index]
PP[Pi1](Y | (W1, W3, W2, W4, X)), # type: ignore[index]
]
)
expected_result_1_part_1_pp = Fraction(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_algorithm/test_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
Z,
Zero,
)
from y0.examples import tikka_trso_figure_8_graph as tikka_trso_figure_8
from y0.examples import ( # type: ignore[attr-defined]
tikka_trso_figure_8_graph as tikka_trso_figure_8,
)
from y0.graph import NxMixedGraph
from y0.mutate import canonicalize, fraction_expand

Expand Down
8 changes: 6 additions & 2 deletions tests/test_causaleffect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pyparsing

from y0.dsl import P, Q, Sum, Variable
from y0.examples import examples, verma_1
from y0.examples import examples, verma_1 # type: ignore[attr-defined]

try:
from rpy2 import robjects
Expand Down Expand Up @@ -38,7 +38,11 @@ def setUpClass(cls) -> None:
except Exception as e:
raise unittest.SkipTest(f"R packages not properly installed.\n\n{e}") from None

cls.assertIsNotNone(robjects, msg="make sure this was imported correctly.")
cls.assertIsNotNone(
unittest.TestCase(), robjects, msg="make sure this was imported correctly."
)
# callahanr: @cthoyt Please take a look at this, adding the first parameter is a fix to correct a mypy error:
# error: Missing positional argument "obj" in call to "assertIsNotNone" of "TestCase" [call-arg]

def test_verma_constraint(self):
"""Test getting the single Verma constraint from the Figure 1A graph."""
Expand Down
8 changes: 7 additions & 1 deletion tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from pgmpy.models import DiscreteBayesianNetwork

from y0.dsl import V1, V2, V3, V4, A, B, C, D, M, Variable, X, Y, Z
from y0.examples import SARS_SMALL_GRAPH, Example, examples, napkin, verma_1
from y0.examples import ( # type: ignore[attr-defined]
SARS_SMALL_GRAPH,
Example,
examples,
napkin,
verma_1,
)
from y0.graph import (
ANANKE_AVAILABLE,
ANANKE_REQUIRED,
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ dependency_groups =
typing
commands =
mypy --ignore-missing-imports --strict src/
mypy --ignore-missing-imports src/ tests/test_dsl.py tests/test_algorithm/test_counterfactual_transportability.py
mypy --ignore-missing-imports src/ tests/

[testenv:docs-lint]
skip_install = true
Expand Down
Loading