An R implementation of: TabNet: Attentive Interpretable
Tabular Learning (Sercan O. Arik, Tomas
Pfister).
The code in this repository is an R port using the torch package of dreamquark-ai/tabnet
PyTorch’s implementation.
TabNet is augmented with Coherent
Hierarchical Multi-label Classification Networks (Eleonora Giunchiglia
et Al.) for hierarchical outcomes.
You can install the released version from CRAN with:
install.packages("tabnet")
The development version can be installed from GitHub with:
# install.packages("remotes")
::install_github("mlverse/tabnet") remotes
Here we show a binary classification example of the
attrition
dataset, using a recipe for
dataset input specification.
library(tabnet)
suppressPackageStartupMessages(library(recipes))
library(yardstick)
library(ggplot2)
set.seed(1)
data("attrition", package = "modeldata")
<- sample.int(nrow(attrition), size = 0.2 * nrow(attrition))
test_idx
<- attrition[-test_idx,]
train <- attrition[test_idx,]
test
<- recipe(Attrition ~ ., data = train) %>%
rec step_normalize(all_numeric(), -all_outcomes())
<- tabnet_fit(rec, train, epochs = 30, valid_split=0.1, learn_rate = 5e-3)
fit autoplot(fit)
The plots gives you an immediate insight about model over-fitting, and if any, the available model checkpoints available before the over-fitting
Keep in mind that regression as well as multi-class classification are also available, and that you can specify dataset through data.frame and formula as well. You will find them in the package vignettes.
As the standard method predict()
is used, you can rely
on your usual metric functions for model performance results. Here we
use {yardstick} :
<- metric_set(accuracy, precision, recall)
metrics cbind(test, predict(fit, test)) %>%
metrics(Attrition, estimate = .pred_class)
#> # A tibble: 3 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 accuracy binary 0.837
#> 2 precision binary 0.837
#> 3 recall binary 1
cbind(test, predict(fit, test, type = "prob")) %>%
roc_auc(Attrition, .pred_No)
#> # A tibble: 1 × 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.546
TabNet has intrinsic explainability feature through the visualization of attention map, either aggregated:
<- tabnet_explain(fit, test)
explain autoplot(explain)
or at each layer through the
type = "steps"
option:
autoplot(explain, type = "steps")
For cases when a consistent part of your dataset has no outcome, TabNet offers a self-supervised training step allowing to model to capture predictors intrinsic features and predictors interactions, upfront the supervised task.
<- tabnet_pretrain(rec, train, epochs = 50, valid_split=0.1, learn_rate = 1e-2)
pretrain autoplot(pretrain)
The example here is a toy example as the train
dataset
does actually contain outcomes. The vignette on Self-supervised
training and fine-tuning will gives you the complete correct
workflow step-by-step.
{tabnet} leverage the masking mechanism to deal with missing data, so you don’t have to remove the entries in your dataset with some missing values in the predictors variables.
Group | Feature | {tabnet} | dreamquark-ai | fast-tabnet |
---|---|---|---|---|
Input format | data-frame | ✅ | ✅ | ✅ |
formula | ✅ | |||
recipe | ✅ | |||
Node | ✅ | |||
missings in predictor | ✅ | |||
Output format | data-frame | ✅ | ✅ | ✅ |
workflow | ✅ | |||
ML Tasks | self-supervised learning | ✅ | ✅ | |
classification (binary, multi-class) | ✅ | ✅ | ✅ | |
regression | ✅ | ✅ | ✅ | |
multi-outcome | ✅ | ✅ | ||
hierarchical multi-label classif. | ✅ | |||
Model management | from / to file | ✅ | ✅ | v |
resume from snapshot | ✅ | |||
training diagnostic | ✅ | |||
Interpretability | ✅ | ✅ | ✅ | |
Performance | 1 x | 2 - 4 x | ||
Code quality | test coverage | 85% | ||
continuous integration | 4 OS including GPU |
Alternative TabNet implementation features