Cram Bandit

What is cram_bandit()?

The cram_bandit() function implements the Cram methodology for on-policy statistical evaluation of contextual bandit algorithms. Unlike traditional off-policy approaches, Cram uses the same adaptively collected data for both learning and evaluation, delivering more efficient, consistent, and asymptotically normal policy value estimates.

Introduction: Bandits, Policies, and Cram

In many machine learning settings, decisions must be made sequentially under uncertainty — for instance, recommending content, personalizing treatments, or allocating resources. These problems are often modeled as contextual bandits, where an agent:

  1. Observes context (features of the situation)
  2. Chooses an action (e.g., recommend an article)
  3. Observes a reward (e.g., targeted user clicks on the article or not)

A policy is a function that maps context to a probability distribution over actions, with the goal of maximizing expected cumulative reward over time. Learning an optimal policy and evaluating its performance using the same data is difficult due to the adaptive nature of the data collection.

This challenge becomes evident when comparing to supervised learning: in supervised learning, the outcome or label \(y\) is observed for every input \(x\), allowing direct minimization of prediction error. In contrast, in a bandit setting, the outcome (reward) is only observed for the single action chosen by the agent. The agent must therefore select an action in order to reveal the reward associated with it, making data collection and learning inherently intertwined.

The Cram method addresses this as being a general statistical framework for evaluating the final learned policy from a multi-armed contextual bandit algorithm, using the dataset generated by the same bandit algorithm. Notably, Cram is able to leverage this setting to return an estimate of how well the learned policy would perform if deployed on the entire population (policy value), along with a confidence interval at desired significance level.

Understanding the inputs

Many contextual bandit algorithms update their policies every few rounds instead of at every step — this is known as the batched setting. For example, if the batch size is \(B = 5\), the algorithm collects 5 new samples before updating its policy. This results in a sequence of policies \(\hat{\pi}_1, \hat{\pi}_2, ..., \hat{\pi}_T\), where \(T\) is the number of batches.

In total, we observe \(T \times B\) data points, each consisting of:

Cram supports the batched setting of bandit algorithms to allow for flexible use. Note that one can still set \(B = 1\) if performing policy updates after each observation.

Thus, Cram bandit takes as inputs:

Cram bandit returns:


Example: Use cram_bandit() on simulated data with batch size of 1

# Set random seed for reproducibility
set.seed(42)

# Define parameters
T <- 100  # Number of timesteps
K <- 4    # Number of arms

# Simulate a 3D array `pi` of shape (T, T, K)
# - First dimension: Individuals (context Xj)
# - Second dimension: Time steps (pi_t)
# - Third dimension: Arms (depth)
pi <- array(runif(T * T * K, 0.1, 1), dim = c(T, T, K))

# Normalize probabilities so that each row sums to 1 across arms
for (t in 1:T) {
  for (j in 1:T) {
    pi[j, t, ] <- pi[j, t, ] / sum(pi[j, t, ])  
  }
}

# Simulate arm selections (randomly choosing an arm)
arm <- sample(1:K, T, replace = TRUE)

# Simulate rewards (assume normally distributed rewards)
reward <- rnorm(T, mean = 1, sd = 0.5)

result <- cram_bandit(pi, arm, reward, batch=1, alpha=0.05)
result$raw_results
#>                        Metric   Value
#> 1       Policy Value Estimate 0.67621
#> 2 Policy Value Standard Error 0.04394
#> 3       Policy Value CI Lower 0.59008
#> 4       Policy Value CI Upper 0.76234
result$interactive_table

References