First, we need to install blurr module
for Transformers
integration.
reticulate::py_install('https://github.com/ohmeow/blurr',pip = TRUE)
Grab data and take 1 % for fast training:
Select multiple outputs/columns:
Load distill RoBERTa:
task = HF_TASKS_ALL()$SequenceClassification
pretrained_model_name = "distilroberta-base"
config = AutoConfig()$from_pretrained(pretrained_model_name)
config$num_labels = length(lbl_cols)
c(hf_arch, hf_config, hf_tokenizer, hf_model) %<-% get_hf_objects(pretrained_model_name,
task=task,
config=config)
Downloading: 100%|██████████| 899k/899k [00:00<00:00, 961kB/s]
Downloading: 100%|██████████| 456k/456k [00:00<00:00, 597kB/s]
Downloading: 100%|██████████| 331M/331M [03:26<00:00, 1.61MB/s]
Create data blocks:
blocks = list(
HF_TextBlock(hf_arch=hf_arch, hf_tokenizer=hf_tokenizer),
MultiCategoryBlock(encoded=TRUE, vocab=lbl_cols)
)
dblock = DataBlock(blocks=blocks,
get_x=ColReader('text'), get_y=ColReader(lbl_cols),
splitter=RandomSplitter())
dls = dblock %>% dataloaders(df, bs=8)
dls %>% one_batch()
[[1]]
[[1]]$input_ids
tensor([[ 0, 24268, 5257, ..., 1, 1, 1],
[ 0, 287, 4505, ..., 1, 1, 1],
[ 0, 38, 437, ..., 1, 1, 1],
...,
[ 0, 152, 1129, ..., 1, 1, 1],
[ 0, 85, 18, ..., 1, 1, 1],
[ 0, 22014, 31, ..., 1, 1, 1]], device='cuda:0')
[[1]]$attention_mask
tensor([[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0]], device='cuda:0')
[[2]]
TensorMultiCategory([[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0.]], device='cuda:0')
model = HF_BaseModelWrapper(hf_model)
learn = Learner(dls,
model,
opt_func=partial(Adam),
loss_func=BCEWithLogitsLossFlat(),
metrics=partial(accuracy_multi(), thresh=0.2),
cbs=HF_BaseModelCallback(),
splitter=hf_splitter())
learn$loss_func$thresh = 0.2
learn$create_opt() # -> will create your layer groups based on your "splitter" function
learn$freeze()
learn %>% summary()
See summary:
epoch train_loss valid_loss accuracy_multi time
------ ----------- ----------- --------------- ------
HF_BaseModelWrapper (Input shape: 8 x 391)
================================================================
Layer (type) Output Shape Param # Trainable
================================================================
Embedding 8 x 391 x 768 38,603,520 False
________________________________________________________________
Embedding 8 x 391 x 768 394,752 False
________________________________________________________________
Embedding 8 x 391 x 768 768 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 391 x 391 0 False
________________________________________________________________
Linear 8 x 391 x 768 590,592 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 391 x 3072 2,362,368 False
________________________________________________________________
Linear 8 x 391 x 768 2,360,064 False
________________________________________________________________
LayerNorm 8 x 391 x 768 1,536 True
________________________________________________________________
Dropout 8 x 391 x 768 0 False
________________________________________________________________
Linear 8 x 768 590,592 True
________________________________________________________________
Dropout 8 x 768 0 False
________________________________________________________________
Linear 8 x 6 4,614 True
________________________________________________________________
Total params: 82,123,014
Total trainable params: 615,174
Total non-trainable params: 81,507,840
Optimizer used: functools.partial(<function make_python_function.<locals>.python_function at 0x7fee7e8166a8>)
Loss function: FlattenedLoss of BCEWithLogitsLoss()
Model frozen up to parameter group #2
Callbacks:
- TrainEvalCallback
- Recorder
- ProgressCallback
- HF_BaseModelCallback
Finally, fit the model:
epoch train_loss valid_loss accuracy_multi time
------ ----------- ----------- --------------- ------
0 0.040617 0.034286 0.993257 01:21
Predict:
learn$loss_func$thresh = 0.02
learn %>% predict("Those damned affluent white people should only eat their own food, like cod cakes and boiled potatoes.
No enchiladas for them!")
$probabilities
severe_toxicity obscene threat insult identity_attack sexual_explicit
1 9.302437e-07 0.004268706 0.0007849637 0.02687055 0.003282947 0.00232468
$labels
[1] "insult"