Skip to content

Commit

Permalink
fixed incorrect summation variables
Browse files Browse the repository at this point in the history
  • Loading branch information
santikka committed Nov 27, 2023
1 parent 56d8367 commit 83e85ba
Show file tree
Hide file tree
Showing 13 changed files with 81 additions and 72 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: cfid
Type: Package
Title: Identification of Counterfactual Queries in Causal Models
Version: 0.1.6
Version: 0.1.7
Authors@R: c(
person(given = "Santtu",
family = "Tikka",
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# cfid 0.1.7

* Fixed some formulas having incorrect variables indicated as summation variables.

# cfid 0.1.6

* Summation variables are now properly distinguished from query variables in the output formulas of `identifiable()`.
Expand Down
26 changes: 19 additions & 7 deletions R/algorithms.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,36 @@ id_star <- function(g, gamma) {
if (n_comp > 1L) {
# Line 6
c_factors <- vector(mode = "list", length = n_comp)
free_vars <- vector(mode = "list", length = n_comp)
form_terms <- vector(mode = "list", length = n_comp)
nonid_factors <- rep(TRUE, n_comp)
prob_zero <- FALSE
for (i in seq_len(n_comp)) {
s_var <- vars(comp[[i]])
s_sub <- setdiff(v_g, s_var)
for (j in seq_along(comp[[i]])) {
n_terms <- length(comp[[i]])
sub_new <- vector(mode = "list", length = n_terms)
for (j in seq_len(n_terms)) {
gamma_val <- val(comp[[i]][[j]], gamma_prime)
comp[[i]][[j]]$obs <- ifelse_(is.null(gamma_val), 0L, gamma_val)
sub_var <- names(comp[[i]][[j]]$sub)
s_sub_j <- setdiff(s_sub, sub_var)
s_len <- length(s_sub_j)
if (s_len > 0) {
sub_new <- set_names(integer(s_len), s_sub_j)
sub_new[[j]] <- set_names(integer(s_len), s_sub_j)
obs_ix <- which(gamma_obs_var %in% s_sub_j)
if (length(obs_ix) > 0) {
s_val <- unlist(evs(gamma_obs)[obs_ix])
sub_new[names(s_val)] <- s_val
sub_new[[j]][names(s_val)] <- s_val
}
comp[[i]][[j]]$sub <- c(comp[[i]][[j]]$sub, sub_new)
comp[[i]][[j]]$sub <- c(comp[[i]][[j]]$sub, sub_new[[j]])
}
}
sumset <- setdiff(v_g, gamma_var)
sub_new_reduce <- names(
Reduce(function(x, y) intersect(names(x), names(y)), sub_new)
)
free_vars[[i]] <- intersect(sumset, union(sub_new_reduce, s_var))
s_conj <- try(
do.call(counterfactual_conjunction, comp[[i]]), silent = TRUE
)
Expand All @@ -72,13 +82,15 @@ id_star <- function(g, gamma) {
c_factors[[i]]$formula$val == 0L) {
return(list(id = TRUE, formula = probability(val = 0L)))
}
if (c_factors[[i]]$id) {
form_terms[[i]] <- c_factors[[i]]$formula
attr(form_terms[[i]], "free_vars") <- free_vars[[i]]
nonid_factors[i] <- FALSE
}
}
nonid_factors <- !vapply(c_factors, "[[", logical(1L), "id")
if (any(nonid_factors)) {
return(list(id = FALSE, formula = NULL))
}
sumset <- setdiff(v_g, gamma_var)
form_terms <- lapply(c_factors, "[[", "formula")
if (length(sumset) > 0L) {
form_out <- functional(
sumset = lapply(sumset, function(x) cf(var = x, obs = 0L)),
Expand Down
9 changes: 7 additions & 2 deletions R/cf_variable.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,13 @@ is.counterfactual_variable <- function(x) {
#' @rdname counterfactuals
#' @param x A `counterfactual_variable` or a `counterfactual_conjunction`
#' object.
#' @param use_primes A `logical` value indicating whether primes should be
#' used to differentiate between value assignments
#' @param use_primes A `logical` value. If `TRUE` (the default), any value
#' assignment of a counterfactual variable with `obs` will be formatted with
#' as many primes in the superscript as the value of `obs`, e.g.,
#' `obs = 0` outputs `"y"`, `obs = 1` outputs `"y'"`,
#' `obs = 2` outputs `"y''"` and so forth. The alternative when `FALSE` is
#' to simply denote the `obs` value via superscript directly as
#' `"y^{(obs)}"`, where obs is evaluated.
#' @export
format.counterfactual_variable <- function(x, use_primes = TRUE, ...) {
super_var <- character(0L)
Expand Down
29 changes: 11 additions & 18 deletions R/functional.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,12 @@ is.functional <- function(x) {

#' @rdname functional
#' @param x A `functional` object.
#' @param use_primes A `logical` value. If `TRUE` (the default), any value
#' assignment of a counterfactual variable with `obs` will be formatted with
#' as many primes in the superscript as the value of `obs`, e.g.,
#' `obs = 0` outputs `"y"`, `obs = 1` outputs `"y'"`,
#' `obs = 2` outputs `"y''"` and so forth. The alternative when `FALSE` is
#' to simply denote the `obs` value via superscript directly as
#' `"y^{(obs)}"`, where obs is evaluated.
#' @param use_do A `logical` value. If `TRUE`, the explicit do-operation is
#' used to denote interventional probabilities (e.g., \eqn{P(y|do(x))}).
#' If `FALSE` (the default), the subscript notation is used instead
#' (e.g., \eqn{P_x(y)}).
#' @param ... Additional arguments passed to `format`.
#' @return A `character` representation of the `functional` object
#' in LaTeX syntax.
#'
#' @export
format.functional <- function(x, use_primes = TRUE, use_do = FALSE, ...) {
format.functional <- function(x, ...) {
terms <- ""
sumset <- ""
fraction <- ""
Expand All @@ -104,12 +93,12 @@ format.functional <- function(x, use_primes = TRUE, use_do = FALSE, ...) {
if (length(x$sumset) > 0) {
sumset <- paste0(
"\\sum_{",
comma_sep(vapply(x$sumset, format, character(1L), use_primes)),
comma_sep(vapply(x$sumset, format, character(1L), ...)),
"} "
)
}
if (!is.null(x$terms)) {
terms <- vapply(x$terms, format, character(1L), use_primes, use_do)
terms <- vapply(x$terms, format, character(1L), ...)
sums <- vapply(
x$terms,
function(y) { is.functional(y) && length(y$sumset) > 0 },
Expand All @@ -121,13 +110,13 @@ format.functional <- function(x, use_primes = TRUE, use_do = FALSE, ...) {
terms <- collapse(terms)
} else if (!is.null(x$numerator)) {
if (length(x$denominator$val) > 0L && x$denominator$val == 1L) {
fraction <- format(x$numerator, use_primes, use_do)
fraction <- format(x$numerator, ...)
} else {
fraction <- paste0(
"\\frac{",
format(x$numerator, use_primes, use_do),
format(x$numerator, ...),
"}{",
format(x$denominator, use_primes, use_do),
format(x$denominator, ...),
"}"
)
}
Expand Down Expand Up @@ -171,8 +160,12 @@ assign_values <- function(x, bound, v, termwise = FALSE) {
if (termwise) {
v_term <- unlist(c(evs(x$var), evs(x$cond), evs(x$do)))
v[names(v_term)] <- v_term
bind <- bound > 0 & v_names %in% attr(x, "free_vars")
attr(x, "free_vars") <- NULL
} else {
bind <- bound > 0
}
v[bound > 0] <- -bound[bound > 0]
v[bind] <- -bound[bind]
var <- vars(x$var)
cond <- vars(x$cond)
do <- vars(x$do)
Expand Down
9 changes: 5 additions & 4 deletions R/identifiable.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
#' * `formula`\cr An object of class `functional` giving the identifying
#' functional of the query in LaTeX syntax via `format` or `print`,
#' if identifiable. This expression is given in terms of the
#' available `data`. For tautological statements, the resulting
#' available `data`. Variables bound by summation are distinguished by a
#' superscript asterisk. For tautological statements, the resulting
#' probability is 1, and for inconsistent statements, the resulting
#' probability is 0. For formatting options, see
#' [cfid::format.functional()] and [cfid::format.probability()].
Expand Down Expand Up @@ -139,9 +140,6 @@ identifiable <- function(g, gamma, delta = NULL,
functional(terms = list(out$formula)),
out$formula
)
if (out$id && data != "interventions") {
out <- identify_terms(out$formula, data, g)
}
if (out$id) {
n_obs <- sum(!attr(g, "latent"))
v <- set_names(integer(n_obs), attr(g, "labels")[!attr(g, "latent")])
Expand All @@ -152,6 +150,9 @@ identifiable <- function(g, gamma, delta = NULL,
bound[query_vars] <- bound[query_vars] + 1L
out$formula <- assign_values(out$formula, bound, v, termwise = TRUE)
}
if (out$id && data != "interventions") {
out <- identify_terms(out$formula, data, g)
}
out$undefined <- ifelse_(is.null(out$undefined), FALSE, out$undefined)
out$counterfactual <- TRUE
out$gamma <- gamma
Expand Down
15 changes: 4 additions & 11 deletions R/probability.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,6 @@ is.probability <- function(x) {
#' @method format probability
#' @rdname probability
#' @param x A `probability` object.
#' @param use_primes A `logical` value. If `TRUE` (the default), any value
#' assignment of a counterfactual variable with `obs` will be formatted with
#' as many primes in the superscript as the value of `obs`, e.g.,
#' `obs = 0` outputs `"y"`, `obs = 1` outputs `"y'"`,
#' `obs = 2` outputs `"y''"` and so forth. The alternative when `FALSE` is
#' to simply denote the `obs` value via superscript directly as
#' `"y^{(obs)}"`, where obs is evaluated.
#' @param use_do A `logical` value. If `TRUE`, the explicit do-operation is
#' used to denote interventional probabilities (e.g., \eqn{P(y|do(x))}).
#' If `FALSE` (the default), the subscript notation is used instead
Expand Down Expand Up @@ -98,7 +91,7 @@ is.probability <- function(x) {
#' format(f, use_primes = FALSE, use_do = TRUE)
#'
#' @export
format.probability <- function(x, use_primes = TRUE, use_do = FALSE, ...) {
format.probability <- function(x, use_do = FALSE, ...) {
if (length(x$val) > 0L) {
return(as.character(x$val))
}
Expand All @@ -109,7 +102,7 @@ format.probability <- function(x, use_primes = TRUE, use_do = FALSE, ...) {
any_do <- length(x$do) > 0L
any_cond <- length(x$cond) > 0L
if (any_do) {
form_do <- comma_sep(vapply(x$do, format, character(1L), use_primes))
form_do <- comma_sep(vapply(x$do, format, character(1L), ...))
if (!use_do) {
sub <- paste0("_{", form_do, "}")
} else {
Expand All @@ -118,13 +111,13 @@ format.probability <- function(x, use_primes = TRUE, use_do = FALSE, ...) {
}
if (any_cond) {
cond <- paste0(
comma_sep(vapply(x$cond, format, character(1L), use_primes))
comma_sep(vapply(x$cond, format, character(1L), ...))
)
}
if ((any_do && use_do) || any_cond) {
rhs <- paste0("|", do, cond)
}
var <- paste0(comma_sep(vapply(x$var, format, character(1L), use_primes)))
var <- paste0(comma_sep(vapply(x$var, format, character(1L), ...)))
paste0("P", sub, "(", var, rhs, ")")
}

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ coverage](https://codecov.io/gh/santikka/cfid/branch/main/graph/badge.svg)](http
version](http://www.r-pkg.org/badges/version/cfid)](https://CRAN.R-project.org/package=cfid)
<!-- badges::end -->

#> Warning: package 'cfid' was built under R version 4.3.2

## Overview

This package facilitates the identification of counterfactual queries in
Expand Down
10 changes: 1 addition & 9 deletions man/Probability.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 7 additions & 2 deletions man/counterfactuals.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 1 addition & 14 deletions man/functional.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/identifiable.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 17 additions & 3 deletions tests/testthat/test-probability.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ id5 <- identifiable(g2, v5, v7)
id6 <- identifiable(g2, v5, v7, data = "obs")
id7 <- identifiable(g1, conj(v1, v2, v3))

g3 <- dag("X -> Z -> Y")
v8 <- cf("Z", 0, c("X" = 0))
v9 <- cf("Y", 1)
id8 <- identifiable(g3, v8, v9)
id9 <- identifiable(g3, v8, v9, data = "obs")

# Format ------------------------------------------------------------------

test_that("probability format works", {
Expand Down Expand Up @@ -67,15 +73,23 @@ test_that("probability format works", {
)
expect_identical(
format(id5$formula),
"\\frac{\\sum_{x^*} P_{x^*}(y)P(x^*)P_{x^*,y}(z)}{\\sum_{x^*,y^*} P(x^*)P_{x^*,y^*}(z)P_{x^*}(y^*)}"
"\\frac{\\sum_{x^*} P_{x}(y)P(x^*)P_{x^*,y}(z)}{\\sum_{x^*,y^*} P(x^*)P_{x^*,y^*}(z)P_{x^*}(y^*)}"
)
expect_identical(
format(id6$formula),
"\\frac{\\sum_{x^*} P(y|x^*)P(x^*)P(z|x^*,y)}{\\sum_{x^*,y^*} P(x^*)P(z|x^*,y^*)P(y^*|x^*)}"
"\\frac{\\sum_{x^*} P(y|x)P(x^*)P(z|x^*,y)}{\\sum_{x^*,y^*} P(x^*)P(z|x^*,y^*)P(y^*|x^*)}"
)
expect_identical(
format(id7$formula),
"\\sum_{w,d^*} P_{x}(w)P_{w,z}(y,x')P_{d^*}(z)P(d^*)"
"\\sum_{w,d^*} P_{x}(w)P_{w,z}(y,x')P_{d}(z)P(d^*)"
)
expect_identical(
format(id8$formula),
"\\frac{\\sum_{x^*} P_{x}(z)P(x^*)P_{z}(y')}{\\sum_{x^*,z^*} P(x^*)P_{x^*}(z^*)P_{z^*}(y')}"
)
expect_identical(
format(id9$formula),
"\\frac{\\sum_{x^*} P(z|x)P(x^*)P(y'|x,z)}{\\sum_{x^*,z^*} P(x^*)P(z^*|x^*)P(y'|x,z^*)}"
)
})

Expand Down

0 comments on commit 83e85ba

Please sign in to comment.