Learn R Programming

qCBA (version 1.0.2)

sbrlModel2arcCBARuleModel: sbrlModel2arcCBARuleModel Converts a model created by sbrl so that it can be passed to qCBA

Description

Creates instance of CBAmodel class from the arc package. Instance of CBAmodel can then be passed to qcba

Usage

sbrlModel2arcCBARuleModel(
  sbrl_model,
  cutPoints,
  rawDataset,
  classAtt,
  attTypes
)

Arguments

sbrl_model

object returned by arulesCBA::CBA()

cutPoints

specification of cutpoints applied on the data before they were passed to rCBA::build

rawDataset

the raw data (before discretization). This dataset is used to guess attribute types if attTypes is not passed

classAtt

the name of the class attribute

attTypes

vector of attribute types of the original data. If set to null, you need to pass rawDataset.

Examples

Run this code
if (! requireNamespace("rCBA", quietly = TRUE)) {
  message("Please install rCBA to allow for sbrl model conversion")
  return()
} else if (! requireNamespace("sbrl", quietly = TRUE)) {
  message("Please install sbrl to allow for postprocessing of sbrl models")
} else
{
#' # This will run only outside a CRAN test, if the environment variable  NOT_CRAN is set to true
# This environment variable is set by devtools
if (identical(Sys.getenv("NOT_CRAN"), "true")) {
  library(sbrl)
  library(rCBA)
  # sbrl handles only binary problems, iris has 3 target classes - remove one class
  set.seed(getOption("qcba.seed"))
  allData <- datasets::iris[sample(nrow(datasets::iris)),]
  classToExclude<-"versicolor"
  allData <- allData[allData$Species!=classToExclude, ]
  # drop the removed level
  allData$Species <-allData$Species [, drop=TRUE]
  trainFold <- allData[1:50,]
  testFold <- allData[51:nrow(allData),]
  sbrlFixedLabel<-"label"
  origLabel<-"Species"

  orignames<-colnames(trainFold)
  orignames[which(orignames == origLabel)]<-sbrlFixedLabel
  colnames(trainFold)<-orignames
  colnames(testFold)<-orignames

  # to recode label to binary values:
  # first create dict mapping from original distinct class values to 0,1 
  origval<-levels(as.factor(trainFold$label))
  newval<-range(0,1)
  dict<-data.frame(origval,newval)
  # then apply dict to train and test fold
  trainFold$label<-dict[match(trainFold$label, dict$origval), 2]
  testFold$label<-dict[match(testFold$label, dict$origval), 2]

  # discretize training data
  trainFoldDiscTemp <- discrNumeric(trainFold, sbrlFixedLabel)
  trainFoldDiscCutpoints <- trainFoldDiscTemp$cutp
  trainFoldDisc <- as.data.frame(lapply(trainFoldDiscTemp$Disc.data, as.factor))

  # discretize test data
  testFoldDisc <- applyCuts(testFold, trainFoldDiscCutpoints, infinite_bounds=TRUE, labels=TRUE)
  # SBRL 1.4 crashes if features contain a space
  # even if these features are converted to factors,
  # to circumvent this, it is necessary to replace spaces
  trainFoldDisc <- as.data.frame(lapply(trainFoldDisc, function(x) gsub(" ", "", as.character(x))))
  for (name in names(trainFoldDisc)) {trainFoldDisc[name] <- as.factor(trainFoldDisc[,name])}
  # learn sbrl model, rule_minlen is increased to demonstrate the effect of postprocessing 
  sbrl_model <- sbrl(trainFoldDisc, iters=20000, pos_sign="0", 
   neg_sign="1", rule_minlen=3, rule_maxlen=5, minsupport_pos=0.05, minsupport_neg=0.05, 
   lambda=20.0, eta=5.0, nchain=25)
  # apply sbrl model on a test fold
  yhat <- predict(sbrl_model, testFoldDisc)
  yvals<- as.integer(yhat$V1>0.5)
  sbrl_acc<-mean(as.integer(yvals == testFoldDisc$label))
  message("SBRL RESULT")
  message(sbrl_model)
  rm_sbrl<-sbrlModel2arcCBARuleModel(sbrl_model,trainFoldDiscCutpoints,trainFold,sbrlFixedLabel) 
  message(paste("sbrl acc=",sbrl_acc,", sbrl rule count=",nrow(sbrl_model$rs), ",
  avg condition count (incl. default rule)", 
  sum(rm_sbrl@rules@lhs@data)/length(rm_sbrl@rules)))
  rmQCBA_sbrl <- qcba(cbaRuleModel=rm_sbrl,datadf=trainFold)
  prediction <- predict(rmQCBA_sbrl,testFold)
  acc_qcba_sbrl <- CBARuleModelAccuracy(prediction, testFold[[rmQCBA_sbrl@classAtt]])
  avg_rule_length <- rmQCBA_sbrl@rules$condition_count/nrow(rmQCBA_sbrl@rules)
  message("RESULT of QCBA postprocessing of SBRL")
  message(rmQCBA_sbrl@rules)
  message(paste("QCBA after SBRL acc=",acc_qcba_sbrl,", rule count=",
  rmQCBA_sbrl@ruleCount, ", avg condition count (incl. default rule)",  avg_rule_length))
  unlink("tdata_R.label") # delete temp files created by SBRL
  unlink("tdata_R.out")
 }
}

Run the code above in your browser using DataLab