Note: This blog post is still a rough draft. Read on with caution.
In the previous post, we showed that training the last layer of an infinitely wide 1-hidden layer neural network is equivalent to kernel regression with the NNGP kernel. We will generalize this to training all layers of an infinitely wide neural network by introducing the neural tangent kernel.
We proceed in 3 steps:
- Approximate a neural network via linearization.
- Show that training the linearization is equivalent to kernel regression using the neural tangent kernel (NTK).
- The previous two steps worked with the linearization, and holds without taking width limits. We will show that as the number of hidden units \(k\to\infty\), training the linearization approximates training all layers of the original neural network.
The upshot of all this will be our main theorem: training an infinitely wide Bayesian neural network via gradient descent is equivalent to kernel regression with the NTK.
linearization
In the previous post, we typically considered our neural networks as functions \(f:\mathbf{R}^d\to\mathbf{R}\) in some function space, sending input vectors to outputs. An example of such is a 1-hidden layer neural network
\[ f(x) = \frac{1}{\sqrt{k}}A\phi(Bx) \]
where \(k\) is the number of hidden units in the hidden layer, \(A\in\mathbf{R}^{1\times k}, B\in\mathbf{R}^{k\times d}\) are weight matrices, and \(\phi\) is a elementwise (nonlinear) activation function.
This was fine before, because we were considering networks with fixed weights \(A\) and treating it as a strict linear regressor. But now we want to consider the networks as trainable objects, which requires us to perform gradient descent in weight space. Hence in our representation of \(f\) we include parameters, \(f:\mathbf{R}^p\times\mathbf{R}^d\to\mathbf{R}\), where
\[ f(w; x) = \frac{1}{\sqrt{k}}A\phi(Bx) \]
Here \(w\) is shorthand for all the parameters of \(f\) above, which in concatenated form is
\[ w = \begin{bmatrix} A_{11} & A_{12} & \cdot\cdot\cdot & A_{1k} & B_{11} & B_{12} & \cdot\cdot\cdot & B_{21} & \cdot\cdot\cdot & B_{kd} \end{bmatrix} \in \mathbf{R}^{k+kd} \]
Since we want to focus in on the behavior of a neural network in a local neighborhood of its parameters (given fixed input), we fix a dataset sample \(x\) and consider the function of weights \(f_x:\mathbf{R}^p\to\mathbf{R}\) given by \(f_x(w)=f(w;x)\).
We will approximate \(f_x(w)\) around a neighborhood of its initialization \(w^{(0)}\) by performing a first-order Taylor series expansion:
\[ \tilde{f}_x(w) = f(w^{(0)}; x) + \nabla f_x(w^{(0)})^T(w-w^{(0)}) \]
This is called the linearization of the neural network about \(w^{(0)}\). Note that this is a linear model! As a result, training the linearization \(\tilde{f}(w):\mathbf{R}^d\to\mathbf{R}\) given by \(\tilde{f}(w)(x)=\tilde{f}_x(w)\) is equivalent to solving a linear regression problem in the feature space given by applying the feature mapping
\[ \psi: x\mapsto \nabla f_x(w^{(0)}) \]
to the inputs.
As we have seen in the previous post, this is exactly kernel regression using the induced kernel from the feature map
\[ K_\text{NTK}(x,\tilde{x}) = \left\langle \nabla f_x(w^{(0)}), \nabla f_{\tilde{x}}(w^{(0)})\right\rangle \]
This is the neural tangent kernel (NTK). This equivalence between training the linearized neural network and kernel regression with this kernel was first given by this paper by Jacot et al.
In python, we can naively express this as
import torch
optimizer = torch.optim.SGD(nnet.parameters(), lr=0.001)
def ntk_kernel(nnet, x1, x2):
optimizer.zero_grad()
nnet(x1).backward()
fx1_grads = []
for param in nnet.parameters():
fx1_grads.append(param.grad.clone())
optimizer.zero_grad()
nnet(x2).backward()
fx2_grads = []
for param in nnet.parameters():
fx2_grads.append(param.grad.clone())
ker_val = 0
for grad1, grad2 in zip(fx1_grads, fx2_grads):
ker_val += (grad1 * grad2).sum()
return ker_val.numpy().item()
infinite width
While this is nice, the NTK looks too general of an object to be useful. In particular, we would like it if there was an effective way to compute this explicitly for different neural networks of interest. As a warmup, we will go through the derivation of the neural tangent kernel for our 1-hidden layer neural network.
Explicitly, for an input \(x\) we can write the function output \(f(w;x)\) as
\[ f(w;x)=\frac{1}{\sqrt{k}}\sum_{i=1}^k A_{1i}\phi(B_{i,:}x) \]
where \(B_{i,:}\) is the \(i^{th}\) row of \(B\in\mathbf{R}^{k\times d}\). Then we want to compute the gradient
\[ \nabla f_x(w^{(0)}) = \begin{bmatrix} \frac{\partial f_x}{\partial A_{11}} & \frac{\partial f_x}{\partial A_{12}} & \cdot\cdot\cdot & \frac{\partial f_x}{\partial A_{1k}} & \frac{\partial f_x}{\partial B_{11}} & \cdot\cdot\cdot & \frac{\partial f_x}{\partial B_{kd}} \end{bmatrix}^T \]
Taking derivatives, we see that
\[ \frac{\partial f_x}{\partial A_{1i}} = \frac{1}{\sqrt{k}}\phi(B_{i,:}x) \]
and
\[ \frac{\partial f_x}{\partial B_{ij}} = \frac{1}{\sqrt{k}}A_{1i}\phi'(B_{i,:}x)x_j \]
Hence by writing out the inner product sum, the NTK for this 1-hidden layer neural network is given by
\[ \begin{align*} K_\text{NTK}(x,\tilde{x}) &= \frac{1}{k}\sum_{i=1}^k \phi(B_{i,:}x)\phi(B_{i,:}\tilde{x}) + \frac{1}{k}\sum_{i=1}^k\sum_{j=1}^d A_{1i}^2\phi'(B_{i,:}x)\phi'(B_{i,:}\tilde{x})x_j\tilde{x}_j \\ &= \frac{1}{k}\sum_{i=1}^k \phi(B_{i,:}x)\phi(B_{i,:}\tilde{x}) + \frac{x^T\tilde{x}}{k}\sum_{i=1}^k A_{1i}^2\phi'(B_{i,:}x)\phi'(B_{i,:}\tilde{x}) \end{align*} \]
We note from this expression an interesting compositional structure: the first term looks like the finite-width NNGP from the previous post! The second term is a correction term, coming from the fact that we’re training the first layer weights as well.
This suggests that in the infinite-width limit, we can write the NTK in terms of the dual activations we introduced previously. Recall that if we work with a Bayesian 1-hidden layer neural network, with weights sampled from a unit Gaussian, \(A_{1i}, B_{ij}\sim N(0,1)\) iid, and with \(x,\tilde{x}\in S^{d-1}\) be on the unit sphere (i.e. they have vector norm 1), then the NNGP term approaches the dual activation value
\[ \frac{1}{k}\sum_{i=1}^k \phi(B_{i,:}x)\phi(B_{i,:}\tilde{x}) \to \check{\phi}(x^T\tilde{x}) \]
In particular, from the NTK above we can apply the same central limit theorem trick as \(k\to\infty\) to get the infinite-width NTK
\[ K_{\text{NTK},\infty}(x, \tilde{x}) = \check{\phi}(x^T\tilde{x}) + \check{\phi}'(x^T\tilde{x})x^T\tilde{x} \]
where \(\check{\phi}\) is the dual activation of \(\phi\). Note that the \(A_{1i}\)’s vanish in expectation because the variance of the \(A_{1i}\)’s is 1. For example, if \(\phi(x) = \sqrt{2}\max(0, x)\) is the normalized ReLU function, the resulting NTK of the \(\infty\)-width 1-hidden layer neural network with this activation function is given by
\[ K_{\text{NTK},\infty}(x, \tilde{x}) = \frac{1}{\pi}\left(\xi(\pi-\arccos(\xi)) + \sqrt{1-\xi^2}\right) + \frac{\xi}{\pi}(\pi-\arccos(\xi)) \]
where \(\xi=x^T\tilde{x}\) when \(x,\tilde{x}\) both have unit norm.
compositionality
We have seen in this situation that the NTK can be written compositionally from the NNGP with dual activations of \(\phi\) and \(\phi'\). This leads us to believe that we can extend to this general deep (fully-connected) neural networks via a similar inductive strategy.
In this section we now derive an expression for the neural tangent kernel of an infinitely wide neural network with \(L\) hidden layers. We write such a \(L\)-hidden layer neural network
\[ f^{(L)}_x:\mathbf{R}^p\to\mathbf{R} \]
in a similar way as the 1-hidden layer one:
\[ f^{(L)}_x(w) = w^{(L+1)}\frac{1}{\sqrt{k_L}}\phi\left(w^{(L)}\frac{1}{\sqrt{k_{L-1}}}\phi\left(\cdot\cdot\cdot w^{(2)}\frac{1}{\sqrt{k_1}}\phi\left(w^{(1)}x\right)\cdot\cdot\cdot\right)\right) \]
where \(w^{(i)}\in\mathbf{R}^{k_i\times k_{i-1}}\), \(k_0=d, k_{L+1}=1\) and \(\phi\) is a fixed elementwise (nonlinear) activation function. We can recursively write this as
\[ \begin{cases} h^{(j)}(x)=\frac{1}{\sqrt{k_j}}\phi(w^{(j)}h^{(j-1)}(x)) & \text{for } j > 0\\ h^{(0)}(x) = x \end{cases} \]
with end case \(f^{(L)}_x(w)=w^{(L+1)}h^{(L)}(x)\).
The NTK, given by
\[ K^{(L)}_\text{NTK}(x, \tilde{x}) = \left\langle \nabla f^{(L)}_x(w^{(0)}), \nabla f^{(L)}_{\tilde{x}}(w^{(0)})\right\rangle \]
simplifies dramatically in the \(\infty\)-width limit. The main theorem is as follows:
Suppose all weights \(w_i\sim N(0,1)\) iid, \(x,\tilde{x}\) are inputs of unit norm, and that the activation \(\phi\) is normalized with respect to the Gaussian 2-norm (so that \(\check{\phi}(1)=1\)). Then as \(k_1,k_2,...,k_L\to\infty\) (in that order), the NNGP \(\Sigma^{(L)}\) and NTK \(K^{(L)}_{\text{NTK},\infty}\) are given recursively by
\[ \Sigma^{(L)}(x,\tilde{x})=\check{\phi}(\Sigma^{(L-1)}(x,\tilde{x})) \]
and
\[ K^{(L)}_{\text{NTK},\infty}(x,\tilde{x})=\Sigma^{(L)}(x,\tilde{x})+\check{\phi}'\left(\Sigma^{(L-1)}(x,\tilde{x})\right)K^{(L-1)}_{\text{NTK},\infty}(x,\tilde{x}) \]
with base cases \(\Sigma^{(0)}(x,\tilde{x}) = x^T\tilde{x}\) and \(K^{(0)}_{\text{NTK},\infty}(x,\tilde{x})=x^T\tilde{x}\).
We prove this by induction: for the base case \((L=1)\), we have a 1-hidden layer neural network and the formulas read
\[ \Sigma^{(1)}(x,\tilde{x})=\check{\phi}(x^T\tilde{x}) \\ K^{(1)}_{\text{NTK},\infty}(x,\tilde{x}) = \check{\phi}(x^T\tilde{x}) + \check{\phi}'(x^T\tilde{x})x^T\tilde{x} \]
But these are exactly the formulas for the \(\infty\)-width NTK and NNGP from before. In mathematical induction, we now assume true for up to \((L-1)\)-hidden layer neural networks, and attempt to prove the formula for the \(L\)-hidden layer NN case.
Starting with the NNGP, by definition
\[ \begin{align*} \Sigma^{(L)}(x,\tilde{x}) &= \lim_{k_L\to\infty}\left\langle h^{(L)}(x), h^{(L)}(x)\right\rangle \\ &= \lim_{k_L\to\infty}\frac{1}{k_L}\left\langle \phi\left(w^{(L)}h^{(L-1)}(x)\right), \phi\left(w^{(L)}h^{(L-1)}(\tilde{x})\right)\right\rangle \\ &= \lim_{k_L\to\infty}\frac{1}{k_L}\sum_{i=1}^{k_L} \phi\left(w_{i,:}^{(L)}h^{(L-1)}(x)\right)\phi\left(w_{i,:}^{(L)}h^{(L-1)}(\tilde{x})\right) \\ &= \mathbf{E}_{w\sim N(0,1)}\left[\phi\left(w^Th^{(L-1)}(x)\right)\phi\left(w^Th^{(L-1)}(\tilde{x})\right)\right] \\ &= \mathbf{E}_{(u,v)\sim N(0,\Lambda^{(L-1)})}\left[\phi(u)\phi(v)\right] \end{align*} \]
where
\[ \Lambda^{(L-1)} = \begin{pmatrix} \Sigma^{(L-1)}(x,x) & \Sigma^{(L-1)}(x,\tilde{x}) \\ \Sigma^{(L-1)}(\tilde{x},x) & \Sigma^{(L-1)}(\tilde{x},\tilde{x}) \end{pmatrix} \]
as \(h^{(L-1)}\) is a Gaussian process with covariance \(\Lambda^{(L-1)}\) by induction. As \(\check{\phi}(1)=1\), we have \(\Sigma^{(L-1)}(x,x)=1\) for \(x\in S^{d-1}\). Hence
\[ \Sigma^{(L)}(x,\tilde{x}) = \check{\phi}\left(\Sigma^{(L-1)}(x,\tilde{x})\right) \]
as desired. By a similar argument, the NTK induction can be proved as well. \(\square\)
slow learning
To complete our program on the neural tangent kernel, we must show that the original neural network is close to its linearization through training in the infinite width limit. How does one show a function is linear (at least in a small neighborhood of an initialization)? The usual strategy is to show that the quadratic term in the Taylor expansion (i.e. the Hessian) vanishes.
We won’t prove this statement, instead deferring to this paper for the details.
experiments
This section is just an excuse for me to play around with Google’s neural-tangents
library.