If your metric consists of computing a value using the predictions and the labels, and then averaging over all points, use AvgMetric
.
You only need to define a function that receives two arguments, the prediction and the targets, and returns a single scalar.
def crazy_metric(pred,targ): return (pred>targ).float().mean()
CrazyMetric = AvgMetric(crazy_metric)
learn = synth_learner(metrics=CrazyMetric)
learn.fit(2)
Actually, you don't even need to use AvgMetric
, it's so common that fastai automatically uses AvgMetric
if you pass a function.
learn = synth_learner(metrics=crazy_metric)
learn.fit(2)
To have full control on all steps of calculating a metric inherit from Metric
, the three methods needed to override are reset
, accumulate
and value
.
reset
is called at the beggining of the validation step, here you should initialize all the required variables.accumulate
is called after every batch, here you do the actual calculation of your metric and decide how to accumulate the values between the batches.value
is called at the end of the validation step, it should return the final value of your already calculated metric.
class EvenCrazierMetric(Metric):
def reset(self): self.count,self.total = 0,0
def accumulate(self, learn):
bs = find_bs(learn.yb)
pred,yb = learn.pred, detuplify(learn.yb)
self.count = self.count*0.2 + 0.8*(pred-yb).float().sum()
self.total += bs
@property
def value(self): return self.count*self.total
learn = synth_learner(metrics=EvenCrazierMetric())
learn.fit(2)
scikit-learn already constains a bunch of useful metrics. With fastai you don't need to re-write all of that, there's a handy function called skm_to_fastai
that will do the conversion for you. Let's take a look at one that it's already defined in the source code:
def HammingLoss(axis=-1, sample_weight=None):
"Hamming loss for single-label classification problems"
return skm_to_fastai(skm.hamming_loss, axis=axis, sample_weight=sample_weight)
That's it, just wrap the function and pass the required paramaters.
This is how accuracy_multi
is defined inside the library. It's just a simple function, and remeber that when we pass functions as metrics to our Learner
they get automatically converted to AvgMetric
.
def accuracy_multi(inp, targ, thresh=0.5, sigmoid=True):
"Compute accuracy when `inp` and `targ` are the same size."
inp,targ = flatten_check(inp,targ)
if sigmoid: inp = inp.sigmoid()
return ((inp>thresh)==targ.bool()).float().mean()