Bayes by Backprop (BBB): Robust Neural Networks to Unified theory of Brain?
In this post, we review a key algorithm Back prop by Bayes, which allows us to do free energy minimization in practice. It is present how to do variational inference in large models and is a key step towards to building developing “safe” and “robust” mission-critical machine learning systems. Beyond machine learning system this method also can be applied to empirical evaluate computation models of the brain based on the Free Energy Framework, which has been applied in many different areas, including the notion of false inference.
This blog post was inspired by a talk from Prof Karl Friston, who presents very a compelling slide bridging pioneering work in machine learning and generative models of the brain, with a focus on free energy minimization.
Rather than Firston's top-down approach driven-by theory, we start with a practical problem.
Neural Networks (NN) are highly susceptible to adversarial attacks. Let’s say we have trained the best Convolution NN to classify images. If we just add a few noisy pixels to one of the training examples and pass it through the network. The trained network provides a completely different output with a high degree of accuracy. The following picture demonstrates the issue, where both images look like “pandas”, with the addition of noise our network classifies the images as a “gibbon”. Clearly, this isn’t an issue with our vision or perception of the image, the panda is still a panda. This true for other variants of Deep NN as demonstrated here. If you would like to read more about the subject please see this blog.
So why is this an issue? In our learning process, we learn the weights (w) of our network. We start a random set of weights. During the training output of our NN, f(w) is updated constantly by feeding the data to the network and back-propagating the error/loss to adjust the weights “w” until we reach a satisfactory level of performance. Data splitting and regularisation techniques are typically employed to maintain model performance on newer examples. At this point, we obtain a set of weights w, also known as point estimates of w, which are the best estimates for our current dataset.
There are two situations which lead to uncertainty in this training process. First, we are either uncertain about the model parameters i.e. weights (w) related to the data D, leading into epistemic uncertainty. Employing more data generally helps us to reduce epistemic uncertainty. This is the focus of this discussion.
The second source of uncertainty is aleatoric uncertainty.
Solutions: There are two ways of approaching issues related to adversarial attacks.
GAN: One of the proposed solutions is Adversarial training. We start sampling from a noisy distribution, the noise samples are fed through a generator NN (G), to generate synthetic data. A discriminator NN (D) is employed to discriminate between these generated synthetic data and real (or training) data. The adversarial loss function (V) for this training is
We define adversarial loss V by employing a zero-sum game where G wants i.e increases its ability to fool D, while D wants to maximize it’s ability to differentiate between real and synthetic images.
This training process is designed to ensure the distributions of noise samples and our training data are as close as possible. This can be quantified using the Kullback-Leibler (KL) divergence. In practice, training such a network is much harder, one needs to account for issues such as mode collapse, gradient issues, model scaling, etc. Jonathan Hui provides an excellent over of GAN’s in his blog series. However, it is important to state that the theoretical underpins of the GAN’s are also not well understood. So let’s discuss a relatively theoretical ground approach, Variational inference, which dates backs to Feynman.
Variational inference: In addition to being difficult to train, Adversarial methods provide point estimates of the weights w. Whilst adding “noise” was a really simplistic discussion, what happens if our data changes after some time, e.g. healthcare, finance, engineering? These are common in the real world. Do we train the network again? What if the data is changing constantly? As Thomas Wiecke explains with an elegant example in his blog post. What we need to be able to select/sample a set of weights that best describe our stochastic data.
Thus instead of using point estimates or a set of w, we need to obtain the probability distribution of weights w to our data i.e. P(w|D). We can then obtain w from the distribution P(w|D), i.e. the set of the weights which best describe our data. This where can we apply classical Bayes theorem, in Bayes framework the conditional probability P(w|D) is known as posterior
Bayes theorem
Posterior, P(H|E) = (Prior P(H) * likelihood P(E|H))| Evidence P(E)
where H is some hypothesis and E is the Evidence
In our case H are the weights w and evidence is data D
Where the denominator is the sum of all possible likelihoods for all possible prior values i.e. weights. With millions of weights in a typical Deep NN, the integral (in the denominator) is intractable. This integral is also know as marginal of P(D). This classical, Bayes approach is impractical in the real world as it is difficult to efficiently compute the posterior.
Therefore posterior P(w|D) needs to be estimated. Sampling methods can be employed for small models. However, variational inference (VI) is tool also gaining increasing importance. In VI employ another posterior q(w|θ), which is called known as variational posterior, q(w|θ) is another distribution parameterized by θ which are obtained by minimizing the KL between q(w|θ) and true posterior P(w|D) using an optimization process. This provides us θ* which lead to q(w|θ*), which is our best estimate of posterior P(w|D)
The resulting cost/loss function is known as variational free energy or the evidence-based lower bound [ELBO]
The variational posterior q(w|θ) is also intractable due to the integral term in KL which is as follows
So we approximate again: i.e. q(w|θ) is approximated.
Bayes by Backprop (BBB): Approximating q(w|θ) is a computational bottleneck for larger models that are encountered in practice. Blundell et al. 2015 addressed this issue, they applied the reparameterization trick from Kingma et al. 2014 to obtain variational posterior q(w|θ) using stochastic gradient descent and sampling process. This is the highlight of the BBB algorithm.
Reparameterisation trick: The reparameterization trick samples the variational parameters θ from an arbitrary distribution. To simplify our discussion, we use the normal distribution, where is θ parameterized with mean (μ) and standard deviation (σ). Both μ and σ are stochastic, and σ is diagonal elements of the covariance matrix. The “trick” is to sample θ from a parameter-free distribution ϵ and then transform the sampled ϵ to θ with a deterministic function f(ϵ). We can learn the function f (and hence θ) using gradient updates, the key equations along with visualization are shown below. In practice, we use the log of the variable (ρ) to avoid negative values to help training, but that is just an implementation issue we can skip it.
This enables us to compute the variational free energy using a sampling process and stochastic gradient descent. It provides the following equation for the loss function or variational free energy. Notice the integral is now changed to sum.
The above loss function is readily implemented in modern automated gradient computing frameworks like Tensorflow. One can define, prior over weights i.e. p(w), to obtain variational posterior q(w|θ) by sampling and using the loss function during training. Please read this article for LSTM’s
Importance of BBB: BBB provides a practical way to minimize Variational Free Energy. Let's dig a little deeper into Variational Free Energy F(D,θ)
- KL term is a complexity term dependent on model parameters i.e. weights and variational posterior
- The second term is the negative likelihood of the data in the model, which is equivalent to minimizing surprise or maximizing data-dependent accuracy, i.e. epistemic uncertainty.
By minimizing the Variational Free Energy, the arbitrary distribution q(θ) gets closer to the true posterior P(w|D), and if they match, the KL term becomes zero and variational free energy is exactly equal to surprise. Minimizing free energy further by optimizing the parameters of priors i.e. w we can minimize surprises even further. Thus KL terms account for model complexity and act as a regulariser.
This theoretical foundation of variational free energy dates back to the Bayesian brain hypothesis, initially proposed by Helmholtz, and the formalized by Hinton. Tools like Tensorflow provide (implementation basic building block i.e. layers, dense, convolution) and make it feasible to program a range of probabilistic or Helmholtz machines with few lines of code.
Hinton influenced Karl Friston at UCL who developed free energy framework, a principled top-down approach to think about the brain. In this framework, the brain is a considered hierarchical dynamic model, minimizing the global free energy for active inference. Friston reasons that the HDM formalism accounts for multiple models each replicating different brain processes such as perception, action, learning, and attention. These models are (conditionally) independent, mathematical formalized using the notion of Markov blankets. In this framework, under simplistic assumptions, the brain which is an HDM is constantly minimizing free energies of various models under Markov blankets.
BBB provides a means to minimize Variational free energy, which is central to these models, and thus potentially permits empirical validation of these models. Though the modeling is much more complex as it attempts to encapsulate various areas of the machine learning in one under umbrella i.e. vision, attention, feedback (reinforcement learning), reconstruction via a global loss function.
The process of active inference and relevance to the free energy framework is explained in this blog post with limited math. I would suggest consulting the twitter account which provides a humours take on how dense and accessible this material is. We explain that models based on this framework can being employed to investigated false inference in psychiatric disorders e.g. hallucinations. A more practical application in being in AI safety to avoid hallucinations in powerful generative models, which are future of AI
Figure 4: The free-energy principle: a unified brain theory? Nature 2008
Closing thoughts: BBB was originally designed to leverage uncertainty to improve the performance of reinforcement learning algorithms. This algorithm provides a powerful framework to compute an important loss function (at scale) i.e. variational free energy using stochastic gradient descent. It does so by sampling weights (w) from a prior distribution and minimizing the variational free energy by employing a variational posterior. Minimizing variational free energy ensures that we are reducing the epistemic uncertainty, which leads to models being more robust. BBB can be applied on any kind of model, covnets, RNN, reinforcement learning and graphical models, and even on adversarial approaches. BBB does not impose any restriction on the choices of prior employed in our models and can be deployed by leveraging large-scale computation power (GPU’s). This makes BBB a very generic tool at disposal for application in cutting edge machine learning and neuroscience research.