Introduction to LDATree

LDATree is an R modeling package for fitting classification trees. If you are unfamiliar with classification trees, here is a tutorial about the traditional CART and its R implementation rpart.

Compared to other similar trees, LDATree sets itself apart in the following ways:

Build the Tree

Currently, LDATree offers two methods to construct a tree:

  1. The first method utilizes a direct-stopping rule, halting the growth process once specific conditions are satisfied.

  2. The second approach involves pruning: it permits the building of a larger tree, which is then pruned using cross-validation.

# Build a tree using direct-stopping rule
fit <- Treee(Species~., data = iris)

# Build a tree using cross-validation
set.seed(443)
fitCV <- Treee(Species~., data = iris, pruneMethod = "CV")

Plot the Tree

LDATree offers two plotting methods:

  1. You can use plot directly to view the full tree diagram.

  2. To check the individual plot for the node that you are interested in, you have to input the (training) data and specify the node index.

Overall Plot

# View the overall tree
# plot(fit) # Tips: Try clicking on the nodes...

Individual Plots

# Three types of individual plots
# 1. Scatter plot on first two LD scores
plot(fit, data = iris, node = 1)


# 2. Density plot on the first LD score
plot(fit, data = iris, node = 3)
#> Warning: Groups with fewer than two data points have been dropped.
#> Warning in max(ids, na.rm = TRUE): no non-missing arguments to max; returning
#> -Inf


# 3. A message
plot(fit, data = iris, node = 5)
#> [1] "Every observation in this node is predicted to be virginica"

Make Predictions

# Prediction only
predictions <- predict(fit, iris)
head(predictions)
#> [1] "setosa" "setosa" "setosa" "setosa" "setosa" "setosa"
# A more informative prediction
predictions <- predict(fit, iris, type = "all")
head(predictions)
#>   response node setosa versicolor virginica
#> 1   setosa   13      0          0         0
#> 2   setosa   13      0          0         0
#> 3   setosa   13      0          0         0
#> 4   setosa   13      0          0         0
#> 5   setosa   13      0          0         0
#> 6   setosa   13      0          0         0
# Experimental feature: LDAGrove
# If you use CV to prune the tree, you can try an ensemble prediction
predictions <- predict(fitCV, iris, type = "grove")
head(predictions)
#> [1] "setosa" "setosa" "setosa" "setosa" "setosa" "setosa"

Missing Values

For missing values, you do not need to specify anything (unless you want to); LDATree will handle it. By default, it fills in missing numerical variables with their mean and adds a missing flag. For missing factor variables, it assigns a new level. For more options, please refer to help(Treee).

# 
irisMissing <- iris
for(i in 1:4) irisMissing[sample(150,20),i] <- NA
fitMissing <- Treee(Species~., data = irisMissing)
plot(fitMissing, data = irisMissing, node = 1)

LDA/GSVD

As we re-implement the LDA/GSVD and apply it in the model fitting, a by-product is the ldaGSVD function. Feel free to play with it and see how it compares to MASS::lda.

fitLDAgsvd <- ldaGSVD(Species~., data = iris)
predictionsLDAgsvd <- predict(fitLDAgsvd, newdata = iris)
mean(predictionsLDAgsvd == iris$Species)
#> [1] 0.98