| Title: | Rashomon Set of Optimal Trees |
| Version: | 0.1.1 |
| Description: | Implements a general framework for globally optimizing user-specified objective functionals over interpretable binary weight functions represented as sparse decision trees, called ROOT (Rashomon Set of Optimal Trees). It searches over candidate trees to construct a Rashomon set of near-optimal solutions and derives a summary tree highlighting stable patterns in the optimized weights. ROOT includes a built-in generalizability mode for identifying subgroups in trial settings for transportability analyses (Parikh et al. (2025) <doi:10.1080/01621459.2025.2495319>). |
| License: | MIT + file LICENSE |
| Encoding: | UTF-8 |
| RoxygenNote: | 7.3.3 |
| Suggests: | mlbench, testthat (≥ 3.0.0), knitr, rmarkdown, ragg |
| Config/testthat/edition: | 3 |
| Imports: | MASS, rpart, gbm, stats, withr, rpart.plot |
| URL: | https://github.com/peterliu599/ROOT |
| BugReports: | https://github.com/peterliu599/ROOT/issues |
| VignetteBuilder: | knitr |
| Depends: | R (≥ 3.5) |
| LazyData: | true |
| NeedsCompilation: | no |
| Packaged: | 2026-03-05 19:04:35 UTC; yh785 |
| Author: | Yiren Hou |
| Maintainer: | Peter Liu <bliu68@jh.edu> |
| Repository: | CRAN |
| Date/Publication: | 2026-03-10 11:40:12 UTC |
Rashomon Set of Optimal Trees (ROOT) for Functional Optimization
Description
ROOT (Rashomon Set of Optimal Trees) is a general-purpose functional
optimization algorithm that learns interpretable, tree-structured binary
weight functions w(X) \in \{0, 1\}. Given a dataset D_n and a
global objective function L(w, D_n), ROOT searches over the space of
decision trees to find weight assignments that minimize the objective function.
Usage
ROOT(
data,
global_objective_fn = NULL,
generalizability_path = FALSE,
leaf_proba = 0.25,
seed = NULL,
num_trees = 10,
vote_threshold = 2/3,
explore_proba = 0.05,
feature_est = "Ridge",
feature_est_args = list(),
top_k_trees = FALSE,
k = 10,
cutoff = "baseline",
max_depth = 8L,
min_leaf_n = 2L,
max_rejects_per_node = 10L,
verbose = FALSE
)
Arguments
data |
A data.frame containing the dataset. In general optimization mode ( In generalizability mode ( |
global_objective_fn |
A function with signature
|
generalizability_path |
|
leaf_proba |
A numeric tuning parameter that increases the chance
a node stops splitting by selecting a synthetic |
seed |
An optional numeric seed for reproducibility. |
num_trees |
An integer number of trees to grow. More trees explore the solution space more thoroughly. Default 10. |
vote_threshold |
Controls how per-observation votes from the Rashomon
set trees are aggregated into the final binary weight
|
explore_proba |
A numeric giving the exploration probability at
leaves. With probability |
feature_est |
Either |
feature_est_args |
A list of additional arguments passed to a
user-supplied |
top_k_trees |
|
k |
An integer giving the number of top trees when
|
cutoff |
A numeric or |
max_depth |
Maximum depth of each tree grown during the forest
construction stage. A node at |
min_leaf_n |
Minimum number of observations required in a node for
splitting to be attempted. If a node contains fewer than
|
max_rejects_per_node |
Maximum number of consecutive rejected splits
(splits that do not improve the objective) allowed at a single node
before the node is forced to become a leaf. This prevents infinite
recursion in pathological cases. Default |
verbose |
|
Value
An object of class "ROOT" (a list) with elements:
D_rash |
Data frame containing the Rashomon-set tree votes and the
final aggregated weight |
D_forest |
Data frame with all forest-level working columns. |
w_forest |
List of per-tree results from the tree-building routine. |
rashomon_set |
Integer vector of indices identifying which trees were selected into the Rashomon set. |
global_objective_fn |
The objective function used. |
f |
The characteristic (summary) tree fitted to |
testing_data |
Data frame of observations used for optimization
(trial units when |
estimate |
(Only when |
generalizability_path |
Logical flag echoing the input. |
The optimization problem
ROOT solves the functional optimization problem:
w^* \in \arg\min_w L(D_n, w)
where w: \mathbb{R}^p \to \{0, 1\} maps a p-dimensional
covariate vector to a binary include/exclude decision. The key challenge is
that, unlike standard tree algorithms, the global loss
L(D_n, w) is not decomposable as a sum of losses over
independent subsets of the data. This means conventional greedy,
divide-and-conquer tree-building strategies do not apply. ROOT addresses
this through a randomization-based tree construction with an
explore-exploit strategy.
How ROOT works
The algorithm proceeds in several stages:
-
Feature importance estimation: Split probabilities are estimated using Ridge regression, Gradient Boosting Machine (GBM), or a user-supplied function, biasing the search toward covariates likely to be informative.
-
Stochastic tree construction:
num_treestrees are grown. At each internal node, a feature is drawn according to the estimated split probabilities (or a "leaf" token is drawn, terminating the branch). Splits are made at the midpoint of the selected feature's empirical distribution. An explore-exploit strategy assigns leaf weights: with probabilityexplore_probaa random weight is chosen; otherwise the greedy optimal weight (reducing the global objective) is used. -
Rashomon set selection: Trees are ranked by their global objective values. The top-
ktrees (or all trees below a cutoff) form the Rashomon set: a collection of near-optimal but potentially different models, each providing a characterization of the optimal weight function. -
Aggregation: Per-observation votes from the Rashomon set are combined (by default, majority vote) to produce the final weight vector
w_opt. -
Characteristic tree: A single summary decision tree is fitted to the aggregated
w_optassignments, providing a concise, interpretable description of the weight function.
Generalizability mode
When generalizability_path = TRUE, ROOT implements the methodology
of Parikh et al. (2025) for characterizing underrepresented subgroups in
trial-to-target generalizability analyses. In this mode:
-
datamust contain columnsY(outcome),Tr(treatment, 0/1), andS(sample indicator, 1 = trial, 0 = target). ROOT internally computes transportability scores based on inverse-probability weighting (IPW), estimates the selection model
P(S = 1 \mid X), and constructs Horvitz-Thompson-style influence scores.The default objective minimizes the variance of the weighted target average treatment effect (WTATE) estimator. This objective accounts for both the selection odds (trial participation probability) and treatment effect heterogeneity, so that subgroups are flagged as underrepresented only when they both lack trial representation and exhibit effect modification.
The output includes the unweighted sample average treatment effect (SATE) and the WTATE with standard errors.
See characterizing_underrep for a higher-level wrapper that
additionally produces a leaf-level summary table, and
vignette("generalizability_path_example") for a worked example.
General optimization mode
When generalizability_path = FALSE, ROOT operates as a general
functional optimizer. The user supplies any data.frame and
(optionally) a custom global_objective_fn. If no objective is
supplied, ROOT uses a default variance-based loss operating on the
vsq column (per-unit variance proxy). See
vignette("optimization_path_example") for an example.
References
Parikh H, Ross RK, Stuart E, Rudolph KE (2025). "Who Are We Missing?: A Principled Approach to Characterizing the Underrepresented Population." Journal of the American Statistical Association. doi:10.1080/01621459.2025.2495319
See Also
characterizing_underrep for a higher-level wrapper with
leaf-summary output; vignette("generalizability_path_example") for
the generalizability workflow;
vignette("optimization_path_example") for general optimization.
Examples
# --- Generalizability mode ---
data(diabetes_data, package = "ROOT")
root_fit <- ROOT(
data = diabetes_data,
generalizability_path = TRUE,
num_trees = 20,
top_k_trees = TRUE,
k = 10,
seed = 123
)
# --- General optimization mode (custom objective) ---
my_objective <- function(D) {
w <- D$w
if (sum(w) == 0) return(Inf)
sqrt(sum(w * D$vsq) / sum(w)^2)
}
set.seed(123)
n_assets <- 100
# Asset features
volatility <- runif(n_assets, 0.05, 0.40) # annualised volatility
beta <- runif(n_assets, 0.5, 1.8) # market beta
sector <- sample(c("Tech", "Finance", "Energy", "Health"),
n_assets, replace = TRUE)
# Simulate returns: r_i = beta_i * r_market + epsilon_i
market <- rnorm(1000, 0.0005, 0.01)
returns_mat <- sapply(seq_len(n_assets), function(i)
beta[i] * market + rnorm(1000, 0, volatility[i] / sqrt(252))
)
# Per-asset return variance (the objective proxy ROOT will minimize)
vsq <- apply(returns_mat, 2, var)
my_data <- data.frame(
vsq = vsq,
vol = volatility,
beta = beta,
sector = as.integer(factor(sector))
)
opt_fit <- ROOT(
data = my_data,
global_objective_fn = my_objective,
num_trees = 20,
seed = 42
)
Fit a shallow decision tree to characterize learned weights w
Description
Fit a shallow decision tree to characterize learned weights w
Usage
characterize_tree(X, w, max_depth = 3)
Arguments
X |
A data frame of covariates (features). |
w |
A binary vector (0/1, TRUE/FALSE, or factor/character encoding 0/1) with exactly two classes present. |
max_depth |
Integer, the maximum tree depth (default 3). |
Value
An rpart object representing the fitted decision tree.
Characterize Underrepresented Subgroups
Description
A high-level wrapper around ROOT() for identifying and
characterizing subgroups that are insufficiently represented in a
randomized controlled trial (RCT) relative to a target population. The
function returns an interpretable decision tree describing which subgroups
should be included (w(X) = 1) or excluded (w(X) = 0) from the
analysis, along with the corresponding target average treatment effect estimates.
Usage
characterizing_underrep(
data,
global_objective_fn = NULL,
generalizability_path = FALSE,
leaf_proba = 0.25,
seed = 123,
num_trees = 10,
vote_threshold = 2/3,
explore_proba = 0.05,
feature_est = "Ridge",
feature_est_args = list(),
top_k_trees = FALSE,
k = 10,
cutoff = "baseline",
max_depth = 8L,
min_leaf_n = 2L,
max_rejects_per_node = 10L,
verbose = FALSE
)
Arguments
data |
A |
global_objective_fn |
A function with signature
|
generalizability_path |
Logical. If |
leaf_proba |
A |
seed |
Random seed for reproducibility. |
num_trees |
Number of trees to grow in the ROOT forest. More trees explore the tree space more thoroughly but increase computation time. |
vote_threshold |
Controls how Rashomon-set tree votes are aggregated
into |
explore_proba |
Exploration probability in tree growth. Controls the
explore-exploit trade-off: with probability |
feature_est |
Either |
feature_est_args |
List of extra arguments passed to
|
top_k_trees |
Logical; if |
k |
Number of trees to retain when |
cutoff |
Numeric or |
max_depth |
Maximum depth of each tree grown during the forest
construction stage. A node at |
min_leaf_n |
Minimum number of observations required in a node for
splitting to be attempted. If a node contains fewer than
|
max_rejects_per_node |
Maximum number of consecutive rejected splits
(splits that do not improve the objective) allowed at a single node
before the node is forced to become a leaf. This prevents infinite
recursion in pathological cases. Default |
verbose |
Logical; if |
Value
A characterizing_underrep S3 object (a list) with:
root |
The |
combined |
The input |
leaf_summary |
A |
What does "underrepresented" mean?
In the context of generalizing treatment effects from a trial to a target population, a subgroup is considered underrepresented (or insufficiently represented) when it occupies a region of the covariate space that both (a) has limited overlap between the trial and the target population, and (b) exhibits heterogeneous treatment effects.
Formally, the contribution of a unit with covariates X = x to the
variance of the target average treatment effect (TATE) estimator depends on
both the selection ratio \ell(x) = P(S=1 \mid X=x) / P(S=0 \mid X=x)
and the conditional average treatment effect. Subgroups where \ell(x)
is small, and conditional average treatment effect deviates from the
overall TATE, contribute disproportionately to estimator variance. These are
the subgroups that characterizing_underrep() identifies and
characterizes. The sample average treatment effect (SATE) is a
finite sample equivalent version of the TATE.
The generalizability workflow
When generalizability_path = TRUE, this function implements the
two-stage approach of Parikh et al. (2025):
-
Design stage: ROOT learns binary weights
w(X)that minimize the variance of the weighted target average treatment effect (WTATE) estimator, subject to interpretability constraints (tree structure). The resulting decision tree characterizes which subgroups are well-represented (w = 1) and which are underrepresented (w = 0). -
Analysis stage: The WTATE is estimated on the refined target population that excludes the underrepresented subgroups. This estimand trades some generality for greater precision and credibility.
The key estimands are:
-
SATE (Sample Average Treatment Effect): the treatment effect for the full target population based on the trial sample, which may be imprecise if certain subgroups are underrepresented. It is a finite sample equivalent version of the TATE.
-
WTATE (Weighted Target Average Treatment Effect): the treatment effect restricted to the sufficiently represented subpopulation, estimated with lower variance.
General optimization mode
When generalizability_path = FALSE, this function behaves as a
convenience wrapper around ROOT() for arbitrary binary weight
optimization. The user can supply a custom objective function via
global_objective_fn; ROOT will learn an interpretable tree-based
weight function minimizing that objective. See
vignette("optimization_path_example") for an example.
Data requirements
When generalizability_path = TRUE, data must contain the
following standardized columns:
-
Y: numeric outcome variable. -
Tr: binary treatment indicator (0 = control, 1 = treated). -
S: binary sample indicator (1 = trial/RCT, 0 = target population).
All remaining columns are treated as pretreatment covariates X
available for splitting.
References
Parikh H, Ross RK, Stuart E, Rudolph KE (2025). "Who Are We Missing?: A Principled Approach to Characterizing the Underrepresented Population." Journal of the American Statistical Association. doi:10.1080/01621459.2025.2495319
See Also
ROOT for the underlying optimization engine;
vignette("generalizability_path_example") for a detailed worked
example of the generalizability workflow;
vignette("optimization_path_example") for general optimization mode.
Examples
# --- Generalizability analysis ---
# diabetes_data has columns Y, Tr, S, and covariates
data(diabetes_data, package = "ROOT")
char_fit <- characterizing_underrep(
data = diabetes_data,
generalizability_path = TRUE,
num_trees = 20,
top_k_trees = TRUE,
k = 10,
seed = 123
)
# View the characterization tree
plot(char_fit)
# Inspect which subgroups are underrepresented
char_fit$leaf_summary
# Treatment effect estimates (SATE and WTATE)
char_fit$root$estimate
Randomly choose a split feature based on provided probabilities
Description
Given a probability distribution over features (and possibly a "leaf" option), selects one feature at random according to those probabilities.
Usage
choose_feature(split_feature, depth)
Arguments
split_feature |
A named numeric vector of feature selection probabilities. Names should correspond to feature IDs (and may include a special "leaf" entry). |
depth |
Current tree depth (an integer, used for parity with Python implementation but not affecting probabilities in this implementation). |
Value
A single feature name (or "leaf") chosen randomly according to the provided probability weights.
Note
The factor 2^{(0 * depth / 4)} present in the code is effectively 1 (no effect on the first element's weight) and is included only for parity with an equivalent Python implementation. All probabilities are normalized to sum to 1 before sampling.
Compute transport influence scores for generalizability mode
Description
Internal helper used in the generalizability path to construct
generalized linear model (glm)-based, inverse probability weighting
(IPW)-style scores for transporting trial effects to a target population.
Usage
compute_transport_scores(data, outcome, treatment, sample)
Arguments
data |
A |
outcome |
A length-1 character string giving the name of the outcome
column in |
treatment |
A length-1 character string giving the name of the
treatment indicator column in |
sample |
A length-1 character string giving the name of the sample
indicator column in |
Details
The function treats data as a stacked dataset with a sample
indicator S (sample) taking value 1 in the randomized trial
and 0 in the target sample. It proceeds in three steps:
Fit a sampling model
P(S = 1 | X)using logistic regression on all rows ofdata.Within the trial subset
S = 1, fit a treatment modelP(T = 1 | X, S = 1)using logistic regression.For each row, form the density ratio
r(X) = P(S = 0 | X) / P(S = 1 | X)and compute a Horvitz-Thompson-style transported scorev(X, T, Y) = r(X) \left[ \frac{T Y}{e(X)} - \frac{(1 - T) Y}{1 - e(X)} \right],where
e(X) = P(T = 1 | X, S = 1).
The resulting score vector v and its squared version vsq
can be used as pseudo-outcomes for tree-based search over transported
treatment effects.
Value
A list with two numeric vectors of length nrow(data):
vThe transported influence-style score.
vsqThe element-wise square of
v, i.e.,v^2.
Simulated diabetes dataset for examples
Description
A toy dataset for illustrating ROOT examples and tests.
Usage
data(diabetes_data)
Format
A data.frame with one row per individual and the columns:
- Age45
Indicator in
0/1: age >= 45.- DietYes
Indicator in
0/1: on a diet program.- Race_Black
Indicator in
0/1: race is Black.- S
Sample indicator in
0/1:1means RCT or source,0means target.- Sex_Male
Indicator in
0/1: male.- Tr
Treatment assignment in
0/1.- Y
Observed outcome (
numericor0/1).
Compute the midpoint of a numeric vector
Description
Calculates midpoint = (min + max)/2 using finite values only. If there are no finite values, returns NA with a single warning.
Usage
midpoint(X)
Arguments
X |
A numeric vector. |
Value
Numeric scalar midpoint, or NA_real_ if no finite values.
Plot the ROOT summary tree
Description
Visualizes the decision tree that characterizes the weighted subgroup
(the weight function w(d) in {0,1}) identified by ROOT(),
using rpart.plot::prp().
Usage
## S3 method for class 'ROOT'
plot(x, ...)
Arguments
x |
A |
... |
Additional arguments passed to |
Value
No return value; the plot is drawn to the active graphics device.
Examples
ROOT.output = ROOT(diabetes_data,generalizability_path = TRUE, seed = 123)
plot(ROOT.output)
Plot Underrepresented Population Characterization
Description
Visualizes the decision tree derived from the ROOT analysis. Highlights
which subgroups are represented where w = 1 versus underrepresented
where w = 0 in generalizability mode, or simply w(x) in {0,1}
in general optimization mode.
Usage
## S3 method for class 'characterizing_underrep'
plot(
x,
main = "Final Characterized Tree from Rashomon Set",
cex.main = 1.2,
...
)
Arguments
x |
A |
main |
Character string for the plot title. Default is
|
cex.main |
Numeric scaling factor for the title text size. Default is |
... |
Additional arguments passed to |
Value
NULL. The plot is drawn to the active graphics device.
Examples
char.output = characterizing_underrep(diabetes_data,generalizability_path = TRUE, seed = 123)
plot(char.output)
Print a ROOT fit
Description
Provides a human-readable brief summary of a ROOT object, including:
the summary characterization tree
f,in generalizability mode (
generalizability_path = TRUE), the unweighted and weighted estimands with their standard errors and an explanatory note for the weighted standard error (SE).
Usage
## S3 method for class 'ROOT'
print(x, ...)
Arguments
x |
A |
... |
Currently unused and included for S3 compatibility. |
Value
object returned invisibly. Printed output is for inspection.
Abbreviations
- ATE
Average treatment effect.
- RCT
Randomized controlled trial.
- SE
Standard error.
- TATE
Transported average treatment effect.
- WTATE
Weighted transported average treatment effect.
- SATE
Sample average treatment effect.
When generalizability_path = TRUE, the unweighted estimand corresponds
to a SATE-type quantity and the weighted estimand to a WTATE-type
quantity for the transported target population. When generalizability_path = FALSE,
ROOT is used for general functional optimization and no causal labels
are imposed.
Examples
ROOT.output = ROOT(diabetes_data,generalizability_path = TRUE, seed = 123)
print(ROOT.output)
Print a characterizing_underrep fit
Description
Print the ROOT summary which includes unweighted and (when in
generalizability mode) weighted estimates with standard errors, as reported by
summary.ROOT().
Usage
## S3 method for class 'characterizing_underrep'
print(x, ...)
Arguments
x |
A |
... |
Currently unused. Included for S3 compatibility. |
Details
Delegates core statistics and estimands to print(x$root).
Value
object returned invisibly. Printed output is a readable brief summary.
Abbreviations
- ATE
Average treatment effect.
- RCT
Randomized controlled trial.
- SE
Standard error.
- TATE
Transported average treatment effect.
- WTATE
Weighted transported average treatment effect.
- SATE
Sample average treatment effect.
Examples
char.output = characterizing_underrep(diabetes_data,generalizability_path = TRUE, seed = 123)
print(char.output)
Reduce a feature's selection weight by half and renormalize
Description
Lowers the probability weight of a given feature by 50%, and then re-normalizes the entire probability vector.
Usage
reduce_weight(fj, split_feature)
Arguments
fj |
A feature name (character string) present in the names of |
split_feature |
A named numeric vector of probabilities for features (as used in splitting). |
Details
This is typically used when a particular feature split was rejected; the feature's probability is halved to reduce its chance of being chosen again immediately, encouraging exploration of other features. If fj is "leaf", its weight is also halved similarly.
Value
A numeric vector of the same length as split_feature, giving the updated probabilities that sum to 1.
Recursive split builder for weighted tree
Description
Recursively builds a weighted decision tree to optimize a global objective, using an exploration/exploitation trade-off.
Usage
split_node(
split_feature,
X,
D,
parent_loss,
depth,
explore_proba = 0.05,
choose_feature_fn = choose_feature,
reduce_weight_fn = reduce_weight,
global_objective_fn = objective_default,
max_depth = 8,
min_leaf_n = 5,
log_fn = function(...) {
},
max_rejects_per_node = 1000
)
Arguments
split_feature |
Named numeric vector of feature selection probabilities (must include "leaf"). |
X |
Data frame of current observations (includes candidate split feature columns; may include a working copy of weights |
D |
Data frame representing the global state (must include columns |
parent_loss |
Numeric, the loss value of the parent node (used to decide if a split improves the objective). |
depth |
Integer, current tree depth. |
explore_proba |
Numeric, the probability (between 0 and 1) of flipping the exploit choice at a leaf. |
choose_feature_fn |
Function to choose next feature (default |
reduce_weight_fn |
Function to penalize last-tried feature on rejected split (default |
global_objective_fn |
Function |
max_depth |
Integer max depth (stop and make leaf at this depth). |
min_leaf_n |
Integer min rows to attempt a split; else make leaf. |
log_fn |
Function for logging; default no-op. |
max_rejects_per_node |
Safety budget of rejected splits before forcing a leaf. |
Value
A list representing the (sub)tree; includes updated D and local objective.
Summarize a ROOT fit
Description
Provides a readable summary of a ROOT object, including:
the summary characterization tree
f,whether the user supplied a custom
global_objective_fn(Yes/No), andin generalizability mode (
generalizability_path = TRUE), the unweighted and weighted estimands with their standard errors.
Usage
## S3 method for class 'ROOT'
summary(object, ...)
Arguments
object |
A |
... |
Currently unused and included for S3 compatibility. |
Value
object returned invisibly. Printed output is for inspection.
Abbreviations
- ATE
Average treatment effect.
- RCT
Randomized controlled trial.
- SE
Standard error.
- TATE
Transported average treatment effect.
- WTATE
Weighted transported average treatment effect.
- SATE
Sample average treatment effect.
When generalizability_path = TRUE, the unweighted estimand corresponds
to a SATE-type quantity and the weighted estimand to a WTATE-type
quantity for the transported target population. When generalizability_path = FALSE,
ROOT is used for general functional optimization and no causal labels
are imposed; the summary focuses on the tree and diagnostics.
Diagnostics
The summary also reports:
the number of trees grown,
the size of the Rashomon set,
the percentage of observations with ensemble vote
w_opt == 1.
Examples
ROOT.output = ROOT(diabetes_data,generalizability_path = TRUE, seed = 123)
summary(ROOT.output)
Summarize a characterizing_underrep fit
Description
Summarizes the ROOT summary which includes unweighted and (when in
generalizability mode) weighted estimates with standard errors, as reported by
summary.ROOT(). Provides a brief overview of terminal rules from the
annotated summary tree when available.
Usage
## S3 method for class 'characterizing_underrep'
summary(object, ...)
Arguments
object |
A |
... |
Currently unused. Included for S3 compatibility. |
Details
Delegates core statistics and estimands to summary(object$root).
Previews up to ten terminal rules when a summary tree exists.
Value
object returned invisibly. Printed output is a readable summary.
Abbreviations
- ATE
Average treatment effect.
- RCT
Randomized controlled trial.
- SE
Standard error.
- TATE
Transported average treatment effect.
- WTATE
Weighted transported average treatment effect.
- SATE
Sample average treatment effect.
Examples
char.output = characterizing_underrep(diabetes_data,generalizability_path = TRUE, seed = 123)
summary(char.output)