The goal of imitation learning is given trajectories collected by an expert:
$D := \{(s_1, a_1, s_2, a_2, ... s_T, a_T)\}$, learn the policy $\pi_\theta$ by mimicking it.
One trivial deterministic policy is to minimize l2-norm to experts’ actions:
$min \frac{1}{|D|}\sum_{(s, a)}|| a - \bar{a}|| ^2$, such that $\bar{a}=\pi_\theta(s)$. This is a batch process to compute a deterministic solution given each s, it will find a corresponding a.
However, if the expert data is multi-modal, the policy will be somewhere in between. For example in a driving scene, actions in the expert data are either keep the lane or merge to the left lane. The deterministic solution can be somewhere in between (two lanes).
There are several ways to learn the distribution better.
- Mixture of Gaussian model: trying to learn the distributions by setting the number of Gaussians as a hyper-parameter. The limitation is that it can only represent a distribution that is actually a mixture of Gaussians.
- Discretize the action dimension, and use auto-regressive model: this approach models all of the possible actions to K dimension. The output is the distribution of K possible values, with a probability of each. The model can be trained such that:
- The input is the state $s$.
- Generate the value of the first dimension of the action $a_1$.
- Given $a_1$, generate $a_2$.
- Until generating $a_K$.
- Diffusion model: iteratively refine the output action vector from $s$ and a noise vector.
How to solve compounding errors in imitation learning?
Compounding error (as a result of covariate shift) is from the drift away from the data distribution.
There are 3 main approaches to resolve compounding errors:
- Just collect A LOT OF data.
- Collect corrective behavior data based on each state: once we discover the drift at a state $s’$, find an expert to perform the corrective action so that the next state is back to normal. Finally aggregate all data into the dataset. This approach is called dataset aggregation (DAgger). In reality it is not scalable because it’s not trivial to reproduce the specific drifted state and ask help from the expert.
- The better way is to collect corrective behavior data while taking full control: when we discover the drift, we ask the expert to take full control of the agent. It doesn’t need to be back to normal immediately, but can be back to normal a bit later. The difference to the previous method is:
- DAgger: the expert only performs one corrective action given a drifted state $s’$.
- Human gated DAgger (this approach): the expert takes full control for several timestamps. It is a more practical interface for providing corrections. One example is the safety drive on autonomous vehicles. The limitation is the capability to to detect when the intervention is needed.