library(dplyr)
library(prettyglm)
data('titanic')
columns_to_factor <- c('Pclass',
'Sex',
'Cabin',
'Embarked',
'Cabintype',
'Survived')
meanage <- base::mean(titanic$Age, na.rm=TRUE)
titanic <- titanic %>%
dplyr::mutate_at(columns_to_factor, list(~factor(.))) %>%
dplyr::mutate(Age =base::ifelse(is.na(Age)==TRUE,meanage,Age)) %>%
dplyr::mutate(Age_0_25 = prettyglm::splineit(Age,0,25),
Age_25_50 = prettyglm::splineit(Age,25,50),
Age_50_120 = prettyglm::splineit(Age,50,120)) %>%
dplyr::mutate(Fare_0_250 = prettyglm::splineit(Fare,0,250),
Fare_250_600 = prettyglm::splineit(Fare,250,600))
survival_model <- stats::glm(Survived ~
Sex:Age +
Fare +
Embarked +
SibSp +
Parch +
Cabintype,
data = titanic,
family = binomial(link = 'logit'))
# Continuous Variable Example
one_way_ave(feature_to_plot = 'Age',
model_object = survival_model,
target_variable = 'Survived',
data_set = titanic,
number_of_buckets = 20,
upper_percentile_to_cut = 0.1,
lower_percentile_to_cut = 0.1)
# Discrete Variable Example
one_way_ave(feature_to_plot = 'Pclass',
model_object = survival_model,
target_variable = 'Survived',
data_set = titanic)
# Custom Predict Function and facet
a_custom_predict_function <- function(target, model_object, dataset){
dataset <- base::as.data.frame(dataset)
Actual_Values <- dplyr::pull(dplyr::select(dataset, tidyselect::all_of(c(target))))
if(class(Actual_Values) == 'factor'){
Actual_Values <- base::as.numeric(as.character(Actual_Values))
}
Predicted_Values <- base::as.numeric(stats::predict(model_object, dataset, type='response'))
to_return <- base::data.frame(Actual_Values = Actual_Values,
Predicted_Values = Predicted_Values)
to_return <- to_return %>%
dplyr::mutate(Predicted_Values = base::ifelse(Predicted_Values > 0.3,0.3,Predicted_Values))
return(to_return)
}
one_way_ave(feature_to_plot = 'Age',
model_object = survival_model,
target_variable = 'Survived',
data_set = titanic,
number_of_buckets = 20,
upper_percentile_to_cut = 0.1,
lower_percentile_to_cut = 0.1,
predict_function = a_custom_predict_function,
facetby = 'Pclass')
Run the code above in your browser using DataLab