Turn regression rule models into tidy tibbles
# S3 method for cubist
tidy(x, ...)# S3 method for xrf
tidy(x, penalty = NULL, unit = c("rules", "columns"), ...)
A Cubist
or xrf
object.
Not currently used.
A single numeric value for the lambda
penalty value.
What data should be returned? For unit = 'rules'
, each row
corresponds to a rule. For unit = 'columns'
, each row is a predictor
column. The latter can be helpful when determining variable importance.
The Cubist method has columns committee
, rule_num
, rule
, estimate
,
and statistics
. The latter two are nested tibbles. estimate
contains
the parameter estimates for each term in the regression model and statistics
has statistics about the data selected by the rules and the model fit.
The xrf
results has columns rule_id
, rule
, and estimate
. The
rule_id
column has the rule identifier (e.g., "r0_21") or the feature
column name when the column is added directly into the model. For multiclass
models, a class
column is included.
In each case, the rule
column has a character string with the rule
conditions. These can be converted to an R expression using
rlang::parse_expr()
.
library(dplyr)data(ames, package = "modeldata")
ames <- ames %>% mutate(Sale_Price = log10(ames$Sale_Price), Gr_Liv_Area = log10(ames$Gr_Liv_Area))
# ------------------------------------------------------------------------------
cb_fit <- cubist_rules(committees = 10) %>% set_engine("Cubist") %>% fit(Sale_Price ~ Neighborhood + Longitude + Latitude + Gr_Liv_Area + Central_Air, data = ames)
cb_res <- tidy(cb_fit) cb_res
## # A tibble: 157 <U+00D7> 5 ## committee rule_num rule estimate statistic ## <int> <int> <chr> <list> <list> ## 1 1 1 ( Central_Air == 'N' ) & ( Gr_Liv_Area<U+2026> <tibble> <tibble> ## 2 1 2 ( Gr_Liv_Area <= 3.0326188 ) & ( Neigh<U+2026> <tibble> <tibble> ## 3 1 3 ( Neighborhood %in% c( 'Old_Town','Ed<U+2026> <tibble> <tibble> ## 4 1 4 ( Neighborhood %in% c( 'Old_Town','Ed<U+2026> <tibble> <tibble> ## 5 1 5 ( Central_Air == 'N' ) & ( Gr_Liv_Area<U+2026> <tibble> <tibble> ## 6 1 6 ( Longitude <= -93.652023 ) & ( Neighb<U+2026> <tibble> <tibble> ## 7 1 7 ( Gr_Liv_Area > 3.2284005 ) & ( Neighb<U+2026> <tibble> <tibble> ## 8 1 8 ( Neighborhood %in% c( 'North_Ames','<U+2026> <tibble> <tibble> ## 9 1 9 ( Latitude <= 42.009399 ) & ( Neighbor<U+2026> <tibble> <tibble> ## 10 1 10 ( Neighborhood %in% c( 'College_Creek<U+2026> <tibble> <tibble> ## # <U+2026> with 147 more rows
cb_res$estimate[[1]]
## # A tibble: 4 <U+00D7> 2 ## term estimate ## <chr> <dbl> ## 1 (Intercept) -408. ## 2 Longitude -1.43 ## 3 Latitude 6.6 ## 4 Gr_Liv_Area 0.7
cb_res$statistic[[1]]
## # A tibble: 1 <U+00D7> 6 ## num_conditions coverage mean min max error ## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 2 154 4.94 4.11 5.31 0.0956
# ------------------------------------------------------------------------------library(recipes)
xrf_reg_mod <- rule_fit(trees = 10, penalty = .001) %>% set_engine("xrf") %>% set_mode("regression")
# Make dummy variables since xgboost will not ames_rec <- recipe(Sale_Price ~ Neighborhood + Longitude + Latitude + Gr_Liv_Area + Central_Air, data = ames) %>% step_dummy(Neighborhood, Central_Air) %>% step_zv(all_predictors())
ames_processed <- prep(ames_rec) %>% bake(new_data = NULL)
set.seed(1) xrf_reg_fit <- xrf_reg_mod %>% fit(Sale_Price ~ ., data = ames_processed)
xrf_rule_res <- tidy(xrf_reg_fit) xrf_rule_res$rule[nrow(xrf_rule_res)] %>% rlang::parse_expr()
## (Gr_Liv_Area < 3.30210185) & (Gr_Liv_Area < 3.38872266) & (Gr_Liv_Area >= ## 2.94571471) & (Gr_Liv_Area >= 3.24870872) & (Latitude < 42.0271072) & ## (Neighborhood_Old_Town >= -9.53674316e-07)
xrf_col_res <- tidy(xrf_reg_fit, unit = "columns") xrf_col_res
## # A tibble: 149 <U+00D7> 3 ## rule_id term estimate ## <chr> <chr> <dbl> ## 1 r0_1 Gr_Liv_Area -1.27e- 2 ## 2 r2_4 Gr_Liv_Area -3.70e-10 ## 3 r2_2 Gr_Liv_Area 7.59e- 3 ## 4 r2_4 Central_Air_Y -3.70e-10 ## 5 r3_5 Longitude 1.06e- 1 ## 6 r3_6 Longitude 2.65e- 2 ## 7 r3_5 Latitude 1.06e- 1 ## 8 r3_6 Latitude 2.65e- 2 ## 9 r3_5 Longitude 1.06e- 1 ## 10 r3_6 Longitude 2.65e- 2 ## # <U+2026> with 139 more rows