In this post, we will try to connect Energy-Based Model with classical optimal control frameworks like Model-Predictive Control from the perspective of Lagrangian optimization.
This is one part of the series about energy-based learning and optimal control. A recommended reading order is:
- Notes on “The Energy-Based Learning Model” by Yann LeCun, 2021
- Learning Data Distribution Via Gradient Estimation
- [From MPC to Energy-Based Policy]
- How Would Diffusion Model Help Robot Imitation
- Causality hidden in EBM
Review of EKF and MPC
Consider a state-space model with process noise and measurement noise as:
$$
\left{
\right.
$$
With a state
To perform optimal control on such system, we first predict and update our observation with Extended Kalman Filter:
-
Prediction Step
-
Priori Estimation:
-
Jacobian of the State Transition Function:
-
Error Covariance Prediction:
-
-
Update Step
-
Jacobian of the Measurement Function:
-
Kalman Gain:
-
Posterior Estimation:
-
Error Covariance Update:
-
With the estimated state we can perform Model-Based Control with the following condition:
$$
\begin{aligned}
\min_{\mathbf{u}} \quad J &= \sum_{i=0}^{N-1} \ell(x[k+i], u[k+i]) + \ell_f(x[k+N]) \
\ell(x, u) &= (x - x_{\text{ref}})^{\top} Q (x - x_{\text{ref}}) + u^{\top} R u\
\text{s.t.} \quad & x[k+i+1] = f(x[k+i], u[k+i]) \quad \forall i = 0, \dots, N-1 \
& x[k] = \hat{x}[k] \
& x_{\text{min}} \leq x[k+i] \leq x_{\text{max}} \quad \forall i \
& u_{\text{min}} \leq u[k+i] \leq u_{\text{max}} \quad \forall i
\end{aligned}
$$
: Control input sequence. : Prediction horizon. : Stage cost function. : Terminal cost function. : State constraints. : Control constraints.
Introducing Energy Based Model
We can replace the state transition function
Supervised Training
Since
Here we have the score function as:
-
Dataset: A set of transitions
with “independently and identically distributed (i.i.d.)” assumption. -
The training objective is to minimize the difference between data landscape
and model landscape , and the objective function is defined as follows, where loss is commonly defined as MSELoss:
$$
J(\theta)=\frac{1}{2}\mathbb{E}{p{data}}[\mathcal L(s_{data}(\mathbf x), s_{\theta}(\mathbf x))] \
=\frac{1}{2}\mathbb{E}{p{data}}[||s_{data}(\mathbf x)- s_{\theta}(\mathbf x)||^2]
$$
HOWEVER, we cannot get access to full data distribution. According to (Hyvärinen, et.al , 2005)[1] we may use the following procedures. (More in Appendix A of original paper)
$$
\begin{aligned}
J(\theta)&=\frac{1}{2}\mathbb{E}{p{data}}[\left| s_\theta(\mathbf x) \right|^2 - 2 s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) + \left| s_{\text{data}}(\mathbf x) \right|^2] \
&=\frac{1}{2} \mathbb{E}_{p_{\text{data}}} \left[ \left\| s_\theta(\mathbf x) \right\|^2 \right] - \mathbb{E}_{p_{\text{data}}} \left[ s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) \right] + \overbrace{\frac{1}{2} \mathbb{E}_{p_{\text{data}}} \left[ \left\| s_{\text{data}}(\mathbf x) \right\|^2 \right]}^{\text{constant}} \\
J’(\theta)&=\frac{1}{2} \mathbb{E}{p{\text{data}}} \left[ \left| s_\theta(\mathbf x) \right|^2 \right] - \mathbb{E}{p{\text{data}}} \left[ s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) \right] \
\end{aligned}
$$
$$
\mathbb{E}{p{\text{data}}} \left[ s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) \right] = \int p_{\text{data}}(\mathbf x) s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) d\mathbf x \
= \int s_\theta(\mathbf x)^\top \nabla_{\mathbf x} p_{\text{data}}(\mathbf x) d\mathbf x
$$
By integrating by parts, we move the derivative from
$$
\int s_\theta(\mathbf x)^\top \nabla_{\mathbf x} p_{\text{data}}(\mathbf x) d\mathbf x = -\int p_{\text{data}}(\mathbf x) \text{div}{\mathbf x} s\theta(\mathbf x) d\mathbf x \
\text{div}{\mathbf x} s\theta(\mathbf x) = \text{div}{\mathbf x} \left( -\nabla{\mathbf x} E_\theta(\mathbf x) \right) = -\text{div}{\mathbf x} \nabla{\mathbf x} E_\theta(\mathbf x) = -\Delta_{\mathbf x} E_\theta(\mathbf x)\
\mathbb{E}{p{\text{data}}} \left[ s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) \right] = -\mathbb{E}{p{\text{data}}} \left[ \text{div}{\mathbf x} s\theta(\mathbf x) \right]
\begin{aligned}
J(\theta) &= \frac{1}{2} \mathbb{E}{p{\text{data}}} \left[ \left| \nabla_{\mathbf x} E_\theta(\mathbf x) \right|^2 \right] + \mathbb{E}{p{\text{data}}} \left[ \Delta_{\mathbf x} E_\theta(\mathbf x) \right] \
J_k(\theta) &= \text{Tr} \left( \nabla_{\mathbf x[k]}^2 E(\mathbf x[k]) \right) + \frac{1}{2} \left| \nabla_{\mathbf x[k]} E(\mathbf x[k]) \right|^2
\end{aligned}
$$
If you haven’t seen such formulation in diffusion models and feel strange:
- The training objective is to learn the distribution of
, which is a known Gaussian distribution since the noise level is provided. This is also why q-sampling requires .
Optimization Based Inferencing
Langevin Dynamics can produce samples from a probability density
$$
\tilde{\mathbf x}t=\tilde{\mathbf x}{t-1}+\frac{\epsilon}{2}\nabla_{\mathbf x}\log p(\tilde{\mathbf x}_{t-1})+\sqrt{\epsilon}\mathbf z_t
$$
where
Hyvärinen, A., & Dayan, P. (2005). Estimation of non-normalized statistical models by score matching. Journal of Machine Learning Research, 6(4). https://jmlr.org/papers/volume6/hyvarinen05a/hyvarinen05a.pdf ↩︎