In this post, we will try to connect Energy-Based Model with classical optimal control frameworks like Model-Predictive Control from the perspective of Lagrangian optimization.
This is one part of the series about energy-based learning and optimal control. A recommended reading order is:
Notes on “The Energy-Based Learning Model” by Yann LeCun, 2021
Learning Data Distribution Via Gradient Estimation
[From MPC to Energy-Based Policy]
How Would Diffusion Model Help Robot Imitation
Causality hidden in EBM
Review of EKF and MPC Consider a state-space model with process noise and measurement noise as:
{ x [ k ] = f ( x [ k − 1 ] , u [ k − 1 ] , w [ k − 1 ] ) , w ∼ N ( 0 , Q w ) z [ k ] = h ( x [ k ] , v [ k ] ) , v ∼ N ( 0 , Q v ) \left\{ \begin{aligned} x[k] &= f(x[k-1], u[k-1], w[k-1]),\quad &w \sim \mathcal{N}(0, Q_w) \\\\ z[k] &= h(x[k], v[k]),\quad &v \sim \mathcal{N}(0, Q_v) \end{aligned} \right. ⎩ ⎨ ⎧ x [ k ] z [ k ] = f ( x [ k − 1 ] , u [ k − 1 ] , w [ k − 1 ]) , = h ( x [ k ] , v [ k ]) , w ∼ N ( 0 , Q w ) v ∼ N ( 0 , Q v )
With a state x x x and observation z z z .
To perform optimal control on such system, we first predict and update our observation with Extended Kalman Filter:
Prediction Step
Update Step
With the estimated state we can perform Model-Based Control with the following condition:
min u J = ∑ i = 0 N − 1 ℓ ( x [ k + i ] , u [ k + i ] ) + ℓ f ( x [ k + N ] ) ℓ ( x , u ) = ( x − x ref ) ⊤ Q ( x − x ref ) + u ⊤ R u s.t. x [ k + i + 1 ] = f ( x [ k + i ] , u [ k + i ] ) ∀ i = 0 , … , N − 1 x [ k ] = x ^ [ k ] x min ≤ x [ k + i ] ≤ x max ∀ i u min ≤ u [ k + i ] ≤ u max ∀ i \begin{aligned} \min_{\mathbf{u}} \quad J &= \sum_{i=0}^{N-1} \ell(x[k+i], u[k+i]) + \ell_f(x[k+N]) \\\\ \ell(x, u) &= (x - x_{\text{ref}})^{\top} Q (x - x_{\text{ref}}) + u^{\top} R u \\\\ \text{s.t.} \quad & x[k+i+1] = f(x[k+i], u[k+i]) \quad \forall i = 0, \dots, N-1 \\\\ & x[k] = \hat{x}[k] \\\\ & x_{\text{min}} \leq x[k+i] \leq x_{\text{max}} \quad \forall i \\\\ & u_{\text{min}} \leq u[k+i] \leq u_{\text{max}} \quad \forall i \end{aligned} u min J ℓ ( x , u ) s.t. = i = 0 ∑ N − 1 ℓ ( x [ k + i ] , u [ k + i ]) + ℓ f ( x [ k + N ]) = ( x − x ref ) ⊤ Q ( x − x ref ) + u ⊤ R u x [ k + i + 1 ] = f ( x [ k + i ] , u [ k + i ]) ∀ i = 0 , … , N − 1 x [ k ] = x ^ [ k ] x min ≤ x [ k + i ] ≤ x max ∀ i u min ≤ u [ k + i ] ≤ u max ∀ i
u = { u [ k ] , u [ k + 1 ] , … , u [ k + N − 1 ] } \mathbf{u} = \{ u[k], u[k+1], \dots, u[k+N-1] \} u = { u [ k ] , u [ k + 1 ] , … , u [ k + N − 1 ]} : Control input sequence.
N N N : Prediction horizon.
ℓ ( x , u ) \ell(x, u) ℓ ( x , u ) : Stage cost function.
ℓ f ( x ) \ell_f(x) ℓ f ( x ) : Terminal cost function.
x min , x max x_{\text{min}}, x_{\text{max}} x min , x max : State constraints.
u min , u max u_{\text{min}}, u_{\text{max}} u min , u max : Control constraints.
Introducing Energy Based Model We can replace the state transition function f f f in the formulated state space model with EBM as:
p ( x [ k ] ∣ x [ k − 1 ] , u [ k − 1 ] ) = e − E ( x [ k ] , x [ k − 1 ] , u [ k − 1 ] ) Z ( x [ k − 1 ] , u [ k − 1 ] )
p(x[k]∣x[k−1],u[k−1])=\frac{e^{−E(x[k],x[k−1],u[k−1])}}{Z(x[k−1],u[k−1])}
p ( x [ k ] ∣ x [ k − 1 ] , u [ k − 1 ]) = Z ( x [ k − 1 ] , u [ k − 1 ]) e − E ( x [ k ] , x [ k − 1 ] , u [ k − 1 ])
Supervised Training Since Z Z Z is often intractable, we will use score matching to learn the EBM in the following content.
Here we have the score function as:
s θ ( x [ k ] , x [ k − 1 ] , u [ k − 1 ] ) = s θ ( x ) = ∇ x log p θ ( x ) = − ∇ x E θ ( x )
s_\theta (x[k],x[k−1],u[k−1])=s_\theta(\mathbf x) = \nabla_{\mathbf x} \log p_\theta(\mathbf x)=-\nabla_{\mathbf x} E_\theta(\mathbf x)
s θ ( x [ k ] , x [ k − 1 ] , u [ k − 1 ]) = s θ ( x ) = ∇ x log p θ ( x ) = − ∇ x E θ ( x )
Dataset : A set of transitions { ( x i [ k − 1 ] , u i [ k − 1 ] , x i [ k ] ) } i = 1 N \{(x_i[k-1], u_i[k-1], x_i[k])\}^N_{i=1} {( x i [ k − 1 ] , u i [ k − 1 ] , x i [ k ]) } i = 1 N with “independently and identically distributed (i.i.d.)” assumption.
The training objective is to minimize the difference between data landscape s d a t a ( x ) s_{data}(\mathbf x) s d a t a ( x ) and model landscape s θ ( x ) s_\theta(\mathbf x) s θ ( x ) , and the objective function is defined as follows, where loss L \mathcal L L is commonly defined as MSELoss:
J ( θ ) = 1 2 E p d a t a [ L ( s d a t a ( x ) , s θ ( x ) ) ] = 1 2 E p d a t a [ ∣ ∣ s d a t a ( x ) − s θ ( x ) ∣ ∣ 2 ]
J(\theta)=\frac{1}{2}\mathbb{E}_{p_{data}}[\mathcal L(s_{data}(\mathbf x), s_{\theta}(\mathbf x))] \
=\frac{1}{2}\mathbb{E}_{p_{data}}[||s_{data}(\mathbf x)- s_{\theta}(\mathbf x)||^2]
J ( θ ) = 2 1 E p d a t a [ L ( s d a t a ( x ) , s θ ( x ))] = 2 1 E p d a t a [ ∣∣ s d a t a ( x ) − s θ ( x ) ∣ ∣ 2 ]
HOWEVER, we cannot get access to full data distribution. According to (Hyvärinen, et.al , 2005)1 we may use the following procedures. (More in Appendix A of original paper)
J ( θ ) = 1 2 E p d a t a [ ∥ s θ ( x ) ∥ 2 − 2 s θ ( x ) ⊤ s data ( x ) + ∥ s data ( x ) ∥ 2 ] = 1 2 E p data [ ∥ s θ ( x ) ∥ 2 ] − E p data [ s θ ( x ) ⊤ s data ( x ) ] + 1 2 E p data [ ∥ s data ( x ) ∥ 2 ] ⏞ constant J ′ ( θ ) = 1 2 E p data [ ∥ s θ ( x ) ∥ 2 ] − E p data [ s θ ( x ) ⊤ s data ( x ) ] \begin{aligned}
J(\theta)&=\frac{1}{2}\mathbb{E}_{p_{data}}[\left\| s_\theta(\mathbf x) \right\|^2 - 2 s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) + \left\| s_{\text{data}}(\mathbf x) \right\|^2] \\
&=\frac{1}{2} \mathbb{E}_{p_{\text{data}}} \left[ \left\| s_\theta(\mathbf x) \right\|^2 \right] - \mathbb{E}_{p_{\text{data}}} \left[ s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) \right] + \overbrace{\frac{1}{2} \mathbb{E}_{p_{\text{data}}} \left[ \left\| s_{\text{data}}(\mathbf x) \right\|^2 \right]}^{\text{constant}} \\
J'(\theta)&=\frac{1}{2} \mathbb{E}_{p_{\text{data}}} \left[ \left\| s_\theta(\mathbf x) \right\|^2 \right] - \mathbb{E}_{p_{\text{data}}} \left[ s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) \right]
\end{aligned} J ( θ ) J ′ ( θ ) = 2 1 E p d a t a [ ∥ s θ ( x ) ∥ 2 − 2 s θ ( x ) ⊤ s data ( x ) + ∥ s data ( x ) ∥ 2 ] = 2 1 E p data [ ∥ s θ ( x ) ∥ 2 ] − E p data [ s θ ( x ) ⊤ s data ( x ) ] + 2 1 E p data [ ∥ s data ( x ) ∥ 2 ] constant = 2 1 E p data [ ∥ s θ ( x ) ∥ 2 ] − E p data [ s θ ( x ) ⊤ s data ( x ) ]
E p data [ s θ ( x ) ⊤ s data ( x ) ] = ∫ p data ( x ) s θ ( x ) ⊤ s data ( x ) d x = ∫ s θ ( x ) ⊤ ∇ x p data ( x ) d x
\mathbb{E}_{p_{\text{data}}} \left[ s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) \right] = \int p_{\text{data}}(\mathbf x) s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) d\mathbf x \
= \int s_\theta(\mathbf x)^\top \nabla_{\mathbf x} p_{\text{data}}(\mathbf x) d\mathbf x
E p data [ s θ ( x ) ⊤ s data ( x ) ] = ∫ p data ( x ) s θ ( x ) ⊤ s data ( x ) d x = ∫ s θ ( x ) ⊤ ∇ x p data ( x ) d x
By integrating by parts, we move the derivative from p data ( x ) p_{\text{data}}(x) p data ( x ) to s θ ( x ) s_\theta(x) s θ ( x ) :
∫ s θ ( x ) ⊤ ∇ x p data ( x ) d x = − ∫ p data ( x ) div x s θ ( x ) d x div x s θ ( x ) = div x ( − ∇ x E θ ( x ) ) = − div x ∇ x E θ ( x ) = − Δ x E θ ( x ) E p data [ s θ ( x ) ⊤ s data ( x ) ] = − E p data [ div x s θ ( x ) ]
\int s_\theta(\mathbf x)^\top \nabla_{\mathbf x} p_{\text{data}}(\mathbf x) d\mathbf x = -\int p_{\text{data}}(\mathbf x) \text{div}_{\mathbf x} s_\theta(\mathbf x) d\mathbf x \\
\text{div}_{\mathbf x} s_\theta(\mathbf x) = \text{div}_{\mathbf x} \left( -\nabla_{\mathbf x} E_\theta(\mathbf x) \right) = -\text{div}_{\mathbf x} \nabla_{\mathbf x} E_\theta(\mathbf x) = -\Delta_{\mathbf x} E_\theta(\mathbf x)\\
\mathbb{E}_{p_{\text{data}}} \left[ s_\theta(\mathbf x)^\top s_{\text{data}}(\mathbf x) \right] = -\mathbb{E}_{p_{\text{data}}} \left[ \text{div}_{\mathbf x} s_\theta(\mathbf x) \right]
∫ s θ ( x ) ⊤ ∇ x p data ( x ) d x = − ∫ p data ( x ) div x s θ ( x ) d x div x s θ ( x ) = div x ( − ∇ x E θ ( x ) ) = − div x ∇ x E θ ( x ) = − Δ x E θ ( x ) E p data [ s θ ( x ) ⊤ s data ( x ) ] = − E p data [ div x s θ ( x ) ]
Eventually we obtain the updated objective function
J ( θ ) = 1 2 E p data [ ∥ ∇ x E θ ( x ) ∥ 2 ] + E p data [ Δ x E θ ( x ) ] J k ( θ ) = Tr ( ∇ x [ k ] 2 E ( x [ k ] ) ) + 1 2 ∥ ∇ x [ k ] E ( x [ k ] ) ∥ 2
\begin{aligned}
J(\theta) &= \frac{1}{2} \mathbb{E}_{p_{\text{data}}} \left[ \left\| \nabla_{\mathbf x} E_\theta(\mathbf x) \right\|^2 \right] + \mathbb{E}_{p_{\text{data}}} \left[ \Delta_{\mathbf x} E_\theta(\mathbf x) \right] \\
J_k(\theta) &= \text{Tr} \left( \nabla_{\mathbf x[k]}^2 E(\mathbf x[k]) \right) + \frac{1}{2} \left\| \nabla_{\mathbf x[k]} E(\mathbf x[k]) \right\|^2
\end{aligned}
J ( θ ) J k ( θ ) = 2 1 E p data [ ∥ ∇ x E θ ( x ) ∥ 2 ] + E p data [ Δ x E θ ( x ) ] = Tr ( ∇ x [ k ] 2 E ( x [ k ]) ) + 2 1 ∇ x [ k ] E ( x [ k ]) 2
Tr ( ∇ x [ k ] 2 E ( x [ k ] ) ) \text{Tr} \left( \nabla_{\mathbf x[k]}^2 E(\mathbf x[k]) \right) Tr ( ∇ x [ k ] 2 E ( x [ k ]) ) denotes the trace of Hessian matrix (or Jacobian) of score function w.r.t.
x [ k ] \mathbf x[k] x [ k ] .
If you haven’t seen such formulation in diffusion models and feel strange:
The training objective is to learn the distribution of q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t, x_0) q ( x t − 1 ∣ x t , x 0 ) , which is a known Gaussian distribution since the noise level is provided. This is also why q-sampling requires x 0 x_0 x 0 .
Optimization Based Inferencing Langevin Dynamics can produce samples from a probability density p ( x ) p(\mathbf x) p ( x ) using only the score function ∇ x log p ( x ) \nabla_{\mathbf x}\log p(\mathbf x) ∇ x log p ( x ) . Given a fixed step size ϵ > 0 \epsilon >0 ϵ > 0 , and an initial value x ~ ∼ π ( x ) \tilde{\mathbf x} \sim \pi(\mathbf x) x ~ ∼ π ( x ) with π \pi π being a prior distribution, the Langevin method recursively computes the following
x ~ t = x ~ t − 1 + ϵ 2 ∇ x log p ( x ~ t − 1 ) + ϵ z t
\tilde{\mathbf x}_t=\tilde{\mathbf x}_{t-1}+\frac{\epsilon}{2}\nabla_{\mathbf x}\log p(\tilde{\mathbf x}_{t-1})+\sqrt{\epsilon}\mathbf z_t
x ~ t = x ~ t − 1 + 2 ϵ ∇ x log p ( x ~ t − 1 ) + ϵ z t
where z t ∼ N ( 0 , I ) \mathbf z_t \sim \mathcal N(0,I) z t ∼ N ( 0 , I ) . The distribution of x ~ t \tilde{\mathbf x}_t x ~ t equals p ( x ) p(\mathbf x) p ( x ) when ϵ → 0 \epsilon \rightarrow 0 ϵ → 0 and T → ∞ T \rightarrow \infty T → ∞ , or else a Metropolis-Hastings update is needed to correct the error.