Tutorial for R Users

In this introductory tutorial, we go through the key steps of a CausalEGM workflow with the R package RcausalEGM. Two Markdown vignettes are also provided in the R source code.

First of all, you need to install CausalEGM, please refer to the install page to make sure RcausalEGM R package is installed.

The CRAN package repository provides an detailed reference manual.

Note

RcausalEGM R package was developed with reticulate package, which enables building R APIs based on Python APIs. Installing RcausalEGM will automatically install the dependent reticulate package.

Run the py_config() after loading reticulate to get the Python configuration information, detailed configuration can be found in reticulate website.

[1]:
library(reticulate)
#py_config()
[2]:
library(RcausalEGM)

Binary treatment

We first give a step-by-step tutorial for implementing CausalEGM under binary treatment settings.

Data generation

Let’s first generate a simulation dataset with binary treatment. Users could also define their own data, which include treatment (x), outcome (y), and covariates (v).

[3]:
n <- 500
p <- 20
v <- matrix(rnorm(n * p), n, p)
x <- rbinom(n, 1, 0.3 + 0.2 * (v[, 1] > 0))
y <- pmax(v[, 1], 0) * x + v[, 2] + pmin(v[, 3], 0) + rnorm(n)

Let’s take a look at the simulation data.

[4]:
par(cex=0.7, mai=c(0.1,0.1,0.2,0.1))
par(fig=c(0.1,0.5,0.5,1.0))
slices <- c(sum(x==1), sum(x==0))
lbls <- c(paste("T group:",round(sum(x==1)*100/length(x), 2), "%", sep=""), paste("C group:",round(sum(x==0)*100/length(x), 2), "%", sep=""))
pie(slices, labels = lbls, main="Treatment Variables")
par(fig=c(0.5,1,0.5,1), new=TRUE)
hist(y, breaks=12, col="red",xlab="y values")
par(fig=c(0.1,1.0,0.1,0.4), new=TRUE)
boxplot(v, main="Distribution of covariates", xlab="Covariate index", ylab="Values")
_images/tutorial_r_9_0.png

Model training

Start training a CausalEGM model. Users can refer to the core API “causalegm” by help(causalegm) for detailed usage.

Note that the parameters for treatment (x), outcome (y), and covariates (v) are required. Besides, users can also specify the z_dims as a integer list with four elements denoting the latent dimensions.

Important

  1. causalegm is the main R API for the Rcausalegm R package.

  2. causalegm API includes CausalEGM model initilization and model training.

  3. The parameters for causalegm are consistent with the Python APIs.

[5]:
#help(causalegm)
model <- causalegm(x=x,y=y,v=v,
                   z_dims = c(3,3,6,6),
                   n_iter=200)

Note

If the above command stops with warning: Please install the CausalEGM package using the function: install_causalegm(). It means that CausalEGM is not installed in Python. Please run install_causalegm() in R or use command line pip install CausalEGM to make sure that CausalEGM is installed and py_module_available('CausalEGM') will return True.

Treatment Effect Estimation

After the above model training, users can find the .txt format of individual treatment effect (ITE) estimates in the “output_dir” directory (parameter in “causalegm”) if save_res is set to TRUE.

Alternatively, several keys estimates, including average treatment effect and individual treatment effect can be directly obtained from the trained model.

[6]:
ATE <- model$ATE
paste("The average treatment effect (ATE) is", round(ATE, 3))
ITE <- model$best_causal_pre
boxplot(ITE, main="ITE distribution", ylab="Values")
'The average treatment effect (ATE) is -0.051'
_images/tutorial_r_15_1.png

Predicting Counterfactual Outcome

Besides ATE and ITE estimation, we also provide APIs for predicting the counterfactual outcome directly.

[7]:
x_cf <- 1-x
y_cf <- get_est(model, v, x_cf)
boxplot(y_cf, main="Counterfactual outcome", ylab="Values")
_images/tutorial_r_17_0.png

Estimating Conditional Average Treatment Effect (CATE)

We demonstrate how to estimate CATE by an external data after training the model.

[8]:
n_test <- 100
v_test <- matrix(rnorm(n_test * p), n_test, p)
CATE <- get_est(model, v_test)
boxplot(CATE, main="CATE", ylab="Values")
_images/tutorial_r_19_0.png

Continuous treatment

We first give a step-by-step tutorial for implementing CausalEGM under continuous treatment settings. We use the Hirano and Imbens simulation dataset for an example.

[9]:
n <- 20000
p <- 200
v <- matrix(rexp(n * p), n, p)
rate <- v[,1] + v[,2]
scale = 1/rate
x = rexp(n=n, rate=rate)
y = rnorm(n=n, mean = x + (v[,1] + v[,3])*exp(-x * (v[,1] + v[,3])), sd=1)

Let’s take a look at the simulation data.

[10]:
par(cex=0.7, mai=c(0.1,0.1,0.2,0.1))
par(fig=c(0.1,0.5,0.5,1.0))
hist(x, breaks="FD", xlim=c(0,7), col="blue",xlab="x values")
par(fig=c(0.5,1,0.5,1), new=TRUE)
hist(y, breaks="FD", xlim=c(0,7), col="red",xlab="y values")
par(fig=c(0.1,1.0,0.1,0.4), new=TRUE)
boxplot(v[,1:20],main="First 20 covariates", xlab="Covariate index", ylab="v values")

_images/tutorial_r_23_0.png

Model training

Start training a CausalEGM model. Users can refer to the core API “causalegm” by help(causalegm) for detailed usage.

Note that the parameters for x, y, v are required. Besides, users can also specify the z_dims as a integer list with four elements.

(x_min, x_max) is the interval of treatment on which we evaluate the causal effect.

The key binary_treatment is set to be FALSE as we are dealing with continuous treatment.

[11]:
#help(causalegm)
model <- causalegm(x=x,y=y,v=v,
                   z_dims = c(1,1,1,7),
                   n_iter=20000,
                   binary_treatment = FALSE,
                   use_v_gan = FALSE,
                   x_min=0,
                   x_max=3)

Treatment Effect Estimation

The average causal effect evaluated at the 200 uniform points in the interval (x_min, x_max) can be directly obtained from the trained model.

[12]:
ACE <- model$best_causal_pre
boxplot(ACE, main="ACE distribution", ylab="Values")
_images/tutorial_r_27_0.png

Plot of the average dose-response function (ADRF)

We now compare plot of the true ADRF and the ADRF estimated by our model.

[13]:
grid_val = seq(0, 3, length.out=200)
true_effect = grid_val + 2 * (1+grid_val)**(-3)
plot(true_effect, col=2, xlab="Treatment", ylab="Average Dose Response")
lines(ACE, col=3)
legend(x = "topleft",
       legend = c("True ADRF", "Estimated ADRF"),
       lty = c(1, 1),
       col = c(2, 3),
       lwd = 2)
_images/tutorial_r_29_0.png
[14]:
sessionInfo()
R version 3.6.3 (2020-02-29)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS  10.16

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base

other attached packages:
[1] RcausalEGM_0.3.3

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.7      uuid_1.0-3      lattice_0.20-38 here_1.0.1
 [5] rlang_1.1.0     fastmap_1.1.0   fansi_0.5.0     tools_3.6.3
 [9] grid_3.6.3      png_0.1-7       utf8_1.2.2      cli_3.3.0
[13] withr_2.4.3     htmltools_0.5.2 ellipsis_0.3.2  rprojroot_2.0.2
[17] digest_0.6.29   lifecycle_1.0.1 crayon_1.4.2    Matrix_1.3-4
[21] IRdisplay_1.1   repr_1.1.6      base64enc_0.1-3 vctrs_0.4.1
[25] IRkernel_1.3.2  evaluate_0.15   pbdZMQ_0.3-9    compiler_3.6.3
[29] pillar_1.6.4    reticulate_1.28 jsonlite_1.7.2