Learn R Programming

mildsvm (version 0.4.1)

predict.smm: Predict method for smm object

Description

Predict method for smm object

Usage

# S3 method for smm
predict(
  object,
  new_data,
  type = c("class", "raw"),
  layer = "instance",
  new_instances = "instance_name",
  new_bags = "bag_name",
  kernel = NULL,
  ...
)

Value

tibble with nrow(new_data) rows. If type = 'class', the tibble will have a column named .pred_class. If type = 'raw', the tibble will have a column name .pred.

Arguments

object

an object of class smm

new_data

A data frame to predict from. This needs to have all of the features that the data was originally fitted with.

type

If 'class', return predicted values with threshold of 0 as -1 or +1. If 'raw', return the raw predicted scores.

layer

If 'instance', return predictions at the instance level. Option 'bag' returns predictions at the bag level, but only if the model was fit with smm.mild_df(),

new_instances

A character or character vector. Can specify a singular character that provides the column name for the instance names in new_data (default 'instance_name'). Can also specify a vector of length nrow(new_data) that has instance name for each row.

new_bags

A character or character vector. Only relevant when fit with smm.mild_df(), which contains bag level information. Can specify a singular character that provides the column name for the bag names in new_data, default = "bag_name". Can also specify a vector of length nrow(new_data) that has bag name for each instance.

kernel

An optional pre-computed kernel matrix at the instance level or NULL (default NULL). The rows should correspond to instances in the new data to predict, and columns should correspond to instances in the original training data, such as a call to kme().

...

Arguments passed to or from other methods.

Author

Sean Kent

Details

When the object was fitted using the formula method, then the parameters new_bags and new_instances are not necessary, as long as the names match the original function call.

See Also

smm() for fitting the smm object.

Examples

Run this code
set.seed(8)
n_instances <- 10
n_samples <- 20
y <- rep(c(1, -1), each = n_samples * n_instances / 2)
instances <- as.character(rep(1:n_instances, each = n_samples))
x <- data.frame(x1 = rnorm(length(y), mean = 1*(y==1)),
                x2 = rnorm(length(y), mean = 2*(y==1)),
                x3 = rnorm(length(y), mean = 3*(y==1)))

mdl <- smm(x, y, instances, control = list(sigma = 1/3))

# instance level predictions (training data)
suppressWarnings(library(dplyr))
data.frame(instance_name = instances, y = y, x) %>%
  bind_cols(predict(mdl, type = "raw", new_data = x, new_instances = instances)) %>%
  bind_cols(predict(mdl, type = "class", new_data = x, new_instances = instances)) %>%
  distinct(instance_name, y, .pred, .pred_class)

# test data
new_inst <- rep(c("11", "12"), each = 30)
new_y <- rep(c(1, -1), each = 30)
new_x <- data.frame(x1 = rnorm(length(new_inst), mean = 1*(new_inst=="11")),
                    x2 = rnorm(length(new_inst), mean = 2*(new_inst=="11")),
                    x3 = rnorm(length(new_inst), mean = 3*(new_inst=="11")))

# instance level predictions (test data)
data.frame(instance_name = new_inst, y = new_y, new_x) %>%
  bind_cols(predict(mdl, type = "raw", new_data = new_x, new_instances = new_inst)) %>%
  bind_cols(predict(mdl, type = "class", new_data = new_x, new_instances = new_inst)) %>%
  distinct(instance_name, y, .pred, .pred_class)

Run the code above in your browser using DataLab