Professor Forcing: A New Algorithm for Training Recurrent Networks - ShortScience.org

Authors present a method similar to teacher forcing that uses generative adversarial networks to guide training on sequential tasks. This work describes a novel algorithm to ensure the dynamics of an LSTM during inference follows that during training. The motivating example is sampling for a long number of steps at test time while only training on shorter sequences at training time. Experimental results are shown on PTB language modelling, MNIST, handwriting generation and music synthesis. The paper is similar to Generative Adversarial Networks (GAN): in addition to a normal sequence model loss function, the parameters try to “fool” a classifier. That classifier is trying to distinguish generated sequences from the sequence model, from real data. A few Objectives are proposed in section 2.2. The key difference to GAN is the B in equations 1-4. B is a function outputs some statistics of the model, such as the hidden state of the RNN, whereas GAN tries rather to discriminate the actual output sequences. This paper proposes a method for training recurrent neural networks (RNN) in the framework of adversarial training. Since RNNs can be used to generate sequential data, the goal is to optimize the network parameters in such a way that the generated samples are hard to distinguish from real data. This is particularly interesting for RNNs as the classical training criterion only involves the prediction of the next symbol in the sequence. Given a sequence of symbols $x_1, ..., x_t$, the model is trained so as to output $y_t$ as close to $x_{t+1}$ as possible. Training that way does not provide models that are robust during generation, as a mistake at time t potentially makes the prediction at time $t+k$ totally unreliable. This idea is somewhat similar to the idea of computing a sentence-wide loss in the context of encode-decoder translation models. The loss can only be computed after a complete sequence has been generated.