Complex Model Simulation

Package Loading

Before loading the package, we should allocate enough memory for Java. Here we allocate 10GB of memory for Java.

set.seed(123)
# Allocate 10GB of memory for Java. Must be called before library(iBART)
options(java.parameters = "-Xmx10g") 
library(iBART)

Complex Model

In this vignette, we will run iBART on the complex model described in Section 3.4 of the paper, i.e. the data-generating model is y = 15{exp (x1) − exp (x2)}2 + 20sin (πx3x4) + ε,   ε ∼ 𝒩n(0, σ2I). The primary features are X = (x1, ..., xp), where $x_1,...,x_p \overset{\text{iid}}\sim\text{Unif}_n(-1,1)$. We will use the following setting: n = 250, p = 10, and σ = 0.5. The goal in OIS is to identify the 2 true descriptors: f1(X) = {exp (x1) − exp (x2)}2 and f2(X) = sin (πx3x4) using only (y, X) as input. Let’s generate the primary features X and the response variable y.

#### Simulation Parameters ####
n <- 250 # Change n to 100 here to reproduce result in Supplementary Materials A.2.3
p <- 10  # Number of primary features

#### Generate Data ####
X <- matrix(runif(n * p, min = -1, max = 1), nrow = n, ncol = p)
colnames(X) <- paste("x.", seq(from = 1, to = p, by = 1), sep = "")
y <- 15 * (exp(X[, 1]) - exp(X[, 2]))^2 + 20 * sin(pi * X[, 3] * X[, 4]) + rnorm(n, mean = 0, sd = 0.5)

Note that the input data to iBART are only y and X = (x1, ..., x10), and iBART needs to

  1. Generate the correct descriptors f1(X) and f2(X)
  2. Select the correct descriptors

At Iteration 0 (base iteration), iBART determines which of the primary features, (x1, …, x10), are useful and only apply operators on the useful ones. If successful, iBART should keep (x1, …, x4) in the loop and discard (x5, …, x10). Let’s run iBART for 1 iteration (Iteration 0 + Iteration 1) and examine its outputs.

#### iBART ####
iBART_sim <- iBART(X = X, y = y,
                   name = colnames(X),
                   num_burn_in = 5000,                   # lower value for faster run
                   num_iterations_after_burn_in = 1000,  # lower value for faster run
                   num_permute_samples = 20,             # lower value for faster run
                   opt = c("unary"), # only apply unary operators after base iteration
                   sin_cos = TRUE,
                   apply_pos_opt_on_neg_x = FALSE,
                   Lzero = TRUE,
                   K = 4,
                   standardize = FALSE,
                   seed = 123)
#> Start iBART descriptor generation and selection... 
#> Iteration 1 
#> iBART descriptor selection... 
#> avg..........null....................
#> Constructing descriptors using unary operators... 
#> Building X.unary... Initial p = 4; New p = 28 
#> BART iteration done! 
#> LASSO descriptor selection... 
#> L-zero regression... 
#> Total time: 23.2862904071808 secs

iBART() returns a list object that contains many interesting outputs; see ?iBART::iBART for a full list of return values. Here we focus on 2 return values:

  • iBART_sim$descriptor_names: the descriptors selected by iBART
  • iBART_sim$iBART_model: the selected model—a cv.glmnet object

We can use the iBART model the same way we would use a glmnet model. For instance, we can print out the coefficients using coef().

# iBART selected descriptors
iBART_sim$descriptor_names
#>  [1] "(x.1)^2"     "(x.2)^2"     "(x.4)^2"     "exp(x.1)"    "exp(x.2)"   
#>  [6] "exp(x.3)"    "exp(x.4)"    "sin(pi*x.1)" "sin(pi*x.2)" "sin(pi*x.3)"
#> [11] "sin(pi*x.4)" "cos(pi*x.3)" "(1/x.1)"     "(1/x.2)"     "(1/x.3)"    
#> [16] "(1/x.4)"     "abs(x.2)"

# iBART model
coef(iBART_sim$iBART_model, s = "lambda.min")
#> 29 x 1 sparse Matrix of class "dgCMatrix"
#>                      s1
#> (Intercept) -8.98377149
#> x.1          .         
#> x.2          .         
#> x.3          .         
#> x.4          .         
#> (x.1)^2     18.48979389
#> (x.2)^2      8.44944371
#> (x.3)^2      .         
#> (x.4)^2     -2.91764266
#> exp(x.1)     5.51582288
#> exp(x.2)     9.74121269
#> exp(x.3)    -0.49444789
#> exp(x.4)    -5.47704491
#> sin(pi*x.1) -2.22769248
#> sin(pi*x.2) -7.06001109
#> sin(pi*x.3) -1.39995518
#> sin(pi*x.4)  5.10808705
#> cos(pi*x.1)  .         
#> cos(pi*x.2)  .         
#> cos(pi*x.3)  0.23919051
#> cos(pi*x.4)  .         
#> (1/x.1)     -0.05683143
#> (1/x.2)      0.04017341
#> (1/x.3)     -0.01126814
#> (1/x.4)     -0.03267275
#> abs(x.1)     .         
#> abs(x.2)     2.96458665
#> abs(x.3)     .         
#> abs(x.4)     .

iBART_sim$descriptor_names contains the name of the selected descriptors at the last iteration (Iteration 1) and coef(iBART_sim$iBART_model, s = "lambda.min") shows the input descriptors at the last iteration (Iteration 1) and their coefficients. Notice that the first 4 descriptors in coef(iBART_sim$iBART_model, s = "lambda.min") are x1, …, x4. This indicates that iBART discarded x5, …, x10 and kept x1, …, x4 in the loop at Iteration 0.

At Iteration 1, iBART applied unary operators to x1, …, x4, yielding xi, xi2, exp (xi), sin (πxi), cos (πxi), xi−1, |xi|,   for i = 1, 2, 3, 4. Among them, iBART selected 2 active intermediate descriptors: exp (x1) and exp (x2), which are needed to generate f1(X) = {exp (x1) − exp (x2)}2. This is very promising. Note that we don’t have $\sqrt{x_i}$ and log (xi) here because $\sqrt{\cdot}$ and log (⋅) are only defined if xi’s are positive. We can overwrite this by setting apply_pos_opt_on_neg_x = TRUE; this effectively generates $\sqrt{|x_i|}$ and log (|xi|).

To save time, we cached the result of a complete run in data("iBART_sim", package = "iBART"), which can be replicated by using the following code.

iBART_sim <- iBART(X = X, y = y,
                   name = colnames(X),
                   opt = c("unary", "binary", "unary"), 
                   sin_cos = TRUE,
                   apply_pos_opt_on_neg_x = FALSE,
                   Lzero = TRUE,
                   K = 4,
                   standardize = FALSE,
                   seed = 123)

Let’s load the full result and see how iBART did.

load("../data/iBART_sim.rda")                 # load full result

iBART_sim$descriptor_names                    # iBART selected descriptors
#> [1] "(exp(x.1)-exp(x.2))^2" "sin(pi*(x.3*x.4))"
coef(iBART_sim$iBART_model, s = "lambda.min") # iBART model
#> 146 x 1 sparse Matrix of class "dgCMatrix"
#>                                     s1
#> (Intercept)                  0.1928037
#> x.1                          .        
#> x.2                          .        
#> x.3                          .        
#> x.4                          .        
#> exp(x.1)                     .        
#> exp(x.2)                     .        
#> exp(x.3)                     .        
#> exp(x.4)                     .        
#> (x.2+exp(x.2))               .        
#> (x.1-exp(x.2))               .        
#> (x.2-exp(x.1))               .        
#> (exp(x.1)-exp(x.2))          .        
#> (x.1*x.2)                    .        
#> (x.2*exp(x.1))               .        
#> (x.3*x.4)                    .        
#> (exp(x.1)*exp(x.2))          .        
#> (x.2/exp(x.1))               .        
#> (exp(x.2)/exp(x.1))          .        
#> |x.1-x.2|                    .        
#> |x.3-x.4|                    .        
#> |exp(x.1)-exp(x.2)|          .        
#> exp(x.1)^0.5                 .        
#> exp(x.2)^0.5                 .        
#> exp(x.3)^0.5                 .        
#> exp(x.4)^0.5                 .        
#> (exp(x.1)*exp(x.2))^0.5      .        
#> (exp(x.2)/exp(x.1))^0.5      .        
#> |x.1-x.2|^0.5                .        
#> |x.3-x.4|^0.5                .        
#> |exp(x.1)-exp(x.2)|^0.5      .        
#> x.1^2                        .        
#> x.2^2                        .        
#> x.3^2                        .        
#> x.4^2                        .        
#> exp(x.1)^2                   .        
#> exp(x.2)^2                   .        
#> exp(x.3)^2                   .        
#> exp(x.4)^2                   .        
#> (x.2+exp(x.2))^2             .        
#> (x.1-exp(x.2))^2             .        
#> (x.2-exp(x.1))^2             .        
#> (exp(x.1)-exp(x.2))^2       14.7643022
#> (x.1*x.2)^2                  .        
#> (x.2*exp(x.1))^2             .        
#> (x.3*x.4)^2                  .        
#> (exp(x.1)*exp(x.2))^2        .        
#> (x.2/exp(x.1))^2             .        
#> (exp(x.2)/exp(x.1))^2        .        
#> |x.1-x.2|^2                  .        
#> |x.3-x.4|^2                  .        
#> log((exp(x.1)*exp(x.2)))     .        
#> log((exp(x.2)/exp(x.1)))     .        
#> log(|x.1-x.2|)               .        
#> log(|x.3-x.4|)               .        
#> log(|exp(x.1)-exp(x.2)|)     .        
#> exp(exp(x.1))                .        
#> exp(exp(x.2))                .        
#> exp(exp(x.3))                .        
#> exp(exp(x.4))                .        
#> exp((x.2+exp(x.2)))          .        
#> exp((x.1-exp(x.2)))          .        
#> exp((x.2-exp(x.1)))          .        
#> exp((exp(x.1)-exp(x.2)))     .        
#> exp((x.1*x.2))               .        
#> exp((x.2*exp(x.1)))          .        
#> exp((x.3*x.4))               .        
#> exp((exp(x.1)*exp(x.2)))     .        
#> exp((x.2/exp(x.1)))          .        
#> exp((exp(x.2)/exp(x.1)))     .        
#> exp(|x.1-x.2|)               .        
#> exp(|x.3-x.4|)               .        
#> exp(|exp(x.1)-exp(x.2)|)     .        
#> sin(pi*x.1)                  .        
#> sin(pi*x.2)                  .        
#> sin(pi*x.3)                  .        
#> sin(pi*x.4)                  .        
#> sin(pi*exp(x.1))             .        
#> sin(pi*exp(x.2))             .        
#> sin(pi*exp(x.3))             .        
#> sin(pi*exp(x.4))             .        
#> sin(pi*(x.2+exp(x.2)))       .        
#> sin(pi*(x.1-exp(x.2)))       .        
#> sin(pi*(x.2-exp(x.1)))       .        
#> sin(pi*(exp(x.1)-exp(x.2)))  .        
#> sin(pi*(x.1*x.2))            .        
#> sin(pi*(x.2*exp(x.1)))       .        
#> sin(pi*(x.3*x.4))           19.5876303
#> sin(pi*(exp(x.1)*exp(x.2)))  .        
#> sin(pi*(x.2/exp(x.1)))       .        
#> sin(pi*(exp(x.2)/exp(x.1)))  .        
#> sin(pi*|x.1-x.2|)            .        
#> sin(pi*|x.3-x.4|)            .        
#> sin(pi*|exp(x.1)-exp(x.2)|)  .        
#> cos(pi*x.1)                  .        
#> cos(pi*x.2)                  .        
#> cos(pi*x.3)                  .        
#> cos(pi*x.4)                  .        
#> cos(pi*exp(x.1))             .        
#> cos(pi*exp(x.2))             .        
#> cos(pi*exp(x.3))             .        
#> cos(pi*exp(x.4))             .        
#> cos(pi*(x.2+exp(x.2)))       .        
#> cos(pi*(x.1-exp(x.2)))       .        
#> cos(pi*(x.2-exp(x.1)))       .        
#> cos(pi*(exp(x.1)-exp(x.2)))  .        
#> cos(pi*(x.1*x.2))            .        
#> cos(pi*(x.2*exp(x.1)))       .        
#> cos(pi*(x.3*x.4))            .        
#> cos(pi*(exp(x.1)*exp(x.2)))  .        
#> cos(pi*(x.2/exp(x.1)))       .        
#> cos(pi*(exp(x.2)/exp(x.1)))  .        
#> cos(pi*|x.1-x.2|)            .        
#> cos(pi*|x.3-x.4|)            .        
#> x.1^(-1)                     .        
#> x.2^(-1)                     .        
#> x.3^(-1)                     .        
#> x.4^(-1)                     .        
#> exp(x.1)^(-1)                .        
#> exp(x.2)^(-1)                .        
#> exp(x.3)^(-1)                .        
#> exp(x.4)^(-1)                .        
#> (x.2+exp(x.2))^(-1)          .        
#> (x.1-exp(x.2))^(-1)          .        
#> (x.2-exp(x.1))^(-1)          .        
#> (exp(x.1)-exp(x.2))^(-1)     .        
#> (x.1*x.2)^(-1)               .        
#> (x.2*exp(x.1))^(-1)          .        
#> (x.3*x.4)^(-1)               .        
#> (exp(x.1)*exp(x.2))^(-1)     .        
#> (x.2/exp(x.1))^(-1)          .        
#> (exp(x.2)/exp(x.1))^(-1)     .        
#> |x.1-x.2|^(-1)               .        
#> |x.3-x.4|^(-1)               .        
#> |exp(x.1)-exp(x.2)|^(-1)     .        
#> abs(x.1)                     .        
#> abs(x.2)                     .        
#> abs(x.3)                     .        
#> abs(x.4)                     .        
#> abs((x.2+exp(x.2)))          .        
#> abs((x.1-exp(x.2)))          .        
#> abs((x.2-exp(x.1)))          .        
#> abs((x.1*x.2))               .        
#> abs((x.2*exp(x.1)))          .        
#> abs((x.3*x.4))               .        
#> abs((x.2/exp(x.1)))          .

Here iBART generated 145 descriptors in the last iteration, and it correctly identified the true descriptors f1(X) and f2(X) without selecting any false positive. This is very reassuring especially when some of these descriptors are highly correlated with f1(X) or f2(X). For instance, 1(X) = |exp (x1) − exp (x2)| in the descriptor space is highly correlated with f1(X).

f1_true <- (exp(X[,1]) - exp(X[,2]))^2
f1_cor <- abs(exp(X[,1]) - exp(X[,2]))
cor(f1_true, f1_cor)
#> [1] 0.9517217

iBART() also returns other useful and interesting outputs, such as iBART_sim$iBART_gen_size and iBART_sim$iBART_sel_size. They store the dimension of the newly generated / selected descriptor space for each iteration. Let’s plot them and see how iBART use nonparametric variable selection for dimension reduction. In each iteration, we keep the dimension of intermediate descriptor space under 𝒪(p2), leading to a progressive dimension reduction.

library(ggplot2)
df_dim <- data.frame(dim = c(iBART_sim$iBART_sel_size, iBART_sim$iBART_gen_size),
                     iter = rep(0:3, 2),
                     type = rep(c("Selected", "Generated"), each = 4))
ggplot(df_dim, aes(x = iter, y = dim, colour = type, group = type)) +
  theme(text = element_text(size = 15), legend.title = element_blank()) +
  geom_line(size = 1) +
  geom_point(size = 3, shape = 21, fill = "white") +
  geom_text(data = df_dim, aes(label = dim, y = dim + 10, group = type),
            position = position_dodge(0), size = 5, colour = "blue") +
  labs(x = "Iteration", y = "Number of descriptors") +
  scale_x_continuous(breaks = c(0, 1, 2, 3))
#> Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
#> ℹ Please use `linewidth` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.

R Session Info

sessionInfo()
#> R version 4.4.2 (2024-10-31)
#> Platform: x86_64-pc-linux-gnu
#> Running under: Ubuntu 24.04.1 LTS
#> 
#> Matrix products: default
#> BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 
#> LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so;  LAPACK version 3.12.0
#> 
#> locale:
#>  [1] LC_CTYPE=en_US.UTF-8          LC_NUMERIC=C                 
#>  [3] LC_TIME=en_US.UTF-8           LC_COLLATE=C                 
#>  [5] LC_MONETARY=en_US.UTF-8       LC_MESSAGES=en_US.UTF-8      
#>  [7] LC_PAPER=en_US.UTF-8          LC_NAME=en_US.UTF-8          
#>  [9] LC_ADDRESS=en_US.UTF-8        LC_TELEPHONE=en_US.UTF-8     
#> [11] LC_MEASUREMENT=en_US.UTF-8    LC_IDENTIFICATION=en_US.UTF-8
#> 
#> time zone: Etc/UTC
#> tzcode source: system (glibc)
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] ggpubr_0.6.0   ggplot2_3.5.1  iBART_1.0.2    rmarkdown_2.29
#> 
#> loaded via a namespace (and not attached):
#>  [1] tidyr_1.3.1           sass_0.4.9            utf8_1.2.4           
#>  [4] generics_0.1.3        bartMachineJARs_1.2.1 rstatix_0.7.2        
#>  [7] shape_1.4.6.1         lattice_0.22-6        digest_0.6.37        
#> [10] magrittr_2.0.3        evaluate_1.0.1        grid_4.4.2           
#> [13] iterators_1.0.14      fastmap_1.2.0         foreach_1.5.2        
#> [16] jsonlite_1.8.9        glmnet_4.1-8          Matrix_1.7-1         
#> [19] missForest_1.5        backports_1.5.0       Formula_1.2-5        
#> [22] survival_3.7-0        gridExtra_2.3         bartMachine_1.3.4.1  
#> [25] purrr_1.0.2           fansi_1.0.6           doRNG_1.8.6          
#> [28] itertools_0.1-3       scales_1.3.0          codetools_0.2-20     
#> [31] jquerylib_0.1.4       abind_1.4-8           cli_3.6.3            
#> [34] rlang_1.1.4           cowplot_1.1.3         munsell_0.5.1        
#> [37] splines_4.4.2         withr_3.0.2           cachem_1.1.0         
#> [40] yaml_2.3.10           tools_4.4.2           parallel_4.4.2       
#> [43] ggsignif_0.6.4        dplyr_1.1.4           colorspace_2.1-1     
#> [46] rngtools_1.5.2        broom_1.0.7           buildtools_1.0.0     
#> [49] vctrs_0.6.5           R6_2.5.1              lifecycle_1.0.4      
#> [52] randomForest_4.7-1.2  car_3.1-3             pkgconfig_2.0.3      
#> [55] rJava_1.0-11          bslib_0.8.0           pillar_1.9.0         
#> [58] gtable_0.3.6          glue_1.8.0            Rcpp_1.0.13-1        
#> [61] tidyselect_1.2.1      xfun_0.49             tibble_3.2.1         
#> [64] sys_3.4.3             knitr_1.49            farver_2.1.2         
#> [67] htmltools_0.5.8.1     carData_3.0-5         labeling_0.4.3       
#> [70] maketools_1.3.1       compiler_4.4.2