from fastai2.basics import *
from fastai2.callback.all import *
from fastcook.utils import *
There is a callback for every step of the training loop.
Check all the available callbacks here or check their definitions in the source code by running the cells below:
Learner.one_batch??
Learner.fit??
To create a callback you simply need to inherit from Callback
and define methods with the event names you want to interact with.
class PrintCallback(Callback):
def after_epoch(self): print('After epoch')
def begin_fit(self): print('Beginning fit')
learn = synth_learner(cbs=PrintCallback())
learn.fit(2)
Now, the really cool thing about callbacks is that they have access to the learner object itself. In the Learner
training loop everything ends up being saved as an attribute, the predictions, loss, targets, everything. This gives the callback complete power to modify anything you need.
Let's define a custom loss function that receives the standard combination of predictions and targets plus some additional stuff:
def explosive_loss(pred, targ, stuff, **kwargs):
loss = MSELossFlat()(pred,targ,**kwargs)
return loss + (1000 if stuff=='explode' else 0)
And now we create a corresponding callback to inject this additional stuff:
class ExplodingCallback(Callback):
def after_pred(self):
stuff = 'stable'
if random.randint(0,1): stuff = 'explode'
self.learn.yb = (*self.yb, stuff)
Note: You need to use
self.learn.<stuff>
to write stuff but onlyself.<stuff>
to read it.
learn = synth_learner(cbs=ExplodingCallback(), loss_func=explosive_loss)
learn.fit(1)
Let's also create a callback that stops training if explosions happen:
class DefuserCallback(Callback):
def after_loss(self):
if self.loss > 1000:
print('The bomb has been defused')
raise CancelFitException
cbs = [ExplodingCallback(), DefuserCallback()]
learn = synth_learner(cbs=cbs, loss_func=explosive_loss)
learn.fit(1)
Let's look at a perfect example that demonstrates the callback power, directly copied from fastai
source code.
The following callback works together with a model that not only returns it's output, but also two additional items: the activations of the LSTM pre-dropout and the activations of the LSTM post-dropout, you can read more about this regulization here, at the AR and TAR regularization section.
At this point it's not really important that you understand what the callback is doing, but rather how it's doing.
Notice that in after_pred
it saves the two extra outputs of the model and returns only the standard output. This makes the interaction model+callback transparent to the rest of our code, since everything else is not expecting this two additional items. In after_loss
we use the information we just saved in after_pred
to add regulization losses to our original loss.
class RNNRegularizer(Callback):
"`Callback` that adds AR and TAR regularization in RNN training"
def __init__(self, alpha=0., beta=0.): self.alpha,self.beta = alpha,beta
def after_pred(self):
self.raw_out = self.pred[1][-1] if is_listy(self.pred[1]) else self.pred[1]
self.out = self.pred[2][-1] if is_listy(self.pred[2]) else self.pred[2]
self.learn.pred = self.pred[0]
def after_loss(self):
if not self.training: return
if self.alpha != 0.: self.learn.loss += self.alpha * self.out.float().pow(2).mean()
if self.beta != 0.:
h = self.raw_out
if len(h)>1: self.learn.loss += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()
And this is all we need for adding a completely new regulization strategy to our training loop!