Assessment Metrics for Use in Training Loop

library(geodl)
library(terra)
library(yardstick)

Terminology

During the training process, you may choose to monitor assessment metrics alongside of the loss function. Further, the assessment metrics can be calculated separately for the training and validation data for comparison and to assess for overfitting. The geodl package provides several classification assessment metrics that have been implemented using the luz_metric() function from the luz package. This function specifically defines metrics that can be monitored, calculated, and aggregated during the training processes over multiple mini-batches.

The following metrics are provided:

Note that macro-averaging is used since micro-averaged recall, precision, and F1-score are all equivalent to each other and to overall accuracy. So, there is no reason to calculate these metrics alongside overall accuracy. The mode argument can be set to either “binary” or “multiclass”. If “binary”, only the logit for the positive class prediction should be provided. If both the positive and negative or background class probabilities are provided for a binary classification, the “multiclass” mode should be used. Note that geodl is designed to treat all predictions as multiclass. The “binary” mode is only provided for use outside of the standard geodl workflow. For a binary classification as implemented with geodl, if you want to calculate precision, recall, and/or F1-score using only the positive case, you can set the weight for the background class to 0 in the calculation. This can be accomplished using the clsWghts parameter.

The biThresh parameter allows you to select a probability threshold to differentiate positive and negative class predictions. The default is 0.5, or any pixel with a positive case probability, as approximated by applying a sigmoid function to the positive case logit, will be coded to the positive case. All other samples will be coded to the negative case. This parameter is not used for multiclass problems.

The smooth parameter is used to avoid divide-by-zero errors and to improve numeric stability.

The following two examples demonstrate the metrics. These metrics are also used within the example classification workflows provided in other articles.

Example 1: Binary Classification

This first example demonstrates calculating the assessment metrics for a binary classification where the background class is coded as 0 and the positive case is coded as 1. First, the categorical raster grids are read in as spatRaster objects using the rast() function from terra, and the bands are renamed to aid in interpretability. The refC object represents the reference class numeric codes, the predL object represents the predicted logits (before the data are rescaled with a sigmoid function), and the predC object represent the predicted class indices.

refC <- terra::rast("C:/myFiles/data/metricCheck/binary_reference.tif")

predL <- terra::rast("C:/myFiles/data/metricCheck/binary_logits.tif")

predC <- terra::rast("C:/myFiles/data/metricCheck/binary_prediction.tif")

names(refC) <- "reference"
names(predC) <- "prediction"

For comparison, we first create a confusion matrix by cross-tabulating the grids using the crosstab() function from terra then calculate assessment metrics using the yardstick package from tidymodels.

cm <- terra::crosstab(c(predC, refC))
yardstick::f_meas(cm, estimator="binary", event_level="second")
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 f_meas  binary         0.951
yardstick::accuracy(cm, estimator="binary", event_level="second")
# A tibble: 1 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy binary         0.994
yardstick::recall(cm, estimator="binary", event_level="second")
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 recall  binary         0.973
yardstick::precision(cm, estimator="binary", event_level="second")
# A tibble: 1 × 3
  .metric   .estimator .estimate
  <chr>     <chr>          <dbl>
1 precision binary         0.929

Before the assessment metrics can be calculated, the spatRaster objects need to be converted to torch tensors with the correct dimensionality and data types. This is accomplished as follows:

  1. spatRaster objects are converted to R arrays.
  2. The R arrays are converted to torch tensors with the target codes represented using a long integer data type and the logits represented with a 32-bit float data type.
  3. The tensors are permuted so that the channel dimension is first, as is the standard within torch.
  4. A mini-batch dimension is added at the first position, as is required for use inside of training loops.
  5. Two copies are of the tensors are concatenated to simulate a mini-batch with two samples.
predL <- terra::as.array(predL)
refC <- terra::as.array(refC)

target <- torch::torch_tensor(refC, dtype=torch::torch_long())
pred <- torch::torch_tensor(predL, dtype=torch::torch_float32())
target <- target$permute(c(3,1,2))
pred <- pred$permute(c(3,1,2))

target <- target$unsqueeze(1)
pred <- pred$unsqueeze(1)

target <- torch::torch_cat(list(target, target), dim=1)
pred <- torch::torch_cat(list(pred, pred), dim=1)

In the following blocks of code I calculate overall accuracy, F1-score, recall, and precision. This requires (1) instantiating an instance of the metric, (2) updating the metric using the predicted logits and class codes, and (3) computing the metric.

metric<-luz_metric_overall_accuracy(nCls=1,
                                    smooth=1e-8,
                                    mode = "binary",
                                    biThresh = 0.5,
                                    zeroStart=TRUE,
                                    usedDS=FALSE)
metric<-metric$new()
metric$update(pred,target)
metric$compute()
[1] 0.9945027
metric<-luz_metric_f1score(nCls=1,
                           smooth=1e-8,
                           mode = "binary",
                           biThresh = 0.5,
                           zeroStart=TRUE,
                           usedDS=FALSE)
metric<-metric$new()
metric$update(pred,target)
metric$compute()
[1] 0.9521416
metric<-luz_metric_recall(nCls=1,
                          smooth=1e-8,
                          mode = "binary",
                          biThresh = 0.5,
                          zeroStart=TRUE,
                          usedDS=FALSE)
metric<-metric$new()
metric$update(pred,target)
metric$compute()
[1] 0.9782254
metric<-luz_metric_precision(nCls=1,
                             smooth=1e-8,
                             mode = "binary",
                             biThresh = 0.5,
                             zeroStart=TRUE,
                             usedDS=FALSE)
metric<-metric$new()
metric$update(pred,target)
metric$compute()
[1] 0.9274127

Example 2: Multiclass Classification

A multiclass classification is simulated in this example. The process is very similar to the binary example other than how the metrics are configured. Now, the “multiclass” mode is used. The zeroStart parameter is set to TRUE. This indicates that the class codes start at 0 as opposed to 1. The zeroStart parameter is used within many of geodl’s functions. Due to how one-hot encoding is implemented by the torch package, indexing starting at 0 can cause issues. By indicating that indexing starts at 0, the workflow will augment the codes so that an error is not generated by one-hot encoding.

refC <- terra::rast("C:/myFiles/data/metricCheck/multiclass_reference.tif")
predL <- terra::rast("C:/myFiles/data/metricCheck/multiclass_logits.tif")
predC <- terra::rast("C:/myFiles/data/metricCheck/multiclass_prediction.tif")

names(refC) <- "reference"
names(predC) <- "prediction"
cm <- terra::crosstab(c(predC, refC))
yardstick::f_meas(cm, estimator="macro")
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 f_meas  macro          0.822
yardstick::accuracy(cm, estimator="micro")
# A tibble: 1 × 3
  .metric  .estimator .estimate
  <chr>    <chr>          <dbl>
1 accuracy multiclass     0.913
yardstick::recall(cm, estimator="macro")
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 recall  macro          0.827
yardstick::precision(cm, estimator="macro")
# A tibble: 1 × 3
  .metric   .estimator .estimate
  <chr>     <chr>          <dbl>
1 precision macro          0.821
predL <- terra::as.array(predL)
refC <- terra::as.array(refC)

target <- torch::torch_tensor(refC, dtype=torch::torch_long())
pred <- torch::torch_tensor(predL, dtype=torch::torch_float32())
target <- target$permute(c(3,1,2))
pred <- pred$permute(c(3,1,2))

target <- target$unsqueeze(1)
pred <- pred$unsqueeze(1)

target <- torch::torch_cat(list(target, target), dim=1)
pred <- torch::torch_cat(list(pred, pred), dim=1)
metric<-luz_metric_overall_accuracy(nCls=5,
                                    smooth=1e-8,
                                    mode = "multiclass",
                                    zeroStart=TRUE,
                                    usedDS=FALSE)
metric<-metric$new()
metric$update(pred,target)
metric$compute()
[1] 0.9134674
metric<-luz_metric_f1score(nCls=5,
                           smooth=1e-8,
                           mode = "multiclass",
                           zeroStart=TRUE,
                           clsWghts = c(1,1,1,1,1),
                           usedDS=FALSE)
metric<-metric$new()
metric$update(pred,target)
metric$compute()
[1] 0.8237987
metric<-luz_metric_recall(nCls=5,
                          smooth=1e-8,
                          mode = "multiclass",
                          zeroStart=TRUE,
                          clsWghts = c(1,1,1,1,1),
                          usedDS=FALSE)
metric<-metric$new()
metric$update(pred,target)
metric$compute()
[1] 0.8265532
metric<-luz_metric_precision(nCls=5,
                             smooth=1e-8,
                             mode = "multiclass",
                             zeroStart=TRUE,
                             clsWghts = c(1,1,1,1,1),
                             usedDS=FALSE)
metric<-metric$new()
metric$update(pred,target)
metric$compute()
[1] 0.8210625

Again, you will see more demonstrations of these assessment metrics within the example training loops.