Brought to you by:
Paper

Wide neural networks of any depth evolve as linear models under gradient descent*

, , , , , and

Published 21 December 2020 © 2020 IOP Publishing Ltd and SISSA Medialab srl
, , Citation Jaehoon Lee et al J. Stat. Mech. (2020) 124002 DOI 10.1088/1742-5468/abc62b

1742-5468/2020/12/124002

Abstract

A longstanding goal in deep learning research has been to precisely characterize training and generalization. However, the often complex loss landscapes of neural networks (NNs) have made a theory of learning dynamics elusive. In this work, we show that for wide NNs the learning dynamics simplify considerably and that, in the infinite width limit, they are governed by a linear model obtained from the first-order Taylor expansion of the network around its initial parameters. Furthermore, mirroring the correspondence between wide Bayesian NNs and Gaussian processes (GPs), gradient-based training of wide NNs with a squared loss produces test set predictions drawn from a GP with a particular compositional kernel. While these theoretical results are only exact in the infinite width limit, we nevertheless find excellent empirical agreement between the predictions of the original network and those of the linearized version even for finite practically-sized networks. This agreement is robust across different architectures, optimization methods, and loss functions.

Export citation and abstract BibTeX RIS

1. Introduction

Machine learning models based on deep neural networks (NNs) have achieved unprecedented performance across a wide range of tasks (Krizhevsky et al 2012, He et al 2016, Devlin et al 2018). Typically, these models are regarded as complex systems for which many types of theoretical analyses are intractable. Moreover, characterizing the gradient-based training dynamics of these models is challenging owing to the typically high-dimensional non-convex loss surfaces governing the optimization. As is common in the physical sciences, investigating the extreme limits of such systems can often shed light on these hard problems. For NNs, one such limit is that of infinite width, which refers either to the number of hidden units in a fully-connected layer or to the number of channels in a convolutional layer. Under this limit, the output of the network at initialization is a draw from a Gaussian process (GP); moreover, the network output remains governed by a GP after exact Bayesian training using squared loss (Neal 1994, Lee et al 2018, Matthews et al 2018a, Novak et al 2019a, Garriga-Alonso et al 2019). Aside from its theoretical simplicity, the infinite-width limit is also of practical interest as wider networks have been found to generalize better (Lee et al 2018, Novak et al 2019a, Neyshabur et al 2015, Novak et al 2018, Neyshabur et al 2019).

In this work, we explore the learning dynamics of wide NNs under gradient descent and find that the weight-space description of the dynamics becomes surprisingly simple: as the width becomes large, the neural network can be effectively replaced by its first-order Taylor expansion with respect to its parameters at initialization. For this linear model, the dynamics of gradient descent become analytically tractable. While the linearization is only exact in the infinite width limit, we nevertheless find excellent agreement between the predictions of the original network and those of the linearized version even for finite width configurations. The agreement persists across different architectures, optimization methods, and loss functions.

For squared loss, the exact learning dynamics admit a closed-form solution that allows us to characterize the evolution of the predictive distribution in terms of a GP. This result can be thought of as an extension of 'sample-then-optimize' posterior sampling (Matthews et al 2017) to the training of deep NNs. Our empirical simulations confirm that the result accurately models the variation in predictions across an ensemble of finite-width models with different random initializations.

Here we summarize our contributions:

  • Parameter space dynamics: we show that wide network training dynamics in parameter space are equivalent to the training dynamics of a model which is affine in the collection of all network parameters, the weights and biases. This result holds regardless of the choice of loss function. For squared loss, the dynamics admit a closed-form solution as a function of time.
  • Sufficient conditions for linearization: we formally prove that there exists a threshold learning rate ηcritical (see theorem 2.1), such that gradient descent training trajectories with learning rate smaller than ηcritical stay in an $\mathcal{O}\left({n}^{-1/2}\right)$-neighborhood of the trajectory of the linearized network when n, the width of the hidden layers, is sufficiently large.
  • Output distribution dynamics: we formally show that the predictions of a neural network throughout gradient descent training are described by a GP as the width goes to infinity (see theorem 2.2), extending results from Jacot et al (2018). We further derive explicit time-dependent expressions for the evolution of this GP during training. Finally, we provide a novel interpretation of the result. In particular, it offers a quantitative understanding of the mechanism by which gradient descent differs from Bayesian posterior sampling of the parameters: while both methods generate draws from a GP, gradient descent does not generate samples from the posterior of any probabilistic model.
  • Large scale experimental support: we empirically investigate the applicability of the theory in the finite-width setting and find that it gives an accurate characterization of both learning dynamics and posterior function distributions across a variety of conditions, including some practical network architectures such as the wide residual network (Zagoruyko and Komodakis 2016).
  • Parameterization independence: we note that linearization result holds both in standard and NTK parameterization (defined in section 2.1), while previous work assumed the latter, emphasizing that the effect is due to increase in width rather than the particular parameterization.
  • Analytic ReLU and erf neural tangent kernels (NTKs): we compute the analytic NTK corresponding to fully-connected networks with ReLU or erf nonlinearities.
  • Source code: example code investigating both function space and parameter space linearized learning dynamics described in this work is released as open source code within Novak et al (2019b) 2 . We also provide accompanying interactive Colab notebooks for both parameter space 3 and function space 4 linearization.

1.1. Related work

We build on recent work by Jacot et al (2018) that characterize the exact dynamics of network outputs throughout gradient descent training in the infinite width limit. Their results establish that full batch gradient descent in parameter space corresponds to kernel gradient descent in function space with respect to a new kernel, the NTK. We examine what this implies about dynamics in parameter space, where training updates are actually made.

Daniely et al (2016) study the relationship between NNs and kernels at initialization. They bound the difference between the infinite width kernel and the empirical kernel at finite width n, which diminishes as $\mathcal{O}\left(1/\sqrt{n}\right)$. Daniely (2017) uses the same kernel perspective to study stochastic gradient descent (SGD) training of NNs.

Saxe et al (2014) study the training dynamics of deep linear networks, in which the nonlinearities are treated as identity functions. Deep linear networks are linear in their inputs, but not in their parameters. In contrast, we show that the outputs of sufficiently wide NNs are linear in the updates to their parameters during gradient descent, but not usually their inputs.

Du et al (2019), Allen-Zhu et al (2019, 2018), Zou et al (2019) study the convergence of gradient descent to global minima. They proved that for i.i.d. Gaussian initialization, the parameters of sufficiently wide networks move little from their initial values during SGD. This small motion of the parameters is crucial to the effect we present, where wide NNs behave linearly in terms of their parameters throughout training.

Mei et al (2018), Chizat and Bach (2018), Rotskoff and Vanden-Eijnden (2018), Sirignano and Spiliopoulos (2018) analyze the mean field SGD dynamics of training NNs in the large-width limit. Their mean field analysis describes distributional dynamics of network parameters via a PDE. However, their analysis is restricted to one hidden layer networks with a scaling limit $\left(1/n\right)$ different from ours $\left(1/\sqrt{n}\right)$, which is commonly used in modern networks (He et al 2016, Glorot and Bengio 2010).

Chizat et al (2018) 5 argued that infinite width networks are in 'lazy training' regime and maybe too simple to be applicable to realistic NNs. Nonetheless, we empirically investigate the applicability of the theory in the finite-width setting and find that it gives an accurate characterization of both the learning dynamics and posterior function distributions across a variety of conditions, including some practical network architectures such as the wide residual network (Zagoruyko and Komodakis 2016).

2. Theoretical results

2.1. Notation and setup for architecture and training dynamics

Let $\mathcal{D}\subseteq {\mathbb{R}}^{{n}_{0}}{\times}{\mathbb{R}}^{k}$ denote the training set and $\mathcal{X}=\left\{x:\left(x,y\right)\in \mathcal{D}\right\}$ and $\mathcal{Y}=\left\{y:\left(x,y\right)\in \mathcal{D}\right\}$ denote the inputs and labels, respectively. Consider a fully-connected feed-forward network with L hidden layers with widths nl , for l = 1,..., L and a readout layer with nL+1 = k. For each $x\in {\mathbb{R}}^{{n}_{0}}$, we use ${h}^{l}\left(x\right),{x}^{l}\left(x\right)\in {\mathbb{R}}^{{n}_{l}}$ to represent the pre- and post-activation functions at layer l with input x. The recurrence relation for a feed-forward network is defined as

Equation (1)

where ϕ is a point-wise activation function, ${W}^{l+1}\in {\mathbb{R}}^{{n}_{l}{\times}{n}_{l+1}}$ and ${b}^{l+1}\in {\mathbb{R}}^{{n}_{l+1}}$ are the weights and biases, ${\omega }_{ij}^{l}$ and ${b}_{j}^{l}$ are the trainable variables, drawn i.i.d. from a standard Gaussian ${\omega }_{ij}^{l},{\beta }_{j}^{l}\sim \mathcal{N}\left(0,1\right)$ at initialization, and ${\sigma }_{\omega }^{2}$ and ${\sigma }_{b}^{2}$ are weight and bias variances. Note that this parametrization is non-standard, and we will refer to it as the NTK parameterization. It has already been adopted in several recent works (van Laarhoven 2017, Karras et al 2018, Jacot et al 2018, Du et al 2019, Park et al 2019). Unlike the standard parameterization that only normalizes the forward dynamics of the network, the NTK-parameterization also normalizes its backward dynamics. We note that the predictions and training dynamics of NTK-parameterized networks are identical to those of standard networks, up to a width-dependent scaling factor in the learning rate for each parameter tensor. As we derive, and support experimentally, in the supplementary material (SM) (https://stacks.iop/JSTAT/2020/124002/mmedia) sections F and G, our results (linearity in weights, GP predictions) also hold for networks with a standard parameterization.

We define ${\theta }^{l}\equiv \mathrm{v}\mathrm{e}\mathrm{c}\left(\left\{{W}^{l},{b}^{l}\right\}\right)$, the $\left(\left({n}_{l-1}+1\right){n}_{l}\right){\times}1$ vector of all parameters for layer l. $\theta =\mathrm{v}\mathrm{e}\mathrm{c}\left({\cup }_{l=1}^{L+1}{\theta }^{l}\right)$ then indicates the vector of all network parameters, with similar definitions for θl and θ>l . Denote by θt the time-dependence of the parameters and by θ0 their initial values. We use ${f}_{t}\left(x\right)\equiv {h}^{L+1}\left(x\right)\in {\mathbb{R}}^{k}$ to denote the output (or logits) of the neural network at time t. Let $\ell \left(\hat{y},y\right):{\mathbb{R}}^{k}{\times}{\mathbb{R}}^{k}\to \mathbb{R}$ denote the loss function where the first argument is the prediction and the second argument the true label. In supervised learning, one is interested in learning a θ that minimizes the empirical loss 6 , $\mathcal{L}={\sum }_{\left(x,y\right)\in \mathcal{D}}\ell \left({f}_{t}\left(x,\theta \right),y\right)$.

Let η be the learning rate 7 . Via continuous time gradient descent, the evolution of the parameters θ and the logits f can be written as

Equation (2)

Equation (3)

where ${f}_{t}\left(\mathcal{X}\right)=\mathrm{v}\mathrm{e}\mathrm{c}\left({\left[{f}_{t}\left(x\right)\right]}_{x\in \mathcal{X}}\right)$, the $k\vert \mathcal{D}\vert {\times}1$ vector of concatenated logits for all examples, and ${\nabla }_{{f}_{t}\left(\mathcal{X}\right)}\mathcal{L}$ is the gradient of the loss with respect to the model's output, ${f}_{t}\left(\mathcal{X}\right)$. ${\hat{{\Theta}}}_{t}\equiv {\hat{{\Theta}}}_{t}\left(\mathcal{X},\mathcal{X}\right)$ is the tangent kernel at time t, which is a $k\vert \mathcal{D}\vert {\times}k\vert \mathcal{D}\vert $ matrix

Equation (4)

One can define the tangent kernel for general arguments, e.g. ${\hat{{\Theta}}}_{t}\left(x,\mathcal{X}\right)$ where x is test input. At finite-width, $\hat{{\Theta}}$ will depend on the specific random draw of the parameters and in this context we refer to it as the empirical tangent kernel.

The dynamics of discrete gradient descent can be obtained by replacing ${\dot {\theta }}_{t}$ and ${\dot {f}}_{t}\left(\mathcal{X}\right)$ with (θi+1θi ) and $\left({f}_{i+1}\left(\mathcal{X}\right)-{f}_{i}\left(\mathcal{X}\right)\right)$ above, and replacing ${\text{e}}^{-\eta {\hat{{\Theta}}}_{0}t}$ with $\left(1-{\left(1-\eta {\hat{{\Theta}}}_{0}\right)}^{i}\right)$ below.

2.2. Linearized networks have closed form training dynamics for parameters and outputs

In this section, we consider the training dynamics of the linearized network. Specifically, we replace the outputs of the neural network by their first order Taylor expansion,

Equation (5)

where ωt θt θ0 is the change in the parameters from their initial values. Note that ${f}_{t}^{\text{lin}}$ is the sum of two terms: the first term is the initial output of the network, which remains unchanged during training, and the second term captures the change to the initial value during training. The dynamics of gradient flow using this linearized function are governed by,

Equation (6)

Equation (7)

As ∇θ f0(x) remains constant throughout training, these dynamics are often quite simple. In the case of an MSE loss, i.e. $\ell \left(\hat{y},y\right)=\frac{1}{2}{\Vert}\hat{y}-y{{\Vert}}_{2}^{2}$, the ODEs have closed form solutions

Equation (8)

Equation (9)

For an arbitrary point x, ${f}_{t}^{\text{lin}}\left(x\right)={\mu }_{t}\left(x\right)+{\gamma }_{t}\left(x\right)$, where

Equation (10)

Equation (11)

Therefore, we can obtain the time evolution of the linearized neural network without running gradient descent. We only need to compute the tangent kernel ${\hat{{\Theta}}}_{0}$ and the outputs f0 at initialization and use equations (8), (10) and (11) to compute the dynamics of the weights and the outputs.

2.3. Infinite width limit yields a GP

As the width of the hidden layers approaches infinity, the central limit theorem (CLT) implies that the outputs at initialization ${\left\{{f}_{0}\left(x\right)\right\}}_{x\in \mathcal{X}}$ converge to a multivariate Gaussian in distribution. Informally, this occurs because the pre-activations at each layer are a sum of Gaussian random variables (the weights and bias), and thus become a Gaussian random variable themselves. See Poole et al (2016), Schoenholz et al (2017), Lee et al (2018), Xiao et al (2018), Yang and Schoenholz (2017) for more details, and Matthews et al (2018b), Novak et al (2019a) for a formal treatment.

Therefore, randomly initialized NNs are in correspondence with a certain class of GPs (hereinafter referred to as NNGPs), which facilitates a fully Bayesian treatment of NNs (Lee et al 2018, Matthews et al 2018a). More precisely, let ${f}_{t}^{i}$ denote the ith output dimension and $\mathcal{K}$ denote the sample-to-sample kernel function (of the pre-activation) of the outputs in the infinite width setting,

Equation (12)

then ${f}_{0}\left(\mathcal{X}\right)\sim \mathcal{N}\left(0,\mathcal{K}\left(\mathcal{X},\mathcal{X}\right)\right)$, where ${\mathcal{K}}^{i,j}\left(x,{x}^{\prime }\right)$ denotes the covariance between the ith output of x and jth output of x', which can be computed recursively (see Lee et al 2018, section 2.3 and SM section E). For a test input $x\in {\mathcal{X}}_{\mathrm{T}}$, the joint output distribution $f\left(\left[x,\mathcal{X}\right]\right)$ is also multivariate Gaussian. Conditioning on the training samples 8 , $f\left(\mathcal{X}\right)=\mathcal{Y}$, the distribution of $\left.f\left(x\right)\right\vert \mathcal{X},\mathcal{Y}$ is also a Gaussian $\mathcal{N}\left(\mu \left(x\right),{\Sigma}\left(x\right)\right)$,

Equation (13)

and where $\mathcal{K}=\mathcal{K}\left(\mathcal{X},\mathcal{X}\right)$. This is the posterior predictive distribution resulting from exact Bayesian inference in an infinitely wide neural network.

2.3.1. GPs from gradient descent training

If we freeze the variables θL after initialization and only optimize θL+1, the original network and its linearization are identical. Letting the width approach infinity, this particular tangent kernel ${\hat{{\Theta}}}_{0}$ will converge to $\mathcal{K}$ in probability and equation (10) will converge to the posterior equation (13) as t (for further details see SM section D). This is a realization of the 'sample-then-optimize' approach for evaluating the posterior of a GP proposed in Matthews et al (2017).

If none of the variables are frozen, in the infinite width setting, ${\hat{{\Theta}}}_{0}$ also converges in probability to a deterministic kernel Θ (Jacot et al 2018, Yang 2019), which we sometimes refer to as the analytic kernel, and which can also be computed recursively (see SM section E). For ReLU and erf nonlinearity, Θ can be exactly computed (SM section C) which we use in section 3. Letting the width go to infinity, for any t, the output ${f}_{t}^{\text{lin}}\left(x\right)$ of the linearized network is also Gaussian distributed because equations (10) and (11) describe an affine transform of the Gaussian $\left[{f}_{0}\left(x\right),{f}_{0}\left(\mathcal{X}\right)\right]$. Therefore

Corollary 1. For every test points in $x\in {\mathcal{X}}_{\mathrm{T}}$, and t ⩾ 0, ${f}_{t}^{\text{lin}}\left(x\right)$ converges in distribution as width goes to infinity to a Gaussian with mean and covariance given by 9

Equation (14)

Equation (15)

Therefore, over random initialization, ${\mathrm{lim}}_{t\to \infty }{\mathrm{lim}}_{n\to \infty }\enspace {f}_{t}^{\text{lin}}\left(x\right)$ has distribution

Equation (16)

Unlike the case when only θL+1 is optimized, equations (14) and (15) do not admit an interpretation corresponding to the posterior sampling of a probabilistic model 10 . We contrast the predictive distributions from the NNGP, NTK-GP (i.e. equations (14) and (15)) and ensembles of NNs in figure 2.

Infinitely-wide NNs open up ways to study deep NNs both under fully Bayesian training through the GP correspondence, and under GD training through the linearization perspective. The resulting distributions over functions are inconsistent (the distribution resulting from GD training does not generally correspond to a Bayesian posterior). We believe understanding the biases over learned functions induced by different training schemes and architectures is a fascinating avenue for future work.

2.4. Infinite width networks are linearized networks

Equations (2) and (3) of the original network are intractable in general, since ${\hat{{\Theta}}}_{t}$ evolves with time. However, for the mean squared loss, we are able to prove formally that, as long as the learning rate $\eta {< }{\eta }_{\text{critical}}{:=}2{\left({\lambda }_{\text{min}}\left({\Theta}\right)+{\lambda }_{\text{max}}\left({\Theta}\right)\right)}^{-1}$, where λmin/max(Θ) is the min/max eigenvalue of Θ, the gradient descent dynamics of the original neural network falls into its linearized dynamics regime.

Theorem 2.1 (Informal). Let n1 = ⋯ = nL = n and assume λmin(Θ) > 0. Applying gradient descent with learning rate η < ηcritical (or gradient flow), for every $x\in {\mathbb{R}}^{{n}_{0}}$ with ||x||2 ⩽ 1, with probability arbitrarily close to 1 over random initialization,

Equation (17)

Therefore, as n, the distributions of ft (x) and ${f}_{t}^{\text{lin}}\left(x\right)$ become the same. Coupling with corollary 1, we have

Theorem 2.2. If η < ηcritical, then for every $x\in {\mathbb{R}}^{{n}_{0}}$ with ||x||2 ⩽ 1, as n, ft (x) converges in distribution to the Gaussian with mean and variance given by equations (14) and (15).

We refer the readers to figure 2 for empirical verification of this theorem. The proof of theorem 2.1 consists of two steps. The first step is to prove the global convergence of overparameterized NNs (Du et al 2019, Allen-Zhu et al 2019, 2018, Zou et al 2019) and stability of the NTK under gradient descent (and gradient flow); see SM section G. This stability was first observed and proved in Jacot et al (2018) in the gradient flow and sequential limit (i.e. letting n1, ..., nL sequentially) setting under certain assumptions about global convergence of gradient flow. In section G, we show how to use the NTK to provide a self-contained (and cleaner) proof of such global convergence and the stability of NTK simultaneously. The second step is to couple the stability of NTK with Grönwall's type arguments (Dragomir 2003) to upper bound the discrepancy between ft and ${f}_{t}^{\text{lin}}$, i.e. the first norm in equation (17). Intuitively, the ODE of the original network (equation (3)) can be considered as a ${\Vert}{\hat{{\Theta}}}_{t}-{\hat{{\Theta}}}_{0}{{\Vert}}_{F}$-fluctuation from the linearized ODE (equation (7)). One expects the difference between the solutions of these two ODEs to be upper bounded by some functional of ${\Vert}{\hat{{\Theta}}}_{t}-{\hat{{\Theta}}}_{0}{{\Vert}}_{F}$; see section H. Therefore, for a large width network, the training dynamics can be well approximated by linearized dynamics.

Note that the updates for individual weights in equation (6) vanish in the infinite width limit, which for instance can be seen from the explicit width dependence of the gradients in the NTK parameterization. Individual weights move by a vanishingly small amount for wide networks in this regime of dynamics, as do hidden layer activations, but they collectively conspire to provide a finite change in the final output of the network, as is necessary for training. An additional insight gained from linearization of the network is that the individual instance dynamics derived in Jacot et al (2018) can be viewed as a random features method 11 where the features are the gradients of the model with respect to its weights.

2.5. Extensions to other optimizers, architectures, and losses

Our theoretical analysis thus far has focused on fully-connected single-output architectures trained by full batch gradient descent. In SM section B we derive corresponding results for: networks with multi-dimensional outputs, training against a cross entropy loss, and gradient descent with momentum.

In addition to these generalizations, there is good reason to suspect the results to extend to much broader class of models and optimization procedures. In particular, a wealth of recent literature suggests that the mean field theory governing the wide network limit of fully-connected models (Poole et al 2016, Schoenholz et al 2017) extends naturally to residual networks (Yang and Schoenholz 2017), CNNs (Xiao et al 2018), RNNs (Chen et al 2018), batch normalization (Yang et al 2019), and to broad architectures (Yang 2019). We postpone the development of these additional theoretical extensions in favor of an empirical investigation of linearization for a variety of architectures.

3. Experiments

In this section, we provide empirical support showing that the training dynamics of wide NNs are well captured by linearized models. We consider fully-connected, convolutional, and wide ResNet architectures trained with full- and mini-batch gradient descent using learning rates sufficiently small so that the continuous time approximation holds well. We consider two-class classification on CIFAR-10 (horses and planes) as well as ten-class classification on MNIST and CIFAR-10. When using MSE loss, we treat the binary classification task as regression with one class regressing to +1 and the other to −1.

Experiments in figures 1 and 4, S2--S6, were done in JAX (Frostig et al 2018). The remaining experiments used tensorflow (Abadi et al 2016). An open source implementation of this work providing tools to investigate linearized learning dynamics is available at www.github.com/google/neural-tangents (Novak et al 2019b).

Figure 1.

Figure 1. Relative Frobenius norm change during training. Three hidden layer ReLU networks trained with η = 1.0 on a subset of MNIST ($\vert \mathcal{D}\vert =128$). We measure changes of (input/output/intermediary) weights, empirical $\hat{{\Theta}}$, and empirical $\hat{\mathcal{K}}$ after T = 217 steps of gradient descent updates for varying width. We see that the relative change in input/output weights scales as $1/\sqrt{n}$ while intermediate weights scales as 1/n, this is because the dimension of the input/output does not grow with n. The change in $\hat{{\Theta}}$ and $\hat{\mathcal{K}}$ is upper bounded by $\mathcal{O}\left(1/\sqrt{n}\right)$ but is closer to $\mathcal{O}\left(1/n\right)$. See figure S6 for the same experiment with 3-layer tanh and 1-layer ReLU networks. See figures S9 and S10 for additional comparisons of finite width empirical and analytic kernels.

Standard image High-resolution image

Predictive output distribution: in the case of an MSE loss, the output distribution remains Gaussian throughout training. In figure 2, the predictive output distribution for input points interpolated between two training points is shown for an ensemble of NNs and their corresponding GPs. The interpolation is given by x(α) = αx(1) + (1 − α)x(2) where x(1,2) are two training inputs with different classes. We observe that the mean and variance dynamics of neural network outputs during gradient descent training follow the analytic dynamics from linearization well (equations (14) and (15)). Moreover the NNGP predictive distribution which corresponds to exact Bayesian inference, while similar, is noticeably different from the predictive distribution at the end of gradient descent training. For dynamics for individual function draws see SM figure S1.

Figure 2.

Figure 2. Dynamics of mean and variance of trained neural network outputs follow analytic dynamics from linearization. Black lines indicate the time evolution of the predictive output distribution from an ensemble of 100 trained NNs. The blue region indicates the analytic prediction of the output distribution throughout training (equations (14) and (15)). Finally, the red region indicates the prediction that would result from training only the top layer, corresponding to an NNGP (equations 22 and 23). The trained network has 3 hidden layers of width 8192, tanh activation functions, ${\sigma }_{w}^{2}=1.5$, no bias, and η = 0.5. The output is computed for inputs interpolated between two training points (denoted with black dots) x(α) = αx(1) + (1 − α)x(2). The shaded region and dotted lines denote 2 standard deviations (∼95% quantile) from the mean denoted in solid lines. Training was performed with full-batch gradient descent with dataset size $\vert \mathcal{D}\vert =128$. For dynamics for individual function initializations, see SM figure S1.

Standard image High-resolution image

Comparison of training dynamics of linearized network to original network: for a particular realization of a finite width network, one can analytically predict the dynamics of the weights and outputs over the course of training using the empirical tangent kernel at initialization. In figures 3 and 4 (see also S2, S3), we compare these linearized dynamics (equations (8) and (9)) with the result of training the actual network. In all cases we see remarkably good agreement. We also observe that for finite networks, dynamics predicted using the empirical kernel $\hat{{\Theta}}$ better match the data than those obtained using the infinite-width, analytic, kernel Θ. To understand this we note that ${\Vert}{\hat{{\Theta}}}_{T}^{\left(n\right)}-{\hat{{\Theta}}}_{0}^{\left(n\right)}{{\Vert}}_{F}=\mathcal{O}\left(1/n\right){\leqslant}\mathcal{O}\left(1/\sqrt{n}\right)={\Vert}{\hat{{\Theta}}}_{0}^{\left(n\right)}-{\Theta}{{\Vert}}_{F}$, where ${\hat{{\Theta}}}_{0}^{\left(n\right)}$ denotes the empirical tangent kernel of width n network, as plotted in figure 1.

Figure 3.

Figure 3. Full batch gradient descent on a model behaves similarly to analytic dynamics on its linearization, both for network outputs, and also for individual weights. A binary CIFAR classification task with MSE loss and a ReLU fully-connected network with 5 hidden layers of width n = 2048, η = 0.01, $\vert \mathcal{D}\vert =256$, k = 1, ${\sigma }_{w}^{2}=2.0$, and ${\sigma }_{b}^{2}=0.1$. Left two panes show dynamics for a randomly selected subset of datapoints or parameters. Third pane shows that the dynamics of loss for training and test points agree well between the original and linearized model. The last pane shows the dynamics of RMSE between the two models on test points. We observe that the empirical kernel $\hat{{\Theta}}$ gives more accurate dynamics for finite width networks.

Standard image High-resolution image
Figure 4.

Figure 4. A wide residual network and its linearization behave similarly when both are trained by SGD with momentum on MSE loss on CIFAR-10. We adopt the network architecture from Zagoruyko and Komodakis (2016). We use N = 1, channel size 1024, η = 1.0, β = 0.9, k = 10, ${\sigma }_{w}^{2}=1.0$, and ${\sigma }_{b}^{2}=0.0$. See table S1 for details of the architecture. Both the linearized and original model are trained directly on full CIFAR-10 ($\vert \mathcal{D}\vert =50\enspace 000$), using SGD with batch size 8. Output dynamics for a randomly selected subset of train and test points are shown in the first two panes. Last two panes show training and accuracy curves for the original and linearized networks.

Standard image High-resolution image

One can directly optimize parameters of flin instead of solving the ODE induced by the tangent kernel $\hat{{\Theta}}$. Standard neural network optimization techniques such as mini-batching, weight decay, and data augmentation can be directly applied. In figure 4 (S2, S3), we compared the training dynamics of the linearized and original network while directly training both networks.

With direct optimization of linearized model, we tested full ($\vert \mathcal{D}\vert =50\enspace 000$) MNIST digit classification with cross-entropy loss, and trained with a momentum optimizer (figure S3). For cross-entropy loss with softmax output, some logits at late times grow indefinitely, in contrast to MSE loss where logits converge to target value. The error between original and linearized model for cross entropy loss becomes much worse at late times if the two models deviate significantly before the logits enter their late-time steady-growth regime (see figure S4).

Linearized dynamics successfully describes the training of networks beyond vanilla fully-connected models. To demonstrate the generality of this procedure we show we can predict the learning dynamics of subclass of wide residual networks (WRNs) (Zagoruyko and Komodakis 2016). WRNs are a class of model that are popular in computer vision and leverage convolutions, batch normalization, skip connections, and average pooling. In figure 4, we show a comparison between the linearized dynamics and the true dynamics for a wide residual network trained with MSE loss and SGD with momentum, trained on the full CIFAR-10 dataset. We slightly modified the block structure described in table S1 so that each layer has a constant number of channels (1024 in this case), and otherwise followed the original implementation. As elsewhere, we see strong agreement between the predicted dynamics and the result of training.

Effects of dataset size: the training dynamics of a neural network match those of its linearization when the width is infinite and the dataset is finite. In previous experiments, we chose sufficiently wide networks to achieve small error between NNs and their linearization for smaller datasets. Overall, we observe that as the width grows the error decreases (figure S5). Additionally, we see that the error grows in the size of the dataset. Thus, although error grows with dataset this can be counterbalanced by a corresponding increase in the model size.

4. Discussion

We showed theoretically that the learning dynamics in parameter space of deep nonlinear NNs are exactly described by a linearized model in the infinite width limit. Empirical investigation revealed that this agrees well with actual training dynamics and predictive distributions across fully-connected, convolutional, and even wide residual network architectures, as well as with different optimizers (gradient descent, momentum, mini-batching) and loss functions (MSE, cross-entropy). Our results suggest that a surprising number of realistic NNs may be operating in the regime we studied. This is further consistent with recent experimental work showing that NNs are often robust to re-initialization but not re-randomization of layers (Zhang et al 2019).

In the regime we study, since the learning dynamics are fully captured by the kernel $\hat{{\Theta}}$ and the target signal, studying the properties of $\hat{{\Theta}}$ to determine trainability and generalization are interesting future directions. Furthermore, the infinite width limit gives us a simple characterization of both gradient descent and Bayesian inference. By studying properties of the NNGP kernel $\mathcal{K}$ and the tangent kernel Θ, we may shed light on the inductive bias of gradient descent.

Some layers of modern NNs may be operating far from the linearized regime. Preliminary observations in Lee et al (2018) showed that wide NNs trained with SGD perform similarly to the corresponding GPs as width increase, while GPs still outperform trained NNs for both small and large dataset size. Furthermore, in Novak et al (2019a), it is shown that the comparison of performance between finite- and infinite-width networks is highly architecture-dependent. In particular, it was found that infinite-width networks perform as well as or better than their finite-width counterparts for many fully-connected or locally-connected architectures. However, the opposite was found in the case of convolutional networks without pooling. It is still an open research question to determine the main factors that determine these performance gaps. We believe that examining the behavior of infinitely wide networks provides a strong basis from which to build up a systematic understanding of finite-width networks (and/or networks trained with large learning rates).

Acknowledgments

We thank Greg Yang and Alex Alemi for useful discussions and feedback. We are grateful to Daniel Freeman, Alex Irpan and anonymous reviewers for providing valuable feedbacks on the draft. We thank the JAX team for developing a language which makes model linearization and NTK computation straightforward. We would like to especially thank Matthew Johnson for support and debugging help.

Footnotes

  • This article is an updated version of a paper presented at 33rd Conference on Neural Information Processing Systems (NeurIPS 2019), Vancouver, Canada.

  • Note that the open source library has been expanded since initial submission of this work.

  • colab.sandbox.google.com/github/google/neural-tangents/blob/master/notebooks/weightspacelinearization.ipynb

  • We note that this is a concurrent work and an expanded version of this note is presented in parallel at NeurIPS 2019.

  • To simplify the notation for later equations, we use the total loss here instead of the average loss, but for all plots in section 3, we show the average loss.

  • Note that compared to the conventional parameterization, η is larger by factor of width (Park et al 2019). The NTK parameterization allows usage of a universal learning rate scale irrespective of network width.

  • This imposes that hL+1 directly corresponds to the network predictions. In the case of softmax readout, variational or sampling methods are required to marginalize over hL+1.

  • Here '+h.c.' is an abbreviation for 'plus the Hermitian conjugate'.

  • 10 

    One possible exception is when the NNGP kernel and NTK are the same up to a scalar multiplication. This is the case when the activation function is the identity function and there is no bias term.

  • 11 

    We thank Alex Alemi for pointing out a subtlety on correspondence to a random features method.

Please wait… references are loading.