From 68ec27d2260107dc36b7dd2487ef6859fecd8b21 Mon Sep 17 00:00:00 2001 From: Ashwin Paranjape Date: Sun, 12 Jan 2025 16:44:22 -0800 Subject: [PATCH] Modify binary and continuous single IV case to accept nans --- .../instrumental_variable_estimator.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dowhy/causal_estimators/instrumental_variable_estimator.py b/dowhy/causal_estimators/instrumental_variable_estimator.py index 6570f0a540..83902911d2 100755 --- a/dowhy/causal_estimators/instrumental_variable_estimator.py +++ b/dowhy/causal_estimators/instrumental_variable_estimator.py @@ -139,17 +139,18 @@ def estimate_effect( instrument_is_binary = num_unique_values <= 2 if instrument_is_binary: # Obtain estimate by Wald Estimator - y1_z = np.mean(data[self._target_estimand.outcome_variable[0]][instrument == 1]) - y0_z = np.mean(data[self._target_estimand.outcome_variable[0]][instrument == 0]) - x1_z = np.mean(data[self._target_estimand.treatment_variable[0]][instrument == 1]) - x0_z = np.mean(data[self._target_estimand.treatment_variable[0]][instrument == 0]) + y1_z = np.nanmean(data[self._target_estimand.outcome_variable[0]][instrument == 1]) + y0_z = np.nanmean(data[self._target_estimand.outcome_variable[0]][instrument == 0]) + x1_z = np.nanmean(data[self._target_estimand.treatment_variable[0]][instrument == 1]) + x0_z = np.nanmean(data[self._target_estimand.treatment_variable[0]][instrument == 0]) num = y1_z - y0_z deno = x1_z - x0_z iv_est = num / deno else: # Obtain estimate by 2SLS estimator: Cov(y,z) / Cov(x,z) - num_yz = np.cov(data[self._target_estimand.outcome_variable[0]], instrument)[0, 1] - deno_xz = np.cov(data[self._target_estimand.treatment_variable[0]], instrument)[0, 1] + masked_data = np.ma.masked_invalid(data[[self._target_estimand.outcome_variable[0], self._target_estimand.treatment_variable[0], self.estimating_instrument_names[0]]]) + num_yz = np.ma.cov(masked_data[:, 0], masked_data[: ,2])[0, 1] + deno_xz = np.ma.cov(masked_data[:, 1], masked_data[:, 2])[0, 1] iv_est = num_yz / deno_xz else: # More than 1 instrument. Use 2sls.