Introduction & Cram Policy part 1

Introduction: What is the Cram Method?

The Cram method is a powerful approach for simultaneously learning and evaluating decision rules, such as individualized treatment rules (ITRs), from data. Common applications include healthcare (who to treat), pricing and advertising (who to target or how much to charge), and policy (who to support).

Unlike traditional approaches like sample splitting or cross-validation, which waste part of the data on evaluation only, Cram reuses all available data efficiently.

A key distinction from cross-validation is that Cram evaluates the final learned model, rather than averaging performance across multiple models trained on different data splits.

Cram:

The Cram Workflow

Below is the core idea of the Cram method visualized:

 
 

This procedure ensures each update is backed by performance testing, enabling both learning and evaluation in one pass over the data.

 
 

Note: this schematic represents how Cram estimates the difference in policy value (see the definition of policy value above) relative to a baseline policy - an example of baseline policy in healthcare would be to treat nobody (all-zeros) or to randomly treat individuals (assign 1 for treatment and 0 for no treatment randomly); the policy value difference “Delta” gives you how much better (or worse) is the policy learned by Cram on the data relative to the baseline policy - but Cram can also be used to estimate the policy value of the learned policy directly, without the need to specify a baseline policy (not presented here as part of the introduction but available in the outputs of the main functions of Cram; see result table below).

The cram_policy() Function

The cram_policy() function in cramR implements the Cram framework for binary treatment policy learning.

🔑 Key Features of cram_policy()

Example: Running Cram Policy on simulated data

library(data.table)
# Function to generate sample data with heterogeneous treatment effects:
# - Positive effect group
# - Neutral effect group
# - Adverse effect group
generate_data <- function(n) {
  X <- data.table(
    binary = rbinom(n, 1, 0.5),                 # Binary variable
    discrete = sample(1:5, n, replace = TRUE),  # Discrete variable
    continuous = rnorm(n)                       # Continuous variable
  )

  # Binary treatment assignment (50% treated)
  D <- rbinom(n, 1, 0.5)

  # Define heterogeneous treatment effects based on X
  treatment_effect <- ifelse(
    X[, binary] == 1 & X[, discrete] <= 2,        # Group 1: Positive effect
    1,
    ifelse(X[, binary] == 0 & X[, discrete] >= 4, # Group 3: Adverse effect
           -1,
           0.1)                                   # Group 2: Neutral effect
  )

  # Outcome depends on treatment effect + noise
  Y <- D * (treatment_effect + rnorm(n, mean = 0, sd = 1)) +
    (1 - D) * rnorm(n)

  return(list(X = X, D = D, Y = Y))
}

# Generate a sample dataset
set.seed(123)
n <- 1000
data <- generate_data(n)
X <- data$X
D <- data$D
Y <- data$Y
# Options for batch:
# Either an integer specifying the number of batches or a vector/list of batch assignments for all individuals
batch <- 20

# Model type for estimating treatment effects
# Options for model_type: 'causal_forest', 's_learner', 'm_learner'
# Note: you can also set model_type to NULL and specify custom_fit and custom_predict to use your custom model
model_type <- "causal_forest"  

# Options for learner_type:
# if model_type == 'causal_forest', choose NULL
# if model_type == 's_learner' or 'm_learner', choose between 'ridge', 'fnn' and 'caret'
learner_type <- NULL  

# Baseline policy to compare against (list of 0/1 for each individual)
# Options for baseline_policy:
# A list representing the baseline policy assignment for each individual.
# If NULL, a default baseline policy of zeros is created.
# Examples of baseline policy: 
# - All-control baseline: as.list(rep(0, nrow(X))) or NULL
# - Randomized baseline: as.list(sample(c(0, 1), nrow(X), replace = TRUE))
baseline_policy <- as.list(rep(0, nrow(X)))  

# Whether to parallelize batch processing (i.e. the cram method learns T policies, with T the number of batches.
# They are learned in parallel when parallelize_batch is TRUE
# vs. learned sequentially using the efficient data.table structure when parallelize_batch is FALSE, recommended for light weight training).
# Defaults to FALSE.
parallelize_batch <- FALSE  
 

# Model-specific parameters (more details in the article "Cram Policy part 2")
# Examples: NULL defaults to the following:
# - causal_forest: list(num.trees = 100)
# - ridge: list(alpha = 1)
# - caret: list(formula = Y ~ ., caret_params = list(method = "lm", trControl = trainControl(method = "none")))
# - fnn (Feedforward Neural Network): see below
# input_shape <- if (model_type == "s_learner") ncol(X) + 1 else ncol(X)
# default_model_params <- list(
#       input_layer = list(units = 64, activation = 'relu', input_shape = input_shape),
#       layers = list(
#         list(units = 32, activation = 'relu')
#       ),
#       output_layer = list(units = 1, activation = 'linear'),
#       compile_args = list(optimizer = 'adam', loss = 'mse'),
#       fit_params = list(epochs = 5, batch_size = 32, verbose = 0)
#     )
model_params <- NULL  


# Significance level for confidence intervals (default = 95%)
alpha <- 0.05  

# Run the Cram policy method
result <- cram_policy(
  X, D, Y,
  batch = batch,
  model_type = model_type,
  learner_type = learner_type,
  baseline_policy = baseline_policy,
  parallelize_batch = parallelize_batch,
  model_params = model_params,
  alpha = alpha
)

# Display the results
print(result)
#> $raw_results
#>                        Metric   Value
#> 1              Delta Estimate 0.23208
#> 2        Delta Standard Error 0.05862
#> 3              Delta CI Lower 0.11718
#> 4              Delta CI Upper 0.34697
#> 5       Policy Value Estimate 0.21751
#> 6 Policy Value Standard Error 0.05237
#> 7       Policy Value CI Lower 0.11486
#> 8       Policy Value CI Upper 0.32016
#> 9          Proportion Treated 0.60500
#> 
#> $interactive_table
#> 
#> $final_policy_model
#> GRF forest object of type causal_forest 
#> Number of trees: 100 
#> Number of training samples: 1000 
#> Variable importance: 
#>     1     2     3 
#> 0.437 0.350 0.213

Interpreting Results

result$raw_results
#>                        Metric   Value
#> 1              Delta Estimate 0.23208
#> 2        Delta Standard Error 0.05862
#> 3              Delta CI Lower 0.11718
#> 4              Delta CI Upper 0.34697
#> 5       Policy Value Estimate 0.21751
#> 6 Policy Value Standard Error 0.05237
#> 7       Policy Value CI Lower 0.11486
#> 8       Policy Value CI Upper 0.32016
#> 9          Proportion Treated 0.60500
result$interactive_table

The output of cram_policy() includes:

class(result$final_policy_model)
#> [1] "causal_forest" "grf"
summary(result$final_policy_model)
#>                          Length Class  Mode   
#> _ci_group_size              1   -none- numeric
#> _num_variables              1   -none- numeric
#> _num_trees                  1   -none- numeric
#> _root_nodes               100   -none- list   
#> _child_nodes              100   -none- list   
#> _leaf_samples             100   -none- list   
#> _split_vars               100   -none- list   
#> _split_values             100   -none- list   
#> _drawn_samples            100   -none- list   
#> _send_missing_left        100   -none- list   
#> _pv_values                100   -none- list   
#> _pv_num_types               1   -none- numeric
#> predictions              1000   -none- numeric
#> variance.estimates          0   -none- numeric
#> debiased.error           1000   -none- numeric
#> excess.error             1000   -none- numeric
#> seed                        1   -none- numeric
#> num.threads                 1   -none- numeric
#> ci.group.size               1   -none- numeric
#> X.orig                   3000   -none- numeric
#> Y.orig                   1000   -none- numeric
#> W.orig                   1000   -none- numeric
#> Y.hat                    1000   -none- numeric
#> W.hat                    1000   -none- numeric
#> clusters                    0   -none- numeric
#> equalize.cluster.weights    1   -none- logical
#> tunable.params              7   -none- list   
#> has.missing.values          1   -none- logical

You can inspect or apply the learned model to new data.

Visual Summary and Notes

 
 

This visualization summarizes how multiple evaluations across iterations contribute to the full Cram estimate.

Notes:

📐 Comparing Evaluation Strategies: Sample-Splitting, Cross-Validation, and Cram

In this section, we compare classical strategies for model evaluation—namely sample-splitting and cross-validation—with the Cram method. While all three approaches may ultimately train a model using the full dataset, they differ fundamentally in how they estimate the generalization performance of that model.

🔍 Sample-Splitting

Sample-splitting divides the data into a training set and a held-out test set (e.g., 80/20). The model is trained on the training portion, and its performance is assessed on the held-out set. This procedure produces an evaluation of a model that was not trained on all the data, which may understate the performance of the final model trained on the full dataset. Thus, the evaluation corresponds to:

This raises two issues:

🔄 Cross-Validation

We consider k-fold cross-validation, which partitions the data into k equal-sized folds. For each fold, a model is trained on the remaining k-1 folds and evaluated on the held-out fold. This process ensures that each observation is used for both training and evaluation, but in different models. The final performance estimate is the average of fold-specific evaluation metrics and serves as a proxy for the expected performance of the model that would be trained on the full dataset.

Thus, while cross-validation uses all data for evaluation, it only evaluates models trained on partial data. Crucially, the final model trained on the full dataset is never evaluated directly; its performance is approximated by averaging over surrogate models trained on subsets.

📈 Cram

Cram departs from these approaches by directly targeting the performance of the final model trained on the entire dataset, casting evaluation as a statistical estimation problem: it estimates the population-level performance (e.g., expected outcome or loss) that the model would achieve if deployed.

Specifically:

🧠 Key Distinction Summary: What Is Being Evaluated?

All methods aim to estimate the generalization performance of the model trained on the full dataset (denoted as final model in the table below for readability). However, they differ in the models trained during evaluation, the data used for evaluation, and how the final performance estimate is constructed.

Method Evaluation Models Trained Evaluation Data Used Evaluation Mechanism
Sample-Splitting One model trained on a subset Held-out subset Empirical performance on test set
Cross-Validation k models trained on different subsets Entire data (across folds) Average of fold-specific evaluation metrics
Cram Sequence of models trained on cumulative batches Entire data Statistical estimation of generalization performance of final model; provides confidence intervals and inference

final model*: model trained on the full dataset

References