
Generate performance metrics across probability thresholds
Source:R/threshold_perf.R
threshold_perf.Rd
threshold_perf()
can take a set of class probability predictions
and determine performance characteristics across different values
of the probability threshold and any existing groups.
Usage
threshold_perf(.data, ...)
# S3 method for class 'data.frame'
threshold_perf(
.data,
truth,
estimate,
thresholds = NULL,
metrics = NULL,
na_rm = TRUE,
event_level = "first",
...
)
Arguments
- .data
A tibble, potentially grouped.
- ...
Currently unused.
- truth
The column identifier for the true two-class results (that is a factor). This should be an unquoted column name.
- estimate
The column identifier for the predicted class probabilities (that is a numeric). This should be an unquoted column name.
- thresholds
A numeric vector of values for the probability threshold. If unspecified, a series of values between 0.5 and 1.0 are used. Note: if this argument is used, it must be named.
- metrics
Either
NULL
or ayardstick::metric_set()
with a list of performance metrics to calculate. The metrics should all be oriented towards hard class predictions (e.g.yardstick::sensitivity()
,yardstick::accuracy()
,yardstick::recall()
, etc.) and not class probabilities. A set of default metrics is used whenNULL
(see Details below).- na_rm
A single logical: should missing data be removed?
- event_level
A single string. Either
"first"
or"second"
to specify which level oftruth
to consider as the "event".
Details
Note that that the global option yardstick.event_first
will be
used to determine which level is the event of interest. For more details,
see the Relevant level section of yardstick::sens()
.
The default calculated metrics are:
distance = (1 - sens) ^ 2 + (1 - spec) ^ 2
If a custom metric is passed that does not compute sensitivity and specificity, the distance metric is not computed.
Examples
library(dplyr)
data("segment_logistic")
# Set the threshold to 0.6
# > 0.6 = good
# < 0.6 = poor
threshold_perf(segment_logistic, Class, .pred_good, thresholds = 0.6)
#> # A tibble: 4 × 4
#> .threshold .metric .estimator .estimate
#> <dbl> <chr> <chr> <dbl>
#> 1 0.6 sensitivity binary 0.639
#> 2 0.6 specificity binary 0.869
#> 3 0.6 j_index binary 0.508
#> 4 0.6 distance binary 0.148
# Set the threshold to multiple values
thresholds <- seq(0.5, 0.9, by = 0.1)
segment_logistic |>
threshold_perf(Class, .pred_good, thresholds)
#> # A tibble: 20 × 4
#> .threshold .metric .estimator .estimate
#> <dbl> <chr> <chr> <dbl>
#> 1 0.5 sensitivity binary 0.714
#> 2 0.6 sensitivity binary 0.639
#> 3 0.7 sensitivity binary 0.561
#> 4 0.8 sensitivity binary 0.451
#> 5 0.9 sensitivity binary 0.249
#> 6 0.5 specificity binary 0.825
#> 7 0.6 specificity binary 0.869
#> 8 0.7 specificity binary 0.911
#> 9 0.8 specificity binary 0.937
#> 10 0.9 specificity binary 0.977
#> 11 0.5 j_index binary 0.539
#> 12 0.6 j_index binary 0.508
#> 13 0.7 j_index binary 0.472
#> 14 0.8 j_index binary 0.388
#> 15 0.9 j_index binary 0.226
#> 16 0.5 distance binary 0.112
#> 17 0.6 distance binary 0.148
#> 18 0.7 distance binary 0.201
#> 19 0.8 distance binary 0.306
#> 20 0.9 distance binary 0.565
# ---------------------------------------------------------------------------
# It works with grouped data frames as well
# Let's mock some resampled data
resamples <- 5
mock_resamples <- resamples |>
replicate(
expr = sample_n(segment_logistic, 100, replace = TRUE),
simplify = FALSE
) |>
bind_rows(.id = "resample")
resampled_threshold_perf <- mock_resamples |>
group_by(resample) |>
threshold_perf(Class, .pred_good, thresholds)
resampled_threshold_perf
#> # A tibble: 100 × 5
#> resample .threshold .metric .estimator .estimate
#> <chr> <dbl> <chr> <chr> <dbl>
#> 1 1 0.5 sensitivity binary 0.676
#> 2 1 0.6 sensitivity binary 0.595
#> 3 1 0.7 sensitivity binary 0.568
#> 4 1 0.8 sensitivity binary 0.405
#> 5 1 0.9 sensitivity binary 0.297
#> 6 2 0.5 sensitivity binary 0.794
#> 7 2 0.6 sensitivity binary 0.676
#> 8 2 0.7 sensitivity binary 0.618
#> 9 2 0.8 sensitivity binary 0.5
#> 10 2 0.9 sensitivity binary 0.235
#> # ℹ 90 more rows
# Average over the resamples
resampled_threshold_perf |>
group_by(.metric, .threshold) |>
summarise(.estimate = mean(.estimate))
#> `summarise()` has grouped output by '.metric'. You can override using
#> the `.groups` argument.
#> # A tibble: 20 × 3
#> # Groups: .metric [4]
#> .metric .threshold .estimate
#> <chr> <dbl> <dbl>
#> 1 distance 0.5 0.138
#> 2 distance 0.6 0.172
#> 3 distance 0.7 0.211
#> 4 distance 0.8 0.338
#> 5 distance 0.9 0.582
#> 6 j_index 0.5 0.500
#> 7 j_index 0.6 0.480
#> 8 j_index 0.7 0.475
#> 9 j_index 0.8 0.377
#> 10 j_index 0.9 0.218
#> 11 sensitivity 0.5 0.697
#> 12 sensitivity 0.6 0.620
#> 13 sensitivity 0.7 0.555
#> 14 sensitivity 0.8 0.427
#> 15 sensitivity 0.9 0.240
#> 16 specificity 0.5 0.803
#> 17 specificity 0.6 0.860
#> 18 specificity 0.7 0.920
#> 19 specificity 0.8 0.950
#> 20 specificity 0.9 0.978