How to train your differentiable filter
This is a short overview of the following paper (Kloss, Martius, and Bohg 2021):
Kloss, Alina, Georg Martius, and Jeannette Bohg. 2021. “How to Train Your Differentiable Filter.” Autonomous Robots 45 (4): 561–78. https://doi.org/10.1007/s10514-021-09990-9.
Bayesian filtering for state estimation
Bayesian filters(Thrun 2006) contains a
great explanation of Bayesian filters (including Kalman and particle
filters), in the context of robotics, which is relevant for this
paper. For a more complete overview of Kalman filters, see
(Anderson and Moore 2005).
are the standard method for
probabilistic state estimation. Common examples are (extended,
unscented) Kalman filters and particle filters. These filters require
a process model predicting how the state evolves over time, and an
observation model relating an sensor value to the underlying state.
The objective of a filter for state estimation is to estimate a latent state \(\mathbf{x}\) of a dynamical system at any time step \(t\) given an initial belief \(\mathrm{bel}(\mathbf{x}_0) = p(\mathbf{x}_0)\), a sequence of observations \(\mathbf{z}_{1\ldots t}\), and controls \(\mathbf{u}_{0\ldots t}\).
We make the Markov assumption (i.e. states and observations are conditionally independent from the history of past states).
\[ \begin{align*} \mathrm{bel}(\mathbf{x}_t) &= \eta p(\mathbf{z}_t | \mathbf{x}_t) \int p(\mathbf{x}_t | \mathbf{x}_{t-1}, \mathbf{u}_{t-1}) \mathrm{bel}(\mathbf{x}_{t-1}) d\mathbf{x}_{t-1}\\ &= \eta p(\mathbf{z}_t | \mathbf{x}_t) \overline{\mathrm{bel}}(\mathbf{x}_t), \end{align*} \]
where \(\eta\) is a normalization factor. Computing \(\overline{\mathrm{bel}}(\mathbf{x}_t)\) is the prediction step, and applying \(p(\mathbf{z}_t | \mathbf{x}_t)\) is the update step (or the observation step).
We model the dynamics of the system through a process model \(f\) and an observation model \(h\):
\[ \begin{align*} \mathbf{x}_t &= f(\mathbf{x}_{t-1}, \mathbf{u}_{t-1}, \mathbf{q}_{t-1})\\ \mathbf{z}_t &= h(\mathbf{x}_t, \mathbf{r}_t), \end{align*} \] where \(\mathbf{q}\) and \(\mathbf{r}\) are random variables representing process and observation noise, respectively.
Differentiable Bayesian filters
These models are often difficult to formulate and specify, especially when the application has complex dynamics, with complicated noises, nonlinearities, high-dimensional state or observations, etc.
To improve this situation, the key idea is to learn these complex dynamics and noise models from data. Instead of spending hours in front of a blackboard deriving the equations, we could give a simple model a lot of data and learn the equations from them!
In the case of Bayesian filters, we have to define the process, observation, and noise processes as parameterized functions (e.g. neural networks), and learn their parameters end-to-end, through the entire apparatus of the filter. To learn these parameters, we will use the simplest method: gradient descent. Our filter have to become differentiable.
The paper shows that such differentiable filters (trained end-to-end) outperform unstructured LSTMs, and outperform standard filters where the process and observation models are fixed in advance (i.e. analytically derived or even trained separately in isolation).
In most applications, the process and observation noises are often assumed to be uncorrelated Gaussians, with zero mean and constant covariance (which is a hyperparameter of the filter). With end-to-end training, we can learn these parameters (mean and covariance of the noise), but we can even go further, and use heteroscedastic noise models. In this model, the noise can depend on the state of the system and the applied control.
Learnable process and observation models
The observation model \(f\) can be implemented as a simple feed-forward neural network. Importantly, this NN is trained to output the difference between the next and the current state (\(\mathbf{x}_{t+1} - \mathbf{x}_t\)). This ensure stable gradients and an easier initialization near the identity.
For the observation model, we could do the same and model \(g\) as a generative neural network predicting the output of the sensors. However, the observation space is often high-dimensional, and the network is thus difficult to train. Consequently, the authors use a discriminative neural network to reduce the dimensionality of the raw sensory output.
Learnable noise models
In the Gaussian case, we use neural networks to predict the covariance matrix of the noise processes. To ensure positive-definiteness, the network predicts an upper-triangular matrix \(\mathbf{L}_t\) and the noise covariance matrix is set to \(\mathbf{Q}_t = \mathbf{L}_t \mathbf{L}_t^T\).
In the heteroscedastic case, the noise covariance is predicted from the state and the control input.
Loss function
We assume that we have access to the ground-truth trajectory \(\mathbf{x}_{1\ldots T}\).
We can then use the mean squared error (MSE) between the ground truth and the mean of the belief:
\[ L_\mathrm{MSE} = \frac{1}{T} \sum_{t=0}^T (\mathbf{x}_t - \mathbf{\mu}_t)^T (\mathbf{x}_t - \mathbf{\mu}_t). \]
Alternatively, we can compute the negative log-likelihood of the true state under the belief distribution (represented by a Gaussian of mean \(\mathbf{\mu}_t\) and covariance \(\mathbf{\Sigma}_t\)):
\[ L_\mathrm{NLL} = \frac{1}{2T} \sum_{t=0}^T \log(|\mathbf{\Sigma}_t|) + (\mathbf{x}_t - \mathbf{\mu}_t)^T \mathbf{\Sigma}^{-1} (\mathbf{x}_t - \mathbf{\mu}_t). \]
Implementation issues
We need to implement the filters (EKF, UKF, PF) in a differentiable programming framework. The authors use TensorFlow. Their code is available on GitHub.
Some are easy because they use only differentiable operations (mostly
simple linear algebra). For the EKF, we also need to compute
Jacobians. This can be done automatically via automatic
differentiation, but the authors have encountered technical
difficulties with this (memory consumption or slow computations), so
they recommend computing Jacobians manually.It is not clear
whether this is a limitation of automatic differentiation, or of their
specific implementation with TensorFlow. Some other projects have
successfully computed Jacobians for EKFs with autodiff libraries, like
GaussianFilters.jl in Julia.
The particle filter has a resampling step that is not differentiable: the gradient cannot be propagated to particles that are not selected by the sampling step. There are apparently specific resampling algorithms that help mitigate this issue in practice when training.
Conclusions
Differentiable filters achieve better results with fewer parameters than unstructured models like LSTMs, especially on complex tasks. The paper runs extensive experiments on various toy models of various complexity, although unfortunately no real-world application is shown.
Noise models with full covariance improve the tracking accuracy. Heteroscedastic noise models improve it even more.
The main issue is to keep the training stable. They recommend the differentiable extended Kalman filter for getting started, as it is the most simple filter, and is less sensitive to hyperparameter choices. If the task is strongly non-linear, one should use a differentiable unscented Kalman filter or a differentiable particle filter.