Before loading the package, we should allocate enough memory for Java. Here we allocate 10GB of memory for Java.
In this vignette, we will run iBART on the complex model described in Section 3.4 of the paper, i.e. the data-generating model is y = 15{exp (x1) − exp (x2)}2 + 20sin (πx3x4) + ε, ε ∼ 𝒩n(0, σ2I). The primary features are X = (x1, ..., xp), where $x_1,...,x_p \overset{\text{iid}}\sim\text{Unif}_n(-1,1)$. We will use the following setting: n = 250, p = 10, and σ = 0.5. The goal in OIS is to identify the 2 true descriptors: f1(X) = {exp (x1) − exp (x2)}2 and f2(X) = sin (πx3x4) using only (y, X) as input. Let’s generate the primary features X and the response variable y.
#### Simulation Parameters ####
n <- 250 # Change n to 100 here to reproduce result in Supplementary Materials A.2.3
p <- 10 # Number of primary features
#### Generate Data ####
X <- matrix(runif(n * p, min = -1, max = 1), nrow = n, ncol = p)
colnames(X) <- paste("x.", seq(from = 1, to = p, by = 1), sep = "")
y <- 15 * (exp(X[, 1]) - exp(X[, 2]))^2 + 20 * sin(pi * X[, 3] * X[, 4]) + rnorm(n, mean = 0, sd = 0.5)
Note that the input data to iBART are only y and X = (x1, ..., x10), and iBART needs to
At Iteration 0 (base iteration), iBART determines which of the primary features, (x1, …, x10), are useful and only apply operators on the useful ones. If successful, iBART should keep (x1, …, x4) in the loop and discard (x5, …, x10). Let’s run iBART for 1 iteration (Iteration 0 + Iteration 1) and examine its outputs.
#### iBART ####
iBART_sim <- iBART(X = X, y = y,
name = colnames(X),
num_burn_in = 5000, # lower value for faster run
num_iterations_after_burn_in = 1000, # lower value for faster run
num_permute_samples = 20, # lower value for faster run
opt = c("unary"), # only apply unary operators after base iteration
sin_cos = TRUE,
apply_pos_opt_on_neg_x = FALSE,
Lzero = TRUE,
K = 4,
standardize = FALSE,
seed = 123)
#> Start iBART descriptor generation and selection...
#> Iteration 1
#> iBART descriptor selection...
#> avg..........null....................
#> Constructing descriptors using unary operators...
#> Building X.unary... Initial p = 4; New p = 28
#> BART iteration done!
#> LASSO descriptor selection...
#> L-zero regression...
#> Total time: 23.2862904071808 secs
iBART()
returns a list object that contains many
interesting outputs; see ?iBART::iBART
for a full list of
return values. Here we focus on 2 return values:
iBART_sim$descriptor_names
: the descriptors selected by
iBARTiBART_sim$iBART_model
: the selected model—a
cv.glmnet
objectWe can use the iBART model the same way we would use a
glmnet
model. For instance, we can print out the
coefficients using coef()
.
# iBART selected descriptors
iBART_sim$descriptor_names
#> [1] "(x.1)^2" "(x.2)^2" "(x.4)^2" "exp(x.1)" "exp(x.2)"
#> [6] "exp(x.3)" "exp(x.4)" "sin(pi*x.1)" "sin(pi*x.2)" "sin(pi*x.3)"
#> [11] "sin(pi*x.4)" "cos(pi*x.3)" "(1/x.1)" "(1/x.2)" "(1/x.3)"
#> [16] "(1/x.4)" "abs(x.2)"
# iBART model
coef(iBART_sim$iBART_model, s = "lambda.min")
#> 29 x 1 sparse Matrix of class "dgCMatrix"
#> s1
#> (Intercept) -8.98377149
#> x.1 .
#> x.2 .
#> x.3 .
#> x.4 .
#> (x.1)^2 18.48979389
#> (x.2)^2 8.44944371
#> (x.3)^2 .
#> (x.4)^2 -2.91764266
#> exp(x.1) 5.51582288
#> exp(x.2) 9.74121269
#> exp(x.3) -0.49444789
#> exp(x.4) -5.47704491
#> sin(pi*x.1) -2.22769248
#> sin(pi*x.2) -7.06001109
#> sin(pi*x.3) -1.39995518
#> sin(pi*x.4) 5.10808705
#> cos(pi*x.1) .
#> cos(pi*x.2) .
#> cos(pi*x.3) 0.23919051
#> cos(pi*x.4) .
#> (1/x.1) -0.05683143
#> (1/x.2) 0.04017341
#> (1/x.3) -0.01126814
#> (1/x.4) -0.03267275
#> abs(x.1) .
#> abs(x.2) 2.96458665
#> abs(x.3) .
#> abs(x.4) .
iBART_sim$descriptor_names
contains the name of the
selected descriptors at the last iteration (Iteration 1) and
coef(iBART_sim$iBART_model, s = "lambda.min")
shows the
input descriptors at the last iteration (Iteration 1) and their
coefficients. Notice that the first 4 descriptors in
coef(iBART_sim$iBART_model, s = "lambda.min")
are x1, …, x4.
This indicates that iBART discarded x5, …, x10
and kept x1, …, x4
in the loop at Iteration 0.
At Iteration 1, iBART applied unary operators to x1, …, x4,
yielding xi, xi2, exp (xi), sin (πxi), cos (πxi), xi−1, |xi|, for
i = 1, 2, 3, 4. Among them, iBART selected 2 active
intermediate descriptors: exp (x1) and exp (x2), which are
needed to generate f1(X) = {exp (x1) − exp (x2)}2.
This is very promising. Note that we don’t have $\sqrt{x_i}$ and log (xi) here
because $\sqrt{\cdot}$ and log (⋅) are only defined if xi’s are
positive. We can overwrite this by setting
apply_pos_opt_on_neg_x = TRUE
; this effectively generates
$\sqrt{|x_i|}$ and log (|xi|).
To save time, we cached the result of a complete run in
data("iBART_sim", package = "iBART")
, which can be
replicated by using the following code.
iBART_sim <- iBART(X = X, y = y,
name = colnames(X),
opt = c("unary", "binary", "unary"),
sin_cos = TRUE,
apply_pos_opt_on_neg_x = FALSE,
Lzero = TRUE,
K = 4,
standardize = FALSE,
seed = 123)
Let’s load the full result and see how iBART did.
load("../data/iBART_sim.rda") # load full result
iBART_sim$descriptor_names # iBART selected descriptors
#> [1] "(exp(x.1)-exp(x.2))^2" "sin(pi*(x.3*x.4))"
coef(iBART_sim$iBART_model, s = "lambda.min") # iBART model
#> 146 x 1 sparse Matrix of class "dgCMatrix"
#> s1
#> (Intercept) 0.1928037
#> x.1 .
#> x.2 .
#> x.3 .
#> x.4 .
#> exp(x.1) .
#> exp(x.2) .
#> exp(x.3) .
#> exp(x.4) .
#> (x.2+exp(x.2)) .
#> (x.1-exp(x.2)) .
#> (x.2-exp(x.1)) .
#> (exp(x.1)-exp(x.2)) .
#> (x.1*x.2) .
#> (x.2*exp(x.1)) .
#> (x.3*x.4) .
#> (exp(x.1)*exp(x.2)) .
#> (x.2/exp(x.1)) .
#> (exp(x.2)/exp(x.1)) .
#> |x.1-x.2| .
#> |x.3-x.4| .
#> |exp(x.1)-exp(x.2)| .
#> exp(x.1)^0.5 .
#> exp(x.2)^0.5 .
#> exp(x.3)^0.5 .
#> exp(x.4)^0.5 .
#> (exp(x.1)*exp(x.2))^0.5 .
#> (exp(x.2)/exp(x.1))^0.5 .
#> |x.1-x.2|^0.5 .
#> |x.3-x.4|^0.5 .
#> |exp(x.1)-exp(x.2)|^0.5 .
#> x.1^2 .
#> x.2^2 .
#> x.3^2 .
#> x.4^2 .
#> exp(x.1)^2 .
#> exp(x.2)^2 .
#> exp(x.3)^2 .
#> exp(x.4)^2 .
#> (x.2+exp(x.2))^2 .
#> (x.1-exp(x.2))^2 .
#> (x.2-exp(x.1))^2 .
#> (exp(x.1)-exp(x.2))^2 14.7643022
#> (x.1*x.2)^2 .
#> (x.2*exp(x.1))^2 .
#> (x.3*x.4)^2 .
#> (exp(x.1)*exp(x.2))^2 .
#> (x.2/exp(x.1))^2 .
#> (exp(x.2)/exp(x.1))^2 .
#> |x.1-x.2|^2 .
#> |x.3-x.4|^2 .
#> log((exp(x.1)*exp(x.2))) .
#> log((exp(x.2)/exp(x.1))) .
#> log(|x.1-x.2|) .
#> log(|x.3-x.4|) .
#> log(|exp(x.1)-exp(x.2)|) .
#> exp(exp(x.1)) .
#> exp(exp(x.2)) .
#> exp(exp(x.3)) .
#> exp(exp(x.4)) .
#> exp((x.2+exp(x.2))) .
#> exp((x.1-exp(x.2))) .
#> exp((x.2-exp(x.1))) .
#> exp((exp(x.1)-exp(x.2))) .
#> exp((x.1*x.2)) .
#> exp((x.2*exp(x.1))) .
#> exp((x.3*x.4)) .
#> exp((exp(x.1)*exp(x.2))) .
#> exp((x.2/exp(x.1))) .
#> exp((exp(x.2)/exp(x.1))) .
#> exp(|x.1-x.2|) .
#> exp(|x.3-x.4|) .
#> exp(|exp(x.1)-exp(x.2)|) .
#> sin(pi*x.1) .
#> sin(pi*x.2) .
#> sin(pi*x.3) .
#> sin(pi*x.4) .
#> sin(pi*exp(x.1)) .
#> sin(pi*exp(x.2)) .
#> sin(pi*exp(x.3)) .
#> sin(pi*exp(x.4)) .
#> sin(pi*(x.2+exp(x.2))) .
#> sin(pi*(x.1-exp(x.2))) .
#> sin(pi*(x.2-exp(x.1))) .
#> sin(pi*(exp(x.1)-exp(x.2))) .
#> sin(pi*(x.1*x.2)) .
#> sin(pi*(x.2*exp(x.1))) .
#> sin(pi*(x.3*x.4)) 19.5876303
#> sin(pi*(exp(x.1)*exp(x.2))) .
#> sin(pi*(x.2/exp(x.1))) .
#> sin(pi*(exp(x.2)/exp(x.1))) .
#> sin(pi*|x.1-x.2|) .
#> sin(pi*|x.3-x.4|) .
#> sin(pi*|exp(x.1)-exp(x.2)|) .
#> cos(pi*x.1) .
#> cos(pi*x.2) .
#> cos(pi*x.3) .
#> cos(pi*x.4) .
#> cos(pi*exp(x.1)) .
#> cos(pi*exp(x.2)) .
#> cos(pi*exp(x.3)) .
#> cos(pi*exp(x.4)) .
#> cos(pi*(x.2+exp(x.2))) .
#> cos(pi*(x.1-exp(x.2))) .
#> cos(pi*(x.2-exp(x.1))) .
#> cos(pi*(exp(x.1)-exp(x.2))) .
#> cos(pi*(x.1*x.2)) .
#> cos(pi*(x.2*exp(x.1))) .
#> cos(pi*(x.3*x.4)) .
#> cos(pi*(exp(x.1)*exp(x.2))) .
#> cos(pi*(x.2/exp(x.1))) .
#> cos(pi*(exp(x.2)/exp(x.1))) .
#> cos(pi*|x.1-x.2|) .
#> cos(pi*|x.3-x.4|) .
#> x.1^(-1) .
#> x.2^(-1) .
#> x.3^(-1) .
#> x.4^(-1) .
#> exp(x.1)^(-1) .
#> exp(x.2)^(-1) .
#> exp(x.3)^(-1) .
#> exp(x.4)^(-1) .
#> (x.2+exp(x.2))^(-1) .
#> (x.1-exp(x.2))^(-1) .
#> (x.2-exp(x.1))^(-1) .
#> (exp(x.1)-exp(x.2))^(-1) .
#> (x.1*x.2)^(-1) .
#> (x.2*exp(x.1))^(-1) .
#> (x.3*x.4)^(-1) .
#> (exp(x.1)*exp(x.2))^(-1) .
#> (x.2/exp(x.1))^(-1) .
#> (exp(x.2)/exp(x.1))^(-1) .
#> |x.1-x.2|^(-1) .
#> |x.3-x.4|^(-1) .
#> |exp(x.1)-exp(x.2)|^(-1) .
#> abs(x.1) .
#> abs(x.2) .
#> abs(x.3) .
#> abs(x.4) .
#> abs((x.2+exp(x.2))) .
#> abs((x.1-exp(x.2))) .
#> abs((x.2-exp(x.1))) .
#> abs((x.1*x.2)) .
#> abs((x.2*exp(x.1))) .
#> abs((x.3*x.4)) .
#> abs((x.2/exp(x.1))) .
Here iBART generated 145 descriptors in the last iteration, and it correctly identified the true descriptors f1(X) and f2(X) without selecting any false positive. This is very reassuring especially when some of these descriptors are highly correlated with f1(X) or f2(X). For instance, f̃1(X) = |exp (x1) − exp (x2)| in the descriptor space is highly correlated with f1(X).
f1_true <- (exp(X[,1]) - exp(X[,2]))^2
f1_cor <- abs(exp(X[,1]) - exp(X[,2]))
cor(f1_true, f1_cor)
#> [1] 0.9517217
iBART()
also returns other useful and interesting
outputs, such as iBART_sim$iBART_gen_size
and
iBART_sim$iBART_sel_size
. They store the dimension of the
newly generated / selected descriptor space for each iteration. Let’s
plot them and see how iBART use nonparametric variable
selection for dimension reduction. In each iteration, we keep the
dimension of intermediate descriptor space under 𝒪(p2), leading to a
progressive dimension reduction.
library(ggplot2)
df_dim <- data.frame(dim = c(iBART_sim$iBART_sel_size, iBART_sim$iBART_gen_size),
iter = rep(0:3, 2),
type = rep(c("Selected", "Generated"), each = 4))
ggplot(df_dim, aes(x = iter, y = dim, colour = type, group = type)) +
theme(text = element_text(size = 15), legend.title = element_blank()) +
geom_line(size = 1) +
geom_point(size = 3, shape = 21, fill = "white") +
geom_text(data = df_dim, aes(label = dim, y = dim + 10, group = type),
position = position_dodge(0), size = 5, colour = "blue") +
labs(x = "Iteration", y = "Number of descriptors") +
scale_x_continuous(breaks = c(0, 1, 2, 3))
#> Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
#> ℹ Please use `linewidth` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.
sessionInfo()
#> R version 4.4.2 (2024-10-31)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 24.04.1 LTS
#>
#> Matrix products: default
#> BLAS: /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3
#> LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so; LAPACK version 3.12.0
#>
#> locale:
#> [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
#> [3] LC_TIME=en_US.UTF-8 LC_COLLATE=C
#> [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
#> [7] LC_PAPER=en_US.UTF-8 LC_NAME=en_US.UTF-8
#> [9] LC_ADDRESS=en_US.UTF-8 LC_TELEPHONE=en_US.UTF-8
#> [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=en_US.UTF-8
#>
#> time zone: Etc/UTC
#> tzcode source: system (glibc)
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] ggpubr_0.6.0 ggplot2_3.5.1 iBART_1.0.2 rmarkdown_2.29
#>
#> loaded via a namespace (and not attached):
#> [1] tidyr_1.3.1 sass_0.4.9 utf8_1.2.4
#> [4] generics_0.1.3 bartMachineJARs_1.2.1 rstatix_0.7.2
#> [7] shape_1.4.6.1 lattice_0.22-6 digest_0.6.37
#> [10] magrittr_2.0.3 evaluate_1.0.1 grid_4.4.2
#> [13] iterators_1.0.14 fastmap_1.2.0 foreach_1.5.2
#> [16] jsonlite_1.8.9 glmnet_4.1-8 Matrix_1.7-1
#> [19] missForest_1.5 backports_1.5.0 Formula_1.2-5
#> [22] survival_3.7-0 gridExtra_2.3 bartMachine_1.3.4.1
#> [25] purrr_1.0.2 fansi_1.0.6 doRNG_1.8.6
#> [28] itertools_0.1-3 scales_1.3.0 codetools_0.2-20
#> [31] jquerylib_0.1.4 abind_1.4-8 cli_3.6.3
#> [34] rlang_1.1.4 cowplot_1.1.3 munsell_0.5.1
#> [37] splines_4.4.2 withr_3.0.2 cachem_1.1.0
#> [40] yaml_2.3.10 tools_4.4.2 parallel_4.4.2
#> [43] ggsignif_0.6.4 dplyr_1.1.4 colorspace_2.1-1
#> [46] rngtools_1.5.2 broom_1.0.7 buildtools_1.0.0
#> [49] vctrs_0.6.5 R6_2.5.1 lifecycle_1.0.4
#> [52] randomForest_4.7-1.2 car_3.1-3 pkgconfig_2.0.3
#> [55] rJava_1.0-11 bslib_0.8.0 pillar_1.9.0
#> [58] gtable_0.3.6 glue_1.8.0 Rcpp_1.0.13-1
#> [61] tidyselect_1.2.1 xfun_0.49 tibble_3.2.1
#> [64] sys_3.4.3 knitr_1.49 farver_2.1.2
#> [67] htmltools_0.5.8.1 carData_3.0-5 labeling_0.4.3
#> [70] maketools_1.3.1 compiler_4.4.2