Meta-Learning with Implicit Gradients - ShortScience.org

This paper builds upon the previous work in gradient-based meta-learning methods. The objective of meta-learning is to find meta-parameters ($\theta$) which can be "adapted" to yield "task-specific" ($\phi$) parameters.
Thus, $\theta$ and $\phi$ lie in the same hyperspace. A meta-learning problem deals with several tasks, where each task is specified by its respective training and test datasets. At the inference time of gradient-based meta-learning methods, before the start of each task, one needs to perform some gradient-descent (GD) steps initialized from the meta-parameters to obtain these task-specific parameters. The objective of meta-learning is to find $\theta$, such that GD on each task's training data yields parameters that generalize well on its test data. Thus, the objective function of meta-learning is the average loss on the training dataset of each task ($\mathcal{L}_{i}(\phi)$), where the parameters of that task ($\phi$) are obtained by performing GD initialized from the meta-parameters ($\theta$). \begin{equation}
F(\theta) = \frac{1}{M}\sum_{i=1}^{M} \mathcal{L}_i(\phi)
\end{equation} In order to backpropagate the gradients for this task-specific loss function back to the meta-parameters, one needs to backpropagate through task-specific loss function ($\mathcal{L}_{i}$) and the GD steps (or any other optimization algorithm that was used), which were performed to yield $\phi$.
As GD is a series of steps, a whole sequence of changes done on $\theta$ need to be considered for backpropagation. Thus, the past approaches have focused on RNN + BPTT or Truncated BPTT. However, the author shows that with the use of the proximal term in the task-specific optimization (also called inner optimization), one can obtain the gradients without having to consider the entire trajectory of the parameters. The authors call these implicit gradients.
The idea is to constrain the $\phi$ to lie closer to $\theta$ with the help of proximal term which is similar to L2-regularization penalty term.
Due to this constraint, one obtains an implicit equation of $\phi$ in terms of $\theta$ as \begin{equation}
\phi = \theta - \frac{1}{\lambda}\nabla\mathcal{L}_i(\phi)
\end{equation}
This is then differentiated to obtain the implicit gradients as \begin{equation}
\frac{d\phi}{d\theta} = \big( \mathbf{I} + \frac{1}{\lambda}\nabla^{2} \mathcal{L}_i(\phi) \big)^{-1}
\end{equation} and the contribution of gradients from $\mathcal{L}_i$ is thus, \begin{equation} \big( \mathbf{I} + \frac{1}{\lambda}\nabla^{2} \mathcal{L}_i(\phi) \big)^{-1} \nabla \mathcal{L}_i(\phi)
\end{equation} The hessian in the above gradients are memory expensive computations, which become infeasible in deep neural networks.
Thus, the authors approximate the above term by minimizing the quadratic formulation using conjugate gradient method which only requires Hessian-vector products (cheaply available via reverse backpropagation).
\begin{equation}
\min_{\mathbf{w}} \mathbf{w}^\intercal \big( I + \frac{1}{\lambda}\nabla^{2} \mathcal{L}_i(\phi) \big) \mathbf{w} - \mathbf{w}^\intercal \nabla \mathcal{L}_i(\phi)
\end{equation} Thus, the paper introduces computationally cheap and constant memory gradient computation for meta-learning.