* We apply the **chain rule** over each layer
* $\frac{\delta\ell}{\delta f_1} = 2(f_1 - y)$
* $\frac{\delta\ell}{\delta relu} = \frac{\delta f_1}{\delta relu} \frac{\delta\ell}{\delta f_1}$
* $\frac{\delta\ell}{\delta f_0} = \frac{\delta relu}{\delta f_0}\frac{\delta f_1}{\delta relu} \frac{\delta\ell}{\delta f_1}$
---
## Backpropagation Costs
* We must store all intermediate outputs during the forward pass
* We also hope that a gradient exists connecting early layers to the loss
* But this doesn't always happen!
* Think of the 0 gradient section of a ReLU
* Add in numerical stability problems, and training a deep network is tough!
---
## Learning Details
* To demonstrate problems (and solutions!) with training, we should use real data
* But right now, we don't really know how to train a model
* Everything isn't just mean squared error curve fitting
* So let's spend a lecture on loss functions
---
## Loss Functions
* The loss function determines what our function models
* So what is MSE doing?
* Fitting our data, right?
* Let's take an example where our training data is contradictory
---
## Contradictory Data
* We'll generate data for two sin curves, and try to learn them both at once
* We end up with predictions in between the two curves
* We apply a [softplus](https://docs.pytorch.org/docs/stable/generated/torch.nn.Softplus.html#softplus) function to $\sigma$
* Basically a smooth ReLU
* We also add a tiny value, called an epsilon, for stability
---
## Code
```python
import torch
import matplotlib.pyplot as plt
my_colors = [ '#2E2585', '#337538', '#5DA899', '#94CBEC' ]
torch.random.manual_seed(10)
# Set up x points
x = torch.linspace(-3, 3, 500)
# Add a batch dimension
x = x.reshape(x.size(0), 1).float()
# y values without noise
y_original = torch.sin(x)
# x-dependent noise, higher magnitude near the center
noise_std = 0.1 * (3 - torch.abs(x))
noise = torch.randn_like(x) * noise_std
# Training data
y_noisy = y_original + noise
class NormalModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(1, 512),
torch.nn.ReLU(),
torch.nn.Linear(512, 512),
torch.nn.ReLU()
)
self.mean_head = torch.nn.Linear(512, 1)
self.var_head = torch.nn.Linear(512, 1)
def forward(self, x):
# Predict both a mean and a variance
features = self.layers(x)
mu = self.mean_head(features)
# Softplus ensures variance > 0; we add a small epsilon for stability
sigma_sq = torch.nn.functional.softplus(self.var_head(features)) + 1e-6
return mu, sigma_sq
def gaussian_nll_loss(mu, sigma_sq, target):
# Use negative log likelihood for numerical reasons
# This is from the normal PDF
# p(x | mu, sigma^2) = \frac{1}{\sqrt 2\pi\sigma^2}exp(-\frac{(x - \mu)^2}{2\sigma^2}
# Taking the negative log: -1[ ln((2\pi\sigma^2)^{1/2}) - \frac{(x - \mu)^2}{2\sigma^2}
# = -1[ -\frac{1}{2}ln(2\pi) - \frac{1}{2}ln(\sigma^2) - \frac{(x - \mu)^2}{2\sigma^2}
# = \frac{1}{2}ln(2\pi) + \frac{1}{2}ln(\sigma^2) + \frac{(x - \mu)^2}{2\sigma^2}
# The first term is a constant and doesn't matter for the loss.
# Notice that if sigma is a constant as well, we end up with MSE loss
return (torch.log(sigma_sq) / 2 + (target - mu)**2 / (2 * sigma_sq)).mean()
net = NormalModel()
learning_rate = 0.001
for epoch in range(500):
# Ensure that there are no gradients stored
net.zero_grad()
# This is the forward pass
mu, variance = net(x)
# This computes the loss
loss = gaussian_nll_loss(mu, variance, y_noisy)
# This computes the gradients (derivates of the error w.r.t each parameter)
loss.backward()
# Update the parameters
for param in net.parameters():
param.data -= param.grad * learning_rate
# Zero the gradient before the next pass
param.grad = None
with torch.no_grad():
mu, variance = net(x)
ax = plt.gca()
ax.scatter(x[:,0], y_noisy[:,0], linestyle='None', marker='o', label="target data", color=my_colors[1], linewidth=3)
ax.fill_between(x[:,0], (mu+2*variance)[:,0], (mu-2*variance)[:,0], alpha=0.5, color=my_colors[2], linewidth=0)
ax.plot(x[:,0], mu[:,0], linestyle='solid', marker=None, label="mu", color=my_colors[2], linewidth=3)
plt.savefig(f"../figures/05_nll.svg", dpi=2*96)
```
---
## Result