As mentioned, our model is inspired from the M2 model developed by Kingma et al. The architecture is based on the Variational Autoencoder, with the addition of a classification network for labelled data. The latent layer then contains the standard continuous latent distribution from the encoder, but also a latent class label coming from either the classifier or the raw label, depending on availability.
Consider the job of the encoder in a VAE network – it creates a dense, abstract representation of the data before transforming it into a continuous function space. In a typical style classification network, the output layer is determined by the activations of the previous layer – which is again a dense, abstract representation of the input. Whether or not a label is available, there is no reason why this abstract representation of the data should differ, meaning the classifier can share neurons with the encoder before the layers are split, now into mu, sigma and pi layers (where pi contains the latent class label), see Fig5. The benefit of this is two fold; firstly, the model contains less neurons improving training time and secondly and more importantly, all of the weights in the classifier network (apart from the pi layer) are directly trained by unlabelled data by way of the encoder.
The loss function is formed as the sum of the cross entropy between the original input and the output of the decoder, the KL divergence of the mu and sigma layers and (in the case of a label) the cross entropy of the label and the pi layer:
Here, q corresponds to the output probability density function and p to the actual probability density function (it can be delta functions in the continuous case, in which also the sum would be replaced with integrals). Similarly, pi corresponds to the actual label, and r to the probability of that label as is output by the model. If the label is not known, all pi’s are set to zero for this datapoint. The index k sums over all different labels, while i sums over all different output values. If we have more than one feature, the mean of L_ent and L_kl over these is taken.
Let’s break down the loss function and examine from an intuitive perspective how each of the three terms help the model to learn. Firstly, note that the first two terms comprise the loss function of a standard VAE, a KL term and a cross entropy term. The KL term pushes the latent layer to stay close to a standard gaussian probability distribution. Without it, the gaussian distribution may degenerate into a delta function, in which case the VAE will behave exactly the same as a normal auto-encoder. It would then be unable to capture any information beyond the data points seen in training. The resulting decoder would only be able to learn static mappings, but with the KL term, the decoder learns to reconstruct data from a continuous distribution function, giving it a stronger ability to generalize.
These first two terms take effect with every datapoint, whether a label is present or not, which results in the entire network being trained with the exception of the pi layer. The direct effect of an unlabelled data point is to thus train the entire anomaly detector and the majority of the classification network.
The third term only contributes to the loss in the event of a label. The effect of this term is two fold; firstly, it teaches the pi layer to translate the compressed representation of the data into a class label, but it also forces the rest of the network to learn a better representation. If the loss of the first two terms is low, i.e. the VAE thinks it’s found a good representation, but the labels are wrong, the loss will be higher and force the network to try again. The loss function is therefore improving all parts of the network with all types of data. Excluding the pi layer, the entire classifier is trained by unlabelled data whilst the entire network is improved by a label – resulting in a model that’s improved by any and all available data.
It turns out, as we shall see in more detail later on, that this architecture outperforms equivalent supervised models. In a standard classifier, the loss function is designed to achieve a single task – update the network weights such that the correct activation is triggered in the output layer. In our model, the loss is performing two distinct tasks: achieve the correct classification, but also reconstruct the original input as closely as possible.
PURE SUPERVISED APPLICATION
Consider a situation in which all data points in the training dataset are labelled, and we are interested in building a classifier. In this context, we would usually start off building a neural network something as only the red colored parts of Fig6.
What we tried however, is to still use the full Variational Autoencoder network, including the pi layer, and trained it end-to-end on the fully labelled training dataset. One can argue that the decoder (and latent layer) act as a sort of regularizer (see Fig7) in this context, as the model not only has to produce the correct label, but also to reproduce the input. It hence has to build up a better representation of the input data in its deeper layers, from which both the pi layer and the decoder feed off.
If we train the network as such, and for inference cut out the regularizer piece as in Fig6, it turns out that the remaining classifier piece outperforms an equivalent classification model that was trained in the standard way (ie only to produce the correct label outputs). The decoder hence indeed behaved as a regularizer, helping the classifier model to find a good local minimum.
ANOMALY DETECTION APPLICATION
In the previous sections, we saw how unlabelled data points aid the classification performance of a variational autoencoder. Maybe the opposite is true as well: Do the availability of labels aid tasks usually addressed by pure unsupervised systems, such as for instance anomaly detection?
Again, from a pure theoretical perspective, there is again good reason that this is the case. If we ask the model not only to reproduce the input, but also to produce the correct label, the model is again forced to build better representations in its dense layers compared to a purely unsupervised model. This in turn should also help the reconstruction (and hence anomaly detection) task.
In a way, this is the opposite of what we did in the previous section. Instead of cutting out the decoder part after training, we now cut out the pi part (see Fig8).
We will see in the results section that the anomaly detection performance is indeed improved by the availability of labels in the training dataset.
We experimented with three different variations of this architecture:
- The simple dense layered architecture as described above (in which both the encoder and decoder are modeled with vanilla networks)
- A variation with convolutional layers
- A recurrent variation with LSTM layers.
RESULTS & EXPERIMENTS
So, how do the models perform? For the static models we will be using the MNIST digit recognition dataset and for the recurrent models we shall be using the UCI-HAR (Human Activity Recognition) dataset. For each task, we shall also be comparing the results to those of an equivalent supervised / unsupervised model to really see the improvements offered by our architecture.
From here on, the models will be referred to as follows:
As we’re proposing a semi-supervised model, we’ll start by examining the performance of our model in area which it is designed to excel – classification with limited labels in the training data.
Firstly, we’ll examine the static models, trained for 20 epochs on the MNIST dataset. We’ll be using all 600,000 available images, but simply removing the labels, meaning the equivalent supervised model will not be able to train on the unlabelled images.
Unsurprisingly, both flavors of semi-supervised model drastically outperform their supervised counterparts, with the simple dense model showing an impressive improvement of 6.3%, and the convolutional version showing an even bigger improvement of 13.9%. Notice however that even EQ_S_D outperforms EQ_S_C, which should not be the case as convolutional models are much more suited to such tasks. The score here simply stems from the lack of training data for the convolutional model, which requires many more samples to properly converge.
Increasing the number of available labels to 1000, we see the performance of EQ_S_C drastically improve as the model now has enough images to converge, yet the semi-supervised SS_C still outperforms the model by 2.7%. Interestingly, in the case of the dense models, the gap between SS_D and EQ_S_D has increased to 10.1%.
Increasing the number of labels to 100%, the gap closes significantly, but the semi-supervised architectures still score slightly higher than the supervised equivalents. This is where the regularizing effect of the decoder becomes apparent, the training requirement of reconstructing the original input really forces the classification part of the model to gain a deeper, more generalized representation of the data. This effect is even more pronounced when considering the log loss.
The recurrent models were tested in the same way using the UCI-HAR data, with the results showing an even more drastic improvement over the static models. For the following experiments, the models were trained for 40 epochs with varying label availability.
Again, the semi-supervised model clearly outperforms its supervised equivalent in every category; however the difference with lower label availability in this case is significantly higher.
The anomaly detection task is set up as follows: Take a fully labelled dataset and choose one of the label classes as the one that you want to consider ‘anomalous’ in what follows. Train an anomaly detector on a sample of the data points of the remaining ‘normal’ classes. At inference time, compare the anomaly scores the anomaly detector produces using the ‘anomalous’ data points with those of a held-out set of ‘normal’ data points.
When we consider the VAE with pi layer, we additionally feed the label class of the ‘normal’ datapoint into the model while training. For the dense and convolutional models, we were again using the MNIST dataset for benchmarking.
The results are as follows:
These results show that the model consuming labels always outperforms the equivalent unsupervised model not consuming labels, hence proving the value proposition.
A bonus feature of this model is data generation. Like a VAE, the data is represented as a probability distribution, so we can generate data with the decoder by randomly sampling from this distribution. As with the M2 model the addition of the latent class label allows us to specify the class of data we want to generate. For example in the case of MNIST we can specify which digit we would like to generate. Another feature retained from the M2 model is the ability to separate the class of data from the style. As the latent layer is essentially just a probabilistic representation of the data on which the model was trained, depending on where we sample from the distribution, we are changing the source material on which our data will be generated. So, with MNIST for example, sampling from the centre of the distribution will generate the most common handwriting styles, which vary as we move to the edges of the distribution. Put simply, it’s possible to generate any digit with any handwriting style (limited on what’s available in the training set of course).
An example of different styles of the digit two generated by the model looks like (Fig9):