Robust Representation in RL

Incomplete ...


How good are our Deep RL agents when trained on high-dimensional state spaces? Images are a convenient and inexpensive way of acquiring state information as opposed to proprioceptive state of the underlying dynamics.
Learning from pixel frames is nothing new, it started with Atari games and DQN. But in the context of learning control and other tasks, like mujoco or robot manipulation, it is still hard because even learning long horizon trajectories from low level control state can be challenging. Also even though deep convolutional encoders can learn good representations, they still require large amounts of training data. Existing RL approaches still suffer from poor sample complexity.

Some impressive progress has been made applying model-free RL to domains with simple dynamics and discrete action spaces. Extending these methods to complex continuous control environments has been challenging. Why? Well various reasons. Atari tasks are deterministic, so when you press action 'Go right' you go right. This is not the case with environments with more complex dynamics such as continuous control of a robot.
RL signal is much sparser compared to supervised learning. In addition to all of this, an RL policy now needs to solve two problems: a representation learning problem, and a task learning problem. Therefore learning from high-dimensional data for RL is still a hard problem because we still cannot match the performance to that of learning from low level state. So far the research effort has been focused on getting these RL algorithms to learn useful behaviours (in a semi-robust manner at least). And care less about the fact that these RL agents are probably just overfitting to the training environments.
Of course the ultimate goal of RL is to understand how brains do this and try to take inspirations. Some nice article about studying the invariance of representation in the human brain, for example Invariant representations of mass in the human brain paper. It studies how an intuitive understanding of physics develops early in the brain, where object's weight can be inferred based on compression of the material or stability can be judged based on object's center of mass.

The following paper summaries try to answer some of these questions:

  • How robust are our methods to the changes in representation? This can be changes in the colour, texture, geometry, symmetry etc
  • What are some good practical techniques to make representation robust in RL?
When we talk about generalization in RL, we need to be specific about which characteristic we are referring to. We can generalize to $1$) visual changes, $2$) different dynamics, $3$) problem structure which can be for instance defining two tasks that are similar enough and in theory the agent should be able to solve both, having being trained on only one. Let's deal with number $1$, the visual changes setting. There are various strategies that come up a lot in deep learning, representation learning and domain adaptation specifically,

  • Regularization
  • Data Augmentation such as visual domain randomization.
  • Random networks
The literature here is huge, let's focus on some latest deep RL papers.

Network Randomization

The main idea is to use a random convolutional network to generate randomized inputs and train the RL agent on these randomly generated observations. By re-initializing the parameters of the random network at every iteration, the agents are constrained to be trained under a broad range of perturbed low-level features.

Random networks for deep RL

This one is interesting as its slightly different from the general mode of regularization and data augmentation and noise injection in general. The paper Exploration by Random Network Distillation explores the idea of using a randomly initialized network to define an intrisic reward for visiting unexplored parts of the state. So by learning to predict reward for visiting unexplored states, the agent recognizes unexplored regions of the state. Ensemble-based approaches can sometimes help to improve uncertainty estimation in the context of exploration and recognizing unseen states. Lets say we have a random network $f$ with parameters $\phi$. We take the original input $s$, randomize it using this network, $\hat{s} = f(s;\phi)$, then train the agent on this input observation. We expect the agent to learn to ignore the noise and learn a robust representation that matters. To enforce this, they use a feature matching (FM) loss between hidden features from clean and randomized observations, defined as a $L2$ norm loss, $\mathcal{L}^{random}_{FM}= \mathbb{E} [|| h(f(s_t;\phi);\theta) - h(s_t; \theta) ||^2] $

Removing visual bias

Usually in Deep Learning, people study the robustness of CNNs for example to changes of colour, texture, shapes etc. To run experiments for these phenomenons, they take an image classification dataset such as cats and dogs, where dogs are bright, cats are dark in the training set, and the other way round, dogs are dark and cats are bright in the test set. And it has been shown that CNNs are biased towards texture or color, but still do a good job at shapes and other structures ImageNet-trained CNNs are biased towards texture Methods to solve these issues :
  • Grayout (GR)
  • cutout
  • Inversion (IV)
  • color jitter (CJ)
But again these tricks only take you so far.

D4PG

This is an off-policy actor-critic algorithm, learning directly from raw images. They use a distributional version of critic update. These distributions model the randomness due to instrinsic factors.

PlaNet

Model-based RL method which uses a mixed deterministic/stochastic sequential latent variable model, but without explicit policy learning. Instead, the mode is used for planning with model predictive control (MPC) where each plan is optimized with the cross entropy method (CEM). (Deep Planning Networks or “PlaNet”)
  • Model-Based
  • Model predicts future images and rewards using a sequence of compact latent states, trained as a sequential VAE.
  • Reconstruction error gives a training signal.

They learn an encoder $q(s_t | o_{\leq t}, a_{< t})$ to infer an approximate belief over the current hidden state from the history using filtering. Learn environment's dynamics directly on the target task. What do they mean by target task? Learn a transition model and then use it for MPC. Explain MPC PlaNet consists of 3 networks:

  • (a) a Variational Autoencoder (VAE) that encodes an image observation into a latent state;
  • (b) a reward estimator that learns the reward that is associated with each latent state;
  • (c) a recurrent neural network (RNN) which learns to predict the next latent state given the previous one and an action.
These components work together to encode a single observation, rolling out multiple altenative trajectories in latent space without further interactions with the environment. This method is also off-policy because because data is stored in a replay buffer.

SAC-AE

This paper is based on a simple encoder decoder architecture with reconstruction loss as a training signal. Is this robust? How different is the learned representation from our object detection architectures? Let's take the reacher environment as an example. If I place the goal on the left and train, then place the goal to the right at the evaluation phase, can it still solve the task? If yes, then its definitely looking for the red thing that looks like a goal? If not, then why not? Is it the representation or the behaviour policy just hasn't experienced that part of the state, so its not realistic to expect it to reach there? What is this high-fidelity model that encapsulates strong inductive biases?

SLAC

Their architecture learns a latent representation using a stochastic sequential latent variable model. It's off-policy and model-free.