Skip to content

threshold_perf() can take a set of class probability predictions and determine performance characteristics across different values of the probability threshold and any existing groups.

Usage

threshold_perf(.data, ...)

# S3 method for data.frame
threshold_perf(
  .data,
  truth,
  estimate,
  thresholds = NULL,
  metrics = NULL,
  na_rm = TRUE,
  event_level = "first",
  ...
)

Arguments

.data

A tibble, potentially grouped.

...

Currently unused.

truth

The column identifier for the true two-class results (that is a factor). This should be an unquoted column name.

estimate

The column identifier for the predicted class probabilities (that is a numeric). This should be an unquoted column name.

thresholds

A numeric vector of values for the probability threshold. If unspecified, a series of values between 0.5 and 1.0 are used. Note: if this argument is used, it must be named.

metrics

Either NULL or a yardstick::metric_set() with a list of performance metrics to calculate. The metrics should all be oriented towards hard class predictions (e.g. yardstick::sensitivity(), yardstick::accuracy(), yardstick::recall(), etc.) and not class probabilities. A set of default metrics is used when NULL (see Details below).

na_rm

A single logical: should missing data be removed?

event_level

A single string. Either "first" or "second" to specify which level of truth to consider as the "event".

Value

A tibble with columns: .threshold, .estimator, .metric, .estimate and any existing groups.

Details

Note that that the global option yardstick.event_first will be used to determine which level is the event of interest. For more details, see the Relevant level section of yardstick::sens().

The default calculated metrics are:

If a custom metric is passed that does not compute sensitivity and specificity, the distance metric is not computed.

Examples

library(dplyr)
data("segment_logistic")

# Set the threshold to 0.6
# > 0.6 = good
# < 0.6 = poor
threshold_perf(segment_logistic, Class, .pred_good, thresholds = 0.6)
#> # A tibble: 3 × 4
#>   .threshold .metric     .estimator .estimate
#>        <dbl> <chr>       <chr>          <dbl>
#> 1        0.6 sensitivity binary         0.639
#> 2        0.6 specificity binary         0.869
#> 3        0.6 j_index     binary         0.508

# Set the threshold to multiple values
thresholds <- seq(0.5, 0.9, by = 0.1)

segment_logistic %>%
  threshold_perf(Class, .pred_good, thresholds)
#> # A tibble: 15 × 4
#>    .threshold .metric     .estimator .estimate
#>         <dbl> <chr>       <chr>          <dbl>
#>  1        0.5 sensitivity binary         0.714
#>  2        0.6 sensitivity binary         0.639
#>  3        0.7 sensitivity binary         0.561
#>  4        0.8 sensitivity binary         0.451
#>  5        0.9 sensitivity binary         0.249
#>  6        0.5 specificity binary         0.825
#>  7        0.6 specificity binary         0.869
#>  8        0.7 specificity binary         0.911
#>  9        0.8 specificity binary         0.937
#> 10        0.9 specificity binary         0.977
#> 11        0.5 j_index     binary         0.539
#> 12        0.6 j_index     binary         0.508
#> 13        0.7 j_index     binary         0.472
#> 14        0.8 j_index     binary         0.388
#> 15        0.9 j_index     binary         0.226

# ---------------------------------------------------------------------------

# It works with grouped data frames as well
# Let's mock some resampled data
resamples <- 5

mock_resamples <- resamples %>%
  replicate(
    expr = sample_n(segment_logistic, 100, replace = TRUE),
    simplify = FALSE
  ) %>%
  bind_rows(.id = "resample")

resampled_threshold_perf <- mock_resamples %>%
  group_by(resample) %>%
  threshold_perf(Class, .pred_good, thresholds)

resampled_threshold_perf
#> # A tibble: 75 × 5
#>    resample .threshold .metric     .estimator .estimate
#>    <chr>         <dbl> <chr>       <chr>          <dbl>
#>  1 1               0.5 sensitivity binary         0.676
#>  2 1               0.6 sensitivity binary         0.595
#>  3 1               0.7 sensitivity binary         0.568
#>  4 1               0.8 sensitivity binary         0.405
#>  5 1               0.9 sensitivity binary         0.297
#>  6 2               0.5 sensitivity binary         0.794
#>  7 2               0.6 sensitivity binary         0.676
#>  8 2               0.7 sensitivity binary         0.618
#>  9 2               0.8 sensitivity binary         0.5  
#> 10 2               0.9 sensitivity binary         0.235
#> # ℹ 65 more rows

# Average over the resamples
resampled_threshold_perf %>%
  group_by(.metric, .threshold) %>%
  summarise(.estimate = mean(.estimate))
#> `summarise()` has grouped output by '.metric'. You can override using the
#> `.groups` argument.
#> # A tibble: 15 × 3
#> # Groups:   .metric [3]
#>    .metric     .threshold .estimate
#>    <chr>            <dbl>     <dbl>
#>  1 j_index            0.5     0.500
#>  2 j_index            0.6     0.480
#>  3 j_index            0.7     0.475
#>  4 j_index            0.8     0.377
#>  5 j_index            0.9     0.218
#>  6 sensitivity        0.5     0.697
#>  7 sensitivity        0.6     0.620
#>  8 sensitivity        0.7     0.555
#>  9 sensitivity        0.8     0.427
#> 10 sensitivity        0.9     0.240
#> 11 specificity        0.5     0.803
#> 12 specificity        0.6     0.860
#> 13 specificity        0.7     0.920
#> 14 specificity        0.8     0.950
#> 15 specificity        0.9     0.978