Skip to content

Commit

Permalink
Provide recalc_graphs for exact edge effects are graph shrinking.
Browse files Browse the repository at this point in the history
  • Loading branch information
karlb committed Oct 12, 2022
1 parent 7520a10 commit d7854b3
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions causing/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,45 @@ def remove_node_keep_edges(graph, node):
graph.add_edge(a, b, effect=new_edge_effect)

graph.remove_node(node)


def recalc_graphs(graphs, model, xdat) -> Iterable[networkx.DiGraph]:
"""Recalculate node and edge effects in graph.
Do this after modifying the graphs (typically with `remove_node_keep_edges`)
to calculate exact effects.
`graphs` must be in the format generated by `annotated_graphs` and in the
same order as individuals within `xdat`.
"""
yhat = model.compute(xdat)
yhat_mean = np.mean(yhat, axis=1)
xdat_mean = np.mean(xdat, axis=1)

for i, approx_graph in enumerate(graphs):
individual_xdat = xdat[:, i : i + 1]
removed_nodes = set(model.graph.nodes) - set(approx_graph.nodes)

# Calc effects on shrunken model
individual_model = model.shrink(removed_nodes)
effects = individual_model.calc_effects(
individual_xdat,
xdat_mean=xdat_mean,
yhat_mean=yhat_mean[[yvar not in removed_nodes for yvar in model.yvars]],
)

# Get graph for shrunken model
[g] = annotated_graphs(
individual_model,
effects,
node_labels={
n: data["label"]
for n, data in approx_graph.nodes(data=True)
if "label" in data
},
)
for xvar in set(model.xvars) & removed_nodes & g.nodes():
g.remove_node(xvar)

# Preserve graph attributes
g.graph = approx_graph.graph
yield g

0 comments on commit d7854b3

Please sign in to comment.