Why use F.cross_entropy?

… instead of computing it yourself?
training
Published

January 12, 2023

Open In Colab

apart from less code of course! Another reason: it is safer!

I like to have more controll over what and how I am doing things, instead of using black boxes. But be warned that you can get burned when computing negative log likelihood yourself (the same is tru for softmaxes for example).

see this example:

import torch

logits = torch.tensor([-100, -5, 2, 100])
logits = logits.exp()
probs = logits / logits.sum()
probs, probs.sum()
(tensor([0., 0., 0., nan]), tensor(nan))

makes sense, right? exp(100) is VERY large, so if your network misbehaves and produces extreme activations, you have a problem, but…

import torch

logits = torch.tensor([-100, -5, 2, 100])

# here we deduct max value from the logits, so everyting is in (-∞, 0)
#----------------------
logits -= logits.max() 
#----------------------

logits = logits.exp()
probs = logits / logits.sum()
probs, probs.sum()
(tensor([0.0000e+00, 0.0000e+00, 2.7465e-43, 1.0000e+00]), tensor(1.))

is working nicely, and that’s what F.cross_entropy does internally. Of course, you can always add that normalization to safeguard against such cases (or add batchnorm layers to your architecture if you don’t wan’t to bother about such cases at the cost of a little more complexity and state in your model).

Plus of course I am sure there are also more good computational efficiency reasons to use torch’es built-in method do that.