library(ggplot2)
library(casimir)
gold <- tibble::tribble(
~doc_id, ~label_id,
"A", "a",
"A", "b",
"A", "c",
"B", "a",
"B", "d",
"C", "a",
"C", "b",
"C", "d",
"C", "f"
)
pred <- tibble::tribble(
~doc_id, ~label_id, ~score, ~rank,
"A", "a", 0.9, 1,
"A", "d", 0.7, 2,
"A", "f", 0.3, 3,
"A", "c", 0.1, 4,
"B", "a", 0.8, 1,
"B", "e", 0.6, 2,
"B", "d", 0.1, 3,
"C", "f", 0.1, 1,
"C", "c", 0.2, 2,
"C", "e", 0.2, 2
)
pr_curve <- compute_pr_curve(
pred,
gold,
mode = "doc-avg",
optimize_cutoff = TRUE
)
auc <- compute_pr_auc_from_curve(pr_curve$plot_data)
# note that pr curves take the cummax(prec), not the precision
ggplot(pr_curve$plot_data, aes(x = rec, y = prec_cummax)) +
geom_point(
data = pr_curve$opt_cutoff,
aes(x = rec, y = prec_cummax),
color = "red",
shape = "star"
) +
geom_text(
data = pr_curve$opt_cutoff,
aes(
x = rec + 0.2, y = prec_cummax,
label = paste("f1_opt =", round(f1_max, 3))
),
color = "red"
) +
geom_path() +
coord_cartesian(xlim = c(0, 1), ylim = c(0, 1))
Run the code above in your browser using DataLab