Cram Policy part 2

Cram Policy

In the article “Introduction & Cram Policy Part 1”, we introduced the Cram method, which enables simultaneous learning and evaluation of a binary policy. We outlined the primary parameters of the cram_policy() function and demonstrated its application using an example dataset. In this section, we provide a more detailed discussion on how to configure these parameters for various use cases, depending notably on the nature of the dataset and specific policy learning goals.

We begin by simulating a dataset consisting of covariates X, a binary treatment assignment D, and a continuous outcome Y, which we will use to demonstrate the cram_policy() function.

library(data.table)
# Function to generate sample data with heterogeneous treatment effects:
# - Positive effect group
# - Neutral effect group
# - Adverse effect group
generate_data <- function(n) {
  X <- data.table(
    binary = rbinom(n, 1, 0.5),                 # Binary variable
    discrete = sample(1:5, n, replace = TRUE),  # Discrete variable
    continuous = rnorm(n)                       # Continuous variable
  )

  # Binary treatment assignment (50% treated)
  D <- rbinom(n, 1, 0.5)

  # Define heterogeneous treatment effects based on X
  treatment_effect <- ifelse(
    X[, binary] == 1 & X[, discrete] <= 2,        # Group 1: Positive effect
    1,
    ifelse(X[, binary] == 0 & X[, discrete] >= 4, # Group 3: Adverse effect
           -1,
           0.1)                                   # Group 2: Neutral effect
  )

  # Outcome depends on treatment effect + noise
  Y <- D * (treatment_effect + rnorm(n, mean = 0, sd = 1)) +
    (1 - D) * rnorm(n)

  return(list(X = X, D = D, Y = Y))
}

# Generate a sample dataset
set.seed(123)
n <- 1000
data <- generate_data(n)
X <- data$X
D <- data$D
Y <- data$Y

Built-in Model

In this example, we demonstrate how to use the built-in modeling options provided by the cramR package. We walk through the key parameters that control how cram_policy() behaves, including the choice of model, learner type, baseline policy, and batching strategy.

These parameters allow flexibility in configuring the learning process depending on your use case and nature of your dataset.

# Options for batch:
# Either an integer specifying the number of batches or a vector/list of batch assignments for all individuals
batch <- 20

# Model type for estimating treatment effects
# Options for model_type: 'causal_forest', 's_learner', 'm_learner'
# Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model
model_type <- "causal_forest"  

# Options for learner_type:
# if model_type == 'causal_forest', choose NULL
# if model_type == 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'
learner_type <- NULL  

# Baseline policy to compare against (list of 0/1 for each individual)
# Options for baseline_policy:
# A list representing the baseline policy assignment for each individual.
# If NULL, a default baseline policy of zeros is created.
# Examples of baseline policy: 
# - All-control baseline: as.list(rep(0, nrow(X))) or NULL
# - Randomized baseline: as.list(sample(c(0, 1), nrow(X), replace = TRUE))
baseline_policy <- as.list(rep(0, nrow(X)))  

# Whether to parallelize batch processing (i.e. the cram method learns T policies, with T the number of batches.
# They are learned in parallel when parallelize_batch is TRUE
# vs. learned sequentially using the efficient data.table structure when parallelize_batch is FALSE, recommended for light weight training).
# Defaults to FALSE.
parallelize_batch <- FALSE  
 

# Model-specific parameters (more details in the article "Quick Start")
# Examples: NULL defaults to the following:
# - causal_forest: list(num.trees = 100)
# - ridge: list(alpha = 1)
# - caret: list(formula = Y ~ ., caret_params = list(method = "lm", trControl = trainControl(method = "none")))
# - fnn (Feedforward Neural Network): see below
# input_shape <- if (model_type == "s_learner") ncol(X) + 1 else ncol(X)
# default_model_params <- list(
#       input_layer = list(units = 64, activation = 'relu', input_shape = input_shape),
#       layers = list(
#         list(units = 32, activation = 'relu')
#       ),
#       output_layer = list(units = 1, activation = 'linear'),
#       compile_args = list(optimizer = 'adam', loss = 'mse'),
#       fit_params = list(epochs = 5, batch_size = 32, verbose = 0)
#     )
model_params <- NULL  

How to set model_params?

The model_params argument allows you to customize hyperparameters for the model used in policy learning. If left as NULL, cram_policy() will fall back to sensible defaults depending on the model and learner type. In the following, we give default examples of model_params for each model_type and learner_type to illustrate how you can specify model_params. Generally speaking, model_params is a list containing all of the parameters used in the underlying model and its values:

  • For model_type = "causal_forest":
    When model_params <- NULL, the method defaults to grf::causal_forest() with model_params <- list(num.trees = 100).

  • For model_type = "s_learner" or "m_learner", it depends on the learner_type:

    • For learner_type = "ridge", when model_params <- NULL, the method defaults to glmnet::cv.glmnet() with model_params <- list(alpha = 1), corresponding to Ridge regression.

    • For learner_type = "fnn" (Feedforward Neural Network):
      A default Keras model is built with the following architecture:

    # Determine the input shape based on model_type
    # For s_learner, the treatment D is added to the covariates to constitute the training data
    # So the input shape is ncol(X) + 1
    input_shape <- if (model_type == "s_learner") ncol(X) + 1 else ncol(X)
    
    default_model_params <- list(
        input_layer = list(units = 64, activation = 'relu', input_shape = input_shape), 
        layers = list(
          list(units = 32, activation = 'relu')
        ),
        output_layer = list(units = 1, activation = 'linear'),
        compile_args = list(optimizer = 'adam', loss = 'mse'),
        fit_params = list(epochs = 5, batch_size = 32, verbose = 0)
      )
    default_model_params <- list(formula = Y ~ ., caret_params = list(method = "lm", trControl = trainControl(method = "none")))

    Please note that the list should contain an element formula where Y refers to the vector that you provided as input for Y, and an element named caret_params containing the parameters of your choice to pass to caret::train (see full list of parameters here: https://topepo.github.io/caret/model-training-and-tuning.html).

Case of categorical target Y

When using caret, note that if Y is categorical and you are using model_type = "s_learner", you need to choose a classification method outputting probabilities i.e. using the key word classProbs = TRUE in trainControl, see the following as an example with a Random Forest Classifier:

model_params <- list(formula = Y ~ ., caret_params = list(method = "rf", trControl = trainControl(method = "none", classProbs = TRUE)))

Also note that all data inputs needs to be of numeric types, hence for Y categorical, it should contain numeric values representing the class of each observation. No need to use the type factor for cram_policy().

Note on M-learner: If model_type = "m_learner", keep in mind that Y is transformed internally. M-learner requires a propensity model and transformed outcomes:

  • propensity is an argument of cram_policy, if NULL, it defaults to
propensity <- function(X) {rep(0.5, nrow(X))}
  • outcome_transform is an element of model_params accessed as follows:
outcome_transform <- model_params$m_learner_outcome_transform

If NULL, it defaults to

outcome_transform <- function(Y, D, prop_score) {Y * D / prop_score - Y * (1 - D) / (1 - prop_score)}`, where `prop_score` is `propensity(X)

Thus, the Y transformed might not be categorical even though Y is categorical i.e. you might want to use a regression model.

See the following reference for more details about the M-learner:

Jia, Z., Imai, K., & Li, M. L. (2024). The Cram Method for Efficient Simultaneous Learning and Evaluation, URL: https://www.hbs.edu/ris/Publication%20Files/2403.07031v1_a83462e0-145b-4675-99d5-9754aa65d786.pdf.

You can override any of these defaults by passing a custom list to model_params including any parameter name defined in the underlying package of the model that you chose, namely grf::causal_forest(), glmnet::cv.glmnet(), keras and caret::train()

# Significance level for confidence intervals (default = 95%)
alpha <- 0.05  

# Run the CRAM policy method
result <- cram_policy(
  X, D, Y,
  batch = batch,
  model_type = model_type,
  learner_type = learner_type,
  baseline_policy = baseline_policy,
  parallelize_batch = parallelize_batch,
  model_params = model_params,
  alpha = alpha
)

# Display the results
print(result)
#> $raw_results
#>                        Metric   Value
#> 1              Delta Estimate 0.23208
#> 2        Delta Standard Error 0.05862
#> 3              Delta CI Lower 0.11718
#> 4              Delta CI Upper 0.34697
#> 5       Policy Value Estimate 0.21751
#> 6 Policy Value Standard Error 0.05237
#> 7       Policy Value CI Lower 0.11486
#> 8       Policy Value CI Upper 0.32016
#> 9          Proportion Treated 0.60500
#> 
#> $interactive_table
#> 
#> $final_policy_model
#> GRF forest object of type causal_forest 
#> Number of trees: 100 
#> Number of training samples: 1000 
#> Variable importance: 
#>     1     2     3 
#> 0.437 0.350 0.213

Custom Model

To use your own model for policy learning, the trigger is to set model_type = NULL and to supply two user-defined functions:

1. custom_fit(X, Y, D, ...)

This function takes the training data: covariates X, outcomes Y, and binary treatment indicators D, and returns a fitted model object.
You may also define and use additional parameters (e.g., number of folds, regularization settings, etc.) within the function body.

  • X: a matrix or data frame of features
  • Y: a numeric outcome vector
  • D: a binary vector indicating treatment assignment (0 or 1)

Example: Custom X-learner with Ridge regression and 5-fold cross-validation

library(glmnet)
#> Loading required package: Matrix
#> Loaded glmnet 4.1-8

custom_fit <- function(X, Y, D, n_folds = 5) {
  treated_indices <- which(D == 1)
  control_indices <- which(D == 0)
  X_treated <- X[treated_indices, ]
  Y_treated <- Y[treated_indices]
  X_control <- X[control_indices, ]
  Y_control <- Y[control_indices]
  model_treated <- cv.glmnet(as.matrix(X_treated), Y_treated, alpha = 0, nfolds = n_folds)
  model_control <- cv.glmnet(as.matrix(X_control), Y_control, alpha = 0, nfolds = n_folds)
  tau_control <- Y_treated - predict(model_control, as.matrix(X_treated), s = "lambda.min")
  tau_treated <- predict(model_treated, as.matrix(X_control), s = "lambda.min") - Y_control
  X_combined <- rbind(X_treated, X_control)
  tau_combined <- c(tau_control, tau_treated)
  weights <- c(rep(1, length(tau_control)), rep(1, length(tau_treated)))
  final_model <- cv.glmnet(as.matrix(X_combined), tau_combined, alpha = 0, weights = weights, nfolds = n_folds)
  return(final_model)
}

2. custom_predict(model, X_new, D_new)

This function uses the fitted model to generate a binary treatment decision for each individual in X_new.

It should return a vector of 0s and 1s, indicating whether to assign treatment (1) or not (0).
You may also incorporate a custom threshold or post-processing logic within the function.

Example: Apply the decision rule — treat if the estimated CATE is greater than 0

custom_predict <- function(model, X_new, D_new) {
  cate <- predict(model, as.matrix(X_new), s = "lambda.min")

  # Apply decision rule: treat if CATE > 0
  as.integer(cate > 0)
}

3. Use cram_policy() with custom_fit() and custom_predict()

Once both custom_fit() and custom_predict() are defined, you can integrate them into the Cram framework by passing them to cram_policy() as shown below (do not forget to set model_type = NULL):

experiment_results <- cram_policy(
  X, D, Y,
  batch = 20,
  model_type = NULL,
  custom_fit = custom_fit,
  custom_predict = custom_predict,
  alpha = 0.05
)
print(experiment_results)
#> $raw_results
#>                        Metric   Value
#> 1              Delta Estimate 0.22542
#> 2        Delta Standard Error 0.06004
#> 3              Delta CI Lower 0.10774
#> 4              Delta CI Upper 0.34310
#> 5       Policy Value Estimate 0.21085
#> 6 Policy Value Standard Error 0.04280
#> 7       Policy Value CI Lower 0.12696
#> 8       Policy Value CI Upper 0.29475
#> 9          Proportion Treated 0.54500
#> 
#> $interactive_table
#> 
#> $final_policy_model
#> 
#> Call:  cv.glmnet(x = as.matrix(X_combined), y = tau_combined, weights = weights,      nfolds = n_folds, alpha = 0) 
#> 
#> Measure: Mean-Squared Error 
#> 
#>     Lambda Index Measure      SE Nonzero
#> min 0.0395   100  0.9219 0.02942       3
#> 1se 0.5868    71  0.9505 0.03367       3

References