Introduction to MLCausal

MLCausal Team

2026-04-08

Overview

MLCausal provides a tidy, end-to-end pipeline for causal inference in clustered (multilevel) observational data — students in schools, patients in hospitals, employees in firms.

The standard API uses five consistent argument names across every function:

Argument Meaning
treatment Name of the 0/1 treatment variable
outcome Name of the outcome variable
covariates Character vector of covariate names
cluster Name of the cluster identifier
weights Name of the weight variable

The main workflow:

simulate_ml_data()  →  ml_ps()  →  ml_weight() or ml_match()
  →  balance_ml()  →  estimate_att_ml()  →  sens_ml()

1. Simulate Clustered Data

library(MLCausal)

dat <- simulate_ml_data(n_clusters = 20, cluster_size = 25,
                        n_min = 10, seed = 42)
head(dat)
#>   school_id         x1 x2          x3 z        y
#> 1         1  0.1602218  1  1.19316091 1 3.601848
#> 2         1 -1.3408750  0  1.87422963 1 2.140405
#> 3         1 -0.5094501  0  1.19309386 1 2.490992
#> 4         1 -3.4929064  0  1.98230859 0 1.345384
#> 5         1  1.4987760  1 -0.80447377 1 2.539185
#> 6         1  0.6252433  0  0.07808106 0 1.972232
table(dat$z)
#> 
#>   0   1 
#> 300 225

The true ATT is approximately 0.5 (slightly above because treated units are over-represented in high-effect clusters).


2. Estimate Propensity Scores

ps <- ml_ps(
  data      = dat,
  treatment = "z",
  covariates = c("x1", "x2", "x3"),
  cluster   = "school_id",
  method    = "mundlak",
  estimand  = "ATT"
)
print(ps)
#> MLCausal propensity score model
#>   Method:   mundlak
#>   Estimand: ATT
#>   N:        525
#>   Clusters: 20

3. Check Overlap

plot_overlap_ml(ps)

Propensity score overlap plot


4. Build Inverse Probability Weights

dat_w <- ml_weight(ps, estimand = "ATT", stabilize = TRUE, trim = 10)
#> 1 weight(s) Winsorised to upper bound of 10.
summary(dat_w$weights)
#>      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
#>  0.006576  0.245367  1.000000  0.767135  1.000000 10.000000

5. Check Balance

bal <- balance_ml(
  data       = dat_w,
  treatment  = "z",
  covariates = c("x1", "x2", "x3"),
  cluster    = "school_id",
  weights    = "weights"
)
print(bal)
#> MLCausal balance diagnostics
#> 
#> Individual-level SMDs:
#>  variable      level     smd
#>        x1 individual -0.0789
#>        x2 individual -0.0004
#>        x3 individual -0.0927
#> 
#> Cluster-mean SMDs:
#>  variable        level    smd
#>        x1 cluster_mean 1.1713
#>        x2 cluster_mean 0.0193
#>        x3 cluster_mean 1.6173

Individual-level SMDs are numeric. Cluster-mean SMDs are character strings: either a formatted number or a descriptive message if not estimable.


6. Estimate the ATT

est <- estimate_att_ml(
  data       = dat_w,
  outcome    = "y",
  treatment  = "z",
  cluster    = "school_id",
  covariates = c("x1", "x2", "x3"),
  weights    = "weights"
)
print(est)
#> MLCausal treatment effect estimate
#>   Estimate: 0.8196
#>   SE:       0.137
#>   p-value:  4.069e-09

7. Sensitivity Analysis

sens <- sens_ml(estimate = est$estimate, se = est$se)
sens[sens$crosses_null, ][1, ]
#>    confounder_strength adjusted_estimate original_z adjusted_z crosses_null
#> 42                 4.1         0.2579827   5.983486   1.883486         TRUE

Alternative: Dual-Balance Matching

The lambda parameter in ml_match() is the core innovation: it adds a cluster-mean balance penalty to the matching distance, so matches are chosen to improve balance at both the individual and cluster-mean levels simultaneously.

# lambda = 0  → standard PS matching
# lambda = 1  → equal weight on PS distance and cluster-mean balance (default)
# lambda > 1  → prioritise cluster-mean balance over PS proximity

matched <- ml_match(ps, ratio = 1, caliper = 0.5, lambda = 1)
#> 92 treated unit(s) could not be matched and were excluded from the matched sample.
print(matched)
#> MLCausal matched sample (dual-balance)
#>   Matched rows:    266
#>   Matched sets:    133
#>   Clusters used:   19
#>   Unmatched (trt): 92
#>   Lambda (balance weight): 1
#>   Caliper (logit-PS):      0.5

bal_m <- balance_ml(
  data       = matched$data_matched,
  treatment  = "z",
  covariates = c("x1", "x2", "x3"),
  cluster    = "school_id",
  weights    = "match_weight"
)
print(bal_m)
#> MLCausal balance diagnostics
#> 
#> Individual-level SMDs:
#>  variable      level    smd
#>        x1 individual 0.0059
#>        x2 individual 0.0301
#>        x3 individual 0.0417
#> 
#> Cluster-mean SMDs:
#>  variable        level
#>        x1 cluster_mean
#>        x2 cluster_mean
#>        x3 cluster_mean
#>                                                                                           smd
#>  Cluster-level SMD not estimable due to insufficient within-cluster variation after matching.
#>  Cluster-level SMD not estimable due to insufficient within-cluster variation after matching.
#>  Cluster-level SMD not estimable due to insufficient within-cluster variation after matching.

est_m <- estimate_att_ml(
  data       = matched$data_matched,
  outcome    = "y",
  treatment  = "z",
  cluster    = "school_id",
  covariates = c("x1", "x2", "x3"),
  weights    = "match_weight"
)
print(est_m)
#> MLCausal treatment effect estimate
#>   Estimate: 0.5177
#>   SE:       0.1203
#>   p-value:  2.372e-05

Compare bal (weighting) and bal_m (dual-balance matching): the matching approach should show lower cluster-mean SMDs, demonstrating the benefit of the composite distance.


Summary

Step Function Key argument change
Simulate simulate_ml_data() n_min prevents tiny clusters
PS model ml_ps() treatment = (not treat)
Weight ml_weight() output column is weights
Match ml_match() lambda = for dual-balance
Balance balance_ml() treatment =; cluster SMD is string not NA
Estimate estimate_att_ml() treatment =, weights =
Sensitivity sens_ml() default q extended to 5