Let’s see an example for DALEX
package for classification models for the survival problem for Titanic dataset. Here we are using a dataset titanic
avaliable in the DALEX
package. Note that this data was copied from the stablelearner
package.
library("DALEX")
head(titanic)
#> gender age class embarked country fare sibsp parch survived
#> 1 male 42 3rd Southampton United States 7.11 0 0 no
#> 2 male 13 3rd Southampton United States 20.05 0 2 no
#> 3 male 16 3rd Southampton United States 20.05 1 1 no
#> 4 female 39 3rd Southampton England 20.05 1 1 yes
#> 5 female 16 3rd Southampton Norway 7.13 0 0 yes
#> 6 male 25 3rd Southampton United States 7.13 0 0 yes
Ok, now it’s time to create a model. Let’s use the Random Forest model.
# prepare model
library("randomForest")
titanic <- na.omit(titanic)
model_titanic_rf <- randomForest(survived == "yes" ~ gender + age + class + embarked +
fare + sibsp + parch, data = titanic)
model_titanic_rf
#>
#> Call:
#> randomForest(formula = survived == "yes" ~ gender + age + class + embarked + fare + sibsp + parch, data = titanic)
#> Type of random forest: regression
#> Number of trees: 500
#> No. of variables tried at each split: 2
#>
#> Mean of squared residuals: 0.143236
#> % Var explained: 34.65
The third step (it’s optional but useful) is to create a DALEX
explainer for random forest model.
library("DALEX")
explain_titanic_rf <- explain(model_titanic_rf,
data = titanic[,-9],
y = titanic$survived == "yes",
label = "Random Forest v7")
#> Preparation of a new explainer is initiated
#> -> model label : Random Forest v7
#> -> data : 2099 rows 8 cols
#> -> target variable : 2099 values
#> -> model_info : package randomForest , ver. 4.6.14 , task regression ( [33m default [39m )
#> -> predict function : yhat.randomForest will be used ( [33m default [39m )
#> -> predicted values : numerical, min = 0.01286123 , mean = 0.3248356 , max = 0.9912115
#> -> residual function : difference between y and yhat ( [33m default [39m )
#> -> residuals : numerical, min = -0.779851 , mean = -0.0003954087 , max = 0.9085878
#> [32m A new explainer has been created! [39m
Use the feature_importance()
explainer to present importance of particular features. Note that type = "difference"
normalizes dropouts, and now they all start in 0.
library("ingredients")
fi_rf <- feature_importance(explain_titanic_rf)
head(fi_rf)
#> variable mean_dropout_loss label
#> 1 _full_model_ 0.3332983 Random Forest v7
#> 2 country 0.3332983 Random Forest v7
#> 3 parch 0.3440449 Random Forest v7
#> 4 sibsp 0.3451616 Random Forest v7
#> 5 embarked 0.3503033 Random Forest v7
#> 6 fare 0.3733943 Random Forest v7
plot(fi_rf)
As we see the most important feature is gender
. Next three importnat features are class
, age
and fare
. Let’s see the link between model response and these features.
Such univariate relation can be calculated with partial_dependence()
.
Kids 5 years old and younger have much higher survival probability.
pp_age <- partial_dependence(explain_titanic_rf, variables = c("age", "fare"))
head(pp_age)
#> Top profiles :
#> _vname_ _label_ _x_ _yhat_ _ids_
#> 1 fare Random Forest v7 0.0000000 0.3241036 0
#> 2 age Random Forest v7 0.1666667 0.5364253 0
#> 3 age Random Forest v7 2.0000000 0.5607931 0
#> 4 age Random Forest v7 4.0000000 0.5750886 0
#> 5 fare Random Forest v7 6.1904000 0.3111265 0
#> 6 age Random Forest v7 7.0000000 0.5414633 0
plot(pp_age)
cp_age <- conditional_dependence(explain_titanic_rf, variables = c("age", "fare"))
plot(cp_age)
ap_age <- accumulated_dependence(explain_titanic_rf, variables = c("age", "fare"))
plot(ap_age)
Let’s see break down explanation for model predictions for 8 years old male from 1st class that embarked from port C.
First Ceteris Paribus Profiles for numerical variables
new_passanger <- data.frame(
class = factor("1st", levels = c("1st", "2nd", "3rd", "deck crew", "engineering crew", "restaurant staff", "victualling crew")),
gender = factor("male", levels = c("female", "male")),
age = 8,
sibsp = 0,
parch = 0,
fare = 72,
embarked = factor("Southampton", levels = c("Belfast", "Cherbourg", "Queenstown", "Southampton"))
)
sp_rf <- ceteris_paribus(explain_titanic_rf, new_passanger)
plot(sp_rf) +
show_observations(sp_rf)
And for selected categorical variables. Note, that sibsp is numerical but here is presented as a categorical variable.
plot(sp_rf,
variables = c("class", "embarked", "gender", "sibsp"),
variable_type = "categorical")
It looks like the most important feature for this passenger is age
and sex
. After all his odds for survival are higher than for the average passenger. Mainly because of the young age and despite of being a male.
passangers <- select_sample(titanic, n = 100)
sp_rf <- ceteris_paribus(explain_titanic_rf, passangers)
clust_rf <- cluster_profiles(sp_rf, k = 3)
head(clust_rf)
#> Top profiles :
#> _vname_ _label_ _x_ _cluster_ _yhat_ _ids_
#> 1 fare Random Forest v7_1 0.0000000 1 0.1935959 0
#> 2 sibsp Random Forest v7_1 0.0000000 1 0.1695383 0
#> 3 parch Random Forest v7_1 0.0000000 1 0.1672070 0
#> 4 age Random Forest v7_1 0.1666667 1 0.4664651 0
#> 5 parch Random Forest v7_1 0.2800000 1 0.1671393 0
#> 6 sibsp Random Forest v7_1 1.0000000 1 0.1608335 0
plot(sp_rf, alpha = 0.1) +
show_aggregated_profiles(clust_rf, color = "_label_", size = 2)
sessionInfo()
#> R version 3.6.1 (2019-07-05)
#> Platform: x86_64-apple-darwin15.6.0 (64-bit)
#> Running under: macOS Catalina 10.15.3
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRblas.0.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] C/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] ggplot2_3.2.1 randomForest_4.6-14 ingredients_1.1
#> [4] DALEX_1.0
#>
#> loaded via a namespace (and not attached):
#> [1] Rcpp_1.0.3 pillar_1.4.3 compiler_3.6.1 tools_3.6.1
#> [5] digest_0.6.23 evaluate_0.14 lifecycle_0.1.0 tibble_2.1.3
#> [9] gtable_0.3.0 pkgconfig_2.0.3 rlang_0.4.2 yaml_2.2.0
#> [13] xfun_0.11 withr_2.1.2 stringr_1.4.0 dplyr_0.8.3
#> [17] knitr_1.28 grid_3.6.1 tidyselect_0.2.5 glue_1.3.1
#> [21] R6_2.4.1 rmarkdown_1.16 purrr_0.3.3 farver_2.0.3
#> [25] magrittr_1.5 scales_1.1.0 htmltools_0.4.0 assertthat_0.2.1
#> [29] colorspace_1.4-1 labeling_0.3 stringi_1.4.5 lazyeval_0.2.2
#> [33] munsell_0.5.0 crayon_1.3.4