A generic S3 function to compute the cross entropy score for a classification model. This function dispatches to S3 methods in cross.entropy()
and performs no input validation. If you supply NA values or vectors of unequal length (e.g. length(x) != length(y)
), the underlying C++
code may trigger undefined behavior and crash your R
session.
Because cross.entropy()
operates on raw pointers, pointer-level faults (e.g. from NA or mismatched length) occur before any R
-level error handling. Wrapping calls in try()
or tryCatch()
will not prevent R
-session crashes.
To guard against this, wrap cross.entropy()
in a "safe" validator that checks for NA values and matching length, for example:
safe_cross.entropy <- function(x, y, ...) {
stopifnot(
!anyNA(x), !anyNA(y),
length(x) == length(y)
)
cross.entropy(x, y, ...)
}
Apply the same pattern to any custom metric functions to ensure input sanity before calling the underlying C++
code.
# S3 method for matrix
cross.entropy(pk, qk, dim = 0L, normalize = FALSE, ...)
A <double> value or vector:
A pair of <double> matrices of length \(n\) of emprical probabilities \(p\) and estimated probabilities \(q\).
An <integer> value of length 1 (Default: 0). Defines the dimension along which to calculate the entropy (0: total, 1: row-wise, 2: column-wise).
A <logical>-value (default: TRUE). If TRUE, the mean cross-entropy across all observations is returned; otherwise, the sum of cross-entropies is returned.
Arguments passed into other methods.
MacKay, David JC. Information theory, inference and learning algorithms. Cambridge university press, 2003.
Kramer, Oliver, and Oliver Kramer. "Scikit-learn." Machine learning for evolution strategies (2016): 45-53.
Virtanen, Pauli, et al. "SciPy 1.0: fundamental algorithms for scientific computing in Python." Nature methods 17.3 (2020): 261-272.
Other Classification:
accuracy()
,
auc.pr.curve()
,
auc.roc.curve()
,
baccuracy()
,
brier.score()
,
ckappa()
,
cmatrix()
,
dor()
,
fbeta()
,
fdr()
,
fer()
,
fmi()
,
fpr()
,
hammingloss()
,
jaccard()
,
logloss()
,
mcc()
,
nlr()
,
npv()
,
plr()
,
pr.curve()
,
precision()
,
recall()
,
relative.entropy()
,
roc.curve()
,
shannon.entropy()
,
specificity()
,
zerooneloss()
Other Supervised Learning:
accuracy()
,
auc.pr.curve()
,
auc.roc.curve()
,
baccuracy()
,
brier.score()
,
ccc()
,
ckappa()
,
cmatrix()
,
deviance.gamma()
,
deviance.poisson()
,
deviance.tweedie()
,
dor()
,
fbeta()
,
fdr()
,
fer()
,
fmi()
,
fpr()
,
gmse()
,
hammingloss()
,
huberloss()
,
jaccard()
,
logloss()
,
maape()
,
mae()
,
mape()
,
mcc()
,
mpe()
,
mse()
,
nlr()
,
npv()
,
pinball()
,
plr()
,
pr.curve()
,
precision()
,
rae()
,
recall()
,
relative.entropy()
,
rmse()
,
rmsle()
,
roc.curve()
,
rrmse()
,
rrse()
,
rsq()
,
shannon.entropy()
,
smape()
,
specificity()
,
zerooneloss()
Other Entropy:
logloss()
,
relative.entropy()
,
shannon.entropy()
## generate valid probability
## distributions
rand.sum <- function(n) {
x <- sort(runif( n-1 ))
c(x,1) - c(0, x)
}
## empirical and
## predicted probabilites
set.seed(1903)
pk <- t(replicate(200,rand.sum(5)))
qk <- t(replicate(200,rand.sum(5)))
## entropy
cross.entropy(
pk = pk,
qk = qk
)
Run the code above in your browser using DataLab