Vel: PyTorch meets baselines

I'd like to introduce you to Vel, a collection of high-quality implementations of deep learning models.

Vel: PyTorch meets baselines

I'd like to introduce you to Vel, a collection of high-quality implementations of deep learning models.

But first, I owe you a few words of introduction.

Introduction

It's been a quite while since I have written anything here. Many things have happened in the meantime. My daughter was born, I've got a promotion in my day job and I got completely and thoroughly sucked into the rabbit hole of recent advances in deep learning, and especially deep reinforcement learning.

This post is the first one in the series about my latest research project Deep Reinforcement Learning in sparse reward environments. This time, I'll be focusing mostly on implementation and infrastructure side of the project, so that the following parts can deal with the problem setting directly.

Project: Vel

When starting this research project together with my friend, we had slightly different views on which direction it should go. He had a bit more of pragmatic approach "This project is about advancing reinforcement learning. Let's pick existing algorithm implementation made by OpenAI[1] and start from there".

But I had my reservations about that path. Firstly, I never feel I fully understand an algorithm unless I have implemented it myself. Secondly, in my opinion PyTorch offers superior developer experience which leads to quicker development time and faster debugging. That was very important to me in a research project I was to conduct during my free time. Last but not least, I was already starting to build a library of various deep learning models I've implemented before and reinforcement learning sure felt it would fit there too.

That's how a Vel[2] project was born. While in principle it is not limited only to reinforcement learning models, they constitute a majority of the codebase at the moment.

From a high-level point of view, library contains various parts that don't necessarily depend on each other, but are tested to work well together.

  • A simple dependency injection mechanism meant to make it easy to define models and workflows using YAML configuration files. (vel.internals)
  • A flexible and versatile training loop implementation for PyTorch models with metrics tracking (vel.api.learner, vel.metrics)
  • Multiple datasources and their transformations (vel.sources, vel.augmentations)
  • Reusable model pieces (vel.modules)
  • Finally, models themselves (vel.models)
  • Last but not least, all of the above, suited to a reinforcement learning setting (vel.rl)

At some point in the future if there will be an interest I'll write a general introduction to the framework, but this time I'll focus specifically on new reinforcement learning algorithm implementations.

As of the moment of writing this article, following models were implemented and (more importantly!) debugged:

  • Deep Q-Learning based on the DeepMind Nature publication, together with some of the improvements: Double DQN, Dueling DQN, Prioritized Experience Replay
  • Advantage Actor-Critic
  • Trust Region Policy Optimization
  • Proximal Policy Optimization
  • Deep Deterministic Policy Gradient
  • Actor-Critic with Experience Replay

References to individual research publications describing these models can be found in bibliography[3].

These implementations are designed to work well with any OpenAI Gym environment, with many example configurations for Atari and MuJoCo environments provided. Whenever possible, I've tried to make models support both discrete and continuous action spaces.

Unfortunately, while I would love to discuss reinforcement learning in general and all of these models in detail, I couldn't fit it all into this article. Learning and starting to understand them was a great journey for me – if you're interested in reading about it, please leave a comment.

I hope this code will help many other researchers in their projects, as that is the purpose of open source code in the end. If you don't want to use the whole setup and are interested in just the models, that should be fairly easy - install the library, import the relevant modules and use them with your custom components. Unfortunately, reinforcement learning is really fragile[4] at the moment, frequently models work only with very specific way of handling inputs, learning loop and parameters. I've provided all these in a single package for you.

There are still a few models/features that I didn't implement for this release, but I consider providing at some point in the future:

  • Support for Recurrent Neural Network policies
  • Scaling environment rewards with PopArt
  • Hindsight Experience Replay
  • Generative Adversarial Imitation Learning
  • Multi-GPU support/distributed learning

I'll try to keep adding other great state-of-the-art reinforcement learning models on an ongoing basis.

There is only one model so far that I've decided to not implement, and that is Actor Critic using Kronecker-Factored Trust Region. It requires Kronecker-factored Approximate Curvature which is a great idea on it's own and can be implemented (relatively) easily for a fixed set of simple networks but is way beyond my time budget currently to provide a flexible implementation of high enough quality for this project.

If you have any requests of what you'd like to see implemented, open an issue on Github, write a comment here or email me directly.

Testing, debugging

Each model I was working on I've treated as a separate project that can be outlined in three stages:

  • Research – Firstly, I've downloaded and read research publications that explain the model in detail. It is good to give at least two solid runs from the beginning till the end for each piece of work. Frequently, I have found important details to be hidden in a small table on the last page. This phase has to be conducted thoroughly before implementation can begin. Building an intuition behind the model and understanding all the whys will be crucial in all later stages. I was often returning to scan the publication during the implementation phase to fill in the details, but all the high level design I had sketched out beforehand.
  • Implementation - When we know what we are going to do, we can sit down and start writing code. While reinforcement learning models are not big pieces of software in any way, they come with their own complexity. It is beneficial to have sketched outlines and diagrams of important concepts within the model. I was lucky that I was not the first one implementing any of these models so I was able to refer to other implementations available online. But be careful about that, not all models you can easily google will be correct and it's important to approach unproven open source code with a bit of caution. At the time of writing this post I think what I've gathered is the biggest collection of heavily tested PyTorch reinforcement learning models.
  • Debugging - "I've written a few hundreds line of code, can run them, see some numbers in the output and now what?" - I think most practitioners implementing reinforcement learning models will ask themselves this question multiple times. I sure did. While neither reading through research publications or implementing state-of-the-art models are very simple endeavors I think it's the debugging part that is the biggest challenge of these three. I'll devote to rest of this section to this stage.

Improving debugging efficiency of reinforcement learning models is one of the major engineering challenges of this field. I believe experience helps a lot but we it is the structural approach that is needed. I see that as on open question.

People tend to share their success stories, rather than hardships they've encountered. It is quite unfortunate, as in my opinion, others can often learn more from one's failures than successes. Therefore, I'll share with you some of the horror stories I've lived through while debugging models.

I have many years of software development experience behind my belt and have built and deployed many large complex numerical systems into production. Yet still, I've encountered the hardest to track bug of my life to date during this project.

Debugging story 1: A2C performance difference

It was the first hard bug I've encountered in the project. Surprisingly enough, in the end it turned out there was no bug at all.

First piece of code I've written for this project was Advantage Actor-Critic (commonly called A2C). It consists of the simplest, most vanilla policy gradient computation with a critic baseline. There is very little actual complexity on the theoretical side, but the environment rollouts, advantage estimations etc. need to be set up just right. When I was benchmarking my code against OpenAI (who publish benchmarks of their implementations) agent, I've made the following observations:

  • My agent was initially learning much faster than OpenAI agent
  • Learning was highly unstable - sometimes, out of the blue, the rewards would drop catastrophically

That didn't seem right. Although the agent code didn't seem overly complex, I've analyzed it line by line at least 20 times to make sure it does everything as expected. And it did. Yet the performance drops continued to occur.

I was trying every angle I could imagine to get down to the root of this issue. After some time I've started testing directly how do the calculated policy gradients affect policy network parameters. As it turned out, the problem was caused by a difference in implementation of RMSProp optimizer between TensorFlow and PyTorch.

The TensorFlow RMSPRop was implemented the following way:

\[ \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\mathbb{E}[g^2]_t + \varepsilon}} g_t \]

While in PyTorch the same step looked like below:

\[ \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\mathbb{E}[g^2]_t} + \varepsilon} g_t \]

Difference doesn't seem large, but it's all that was needed. I was benchmarking my models with exactly the same parameter values as OpenAI for the results to be meaningful. What I have learned on that day is that \(\varepsilon\) for certain problems is a very sensitive parameter – combined with the difference in square root it made learning process highly unstable. Simply increasing the \(\varepsilon\) value from 1e-5 to around 1e-3 solved this issue completely.

Debugging story 2: ACER replay buffer

This bug I'll remember for a long time. Out of all models I've implemented I consider ACER to be the most complex in mathematical terms. For some of the models, in the research phase, I would just read the relevant publication and advance straight to the implementation phase. For more complex ones, I would read the original work and additional works that the original was referring to. For ACER, I had to go and read publications two-levels deep to understand all mathematical intricacies of off-policy policy-gradient calculations. I had reasonable expectations that I won't get everything right at the first time.

Just as expected, the model was somehow learning, but it was heavily underperforming corresponding OpenAI implementation in my benchmarks. I've begun analyzing carefully all the metrics that I've managed to log through the model run. All of them seemed to be in line for the first half part of the learning process and then started to diverge a bit. I've established a strong suspicion that model must be breaking down around the time the policy we are training diverges enough from the behavior policy for the bias correction to kick in (pretty technical terms about inner workings of the algorithm). I've set up a series of tests, more and more intricate to test this behavior, went through the code multiple times and got ... nothing.

My implementation seemed correct. I've caught a few errors in how some of the metrics were reported, but nothing major. Whatever I was able to test was perfectly in line, yet the algorithm was still underperforming. It is hard to convey the stuggle I had with this one, but it took me weeks to finally solve it.

The error turned out to be hidden in a single line my mind magically overlooked when going over it many many times. When I was saving transitions to the replay buffer, my sequence of states was shifted by one compared to actions and rewards. Usually, one would store tuples of \((s_t, a_t, r_t, s_{t+1})\), I was storing tuples \((s_{t+1}, a_t, r_t, s_{t+2})\) instead. All the calculations were correct this whole time, except I was feeding them with just slightly incorrect data. Moreover, it was close enough to being correct that it has passed all the manual checks of the inputs that I've performed.

I'm really glad this one is behind me. Through this whole process, there were no clues that could guide me to the spot where this mistake was located. A usual approach when debugging complex software is to perform a series of tests consequently narrowing down the location of the troublemaking piece of code. This time, I was wandering completely in the dark.

And I feel completely no guarantee that a similar situation will not happen to me again in the future.

Debugging story 3: Bad gradient flow in DDPG

This one I found rather quickly, but I include it here because it belongs to a very important class of errors I call gradient flow error. Because of the nature of frameworks like PyTorch and TensorFlow, programmers tend to focus on calculating a forward pass of a network and the library is supposed to figure out the backward pass on it's own. Just like all automatic things though, there are situations where we need to nudge library a bit to get the result we want. That was the case this time as well.

In DDPG, model consists of two separate neural networks called unsurprisingly actor and critic. There is a separate update rule for each of them and all implementations I've seen have used separate optimizer objects for each. I didn't like that and I found it more elegant to optimize both networks using a single optimizer.

For deterministic policies, policy gradient calculation step is very simple:

\[ \nabla J = \mathbb{E}(\nabla Q(s, a)) \]

Naively, as my first attempt I've implemented that simply as:

model_action = model.action(rollout['observations'])
model_action_value = model.value(rollout['observations'], model_action)
policy_loss = -model_action_value.mean()
policy_loss.backward()

According to the above code, gradient of policy loss would get propagated to both, the actor network and critic network. As we learned in the research phase, we want to optimize the way policy chooses the best action to maximize the value estimate. Increasing the value estimates themselves would also lower loss function in this case, but that would not only be against our goals but also would increase bias to our value estimations. We definitely don't want to propagate gradient from the policy_loss to the critic network.

Corrected, the code looks like this:

model_action = model.action(rollout['observations'])
model_action_value = model.value(rollout['observations'], model_action)

policy_loss = -model_action_value.mean()

model_action_grad = torch.autograd.grad(policy_loss, model_action)[0]

# Backpropagate actor loss to actor only
model_action.backward(gradient=model_action_grad)

Firstly, we calculate gradient of loss function with respect to the action selected, and then we propagate that gradient further down to the actor network.

If I were to give you three words of advice before debugging a reinforcement learning model, these would be:

  • Understand – do your research thoroughly and understand the model you're trying to implement
  • Measure – log all the metrics that you can
  • Benchmark – Benchmark your new models against your previous models and (if possible) against an independent implementation from someone else

Benchmark results

Even though I've devoted a lot of time and a lot of words in this article to the topic of debugging reinforcement learning models, I'm still not 100% certain my implementations are flawless. Bugs are hard to find, the whole problem setting is noisy and there is just no definite test to tell whether what you have is correct or not. There is a lot of gray area between completely broken and perfectly good.

The closest to that what I found were benchmark resutls published by OpenAI for their baseline implementations for various Atari[6] and MuJoCo[7] environments. I've decided to evaluate my models with hyperparameters chosen to be as close as possible to theirs and compare the results.

Benchmarks - Atari

Below you can find the table with average rewards of the last 100 episodes of training on a selected set of Atari environments using 10 M frames. Original OpenAI benchmark values are highlighted in blue and below them are my results.

As far as I know benchmark results of OpenAI models were averaged form six independent runs with different seeds. I've done the same, but I've discarded the highest and the lowest result as potentially being outliers and calculated mean from the remaining four runs.

Breakout Seaquest Enduro Space Invaders Qbert Pong Beam Rider
A2C – OAI 289.90 1,737.20 0.00 727.32 4,461.29 18.65 2,469.15
A2C – Vel 430.00 1,760.00 0.00 1,000.00 12,000.00 19.50 4,091.00
PPO – OAI 236.90 1,505.40 686.19 959.50 14,234.75 20.39 1,832.95
PPO – Vel 262.00 1,540.00 730.00 956.00 15,012.00 20.56 2,147.00
TRPO – OAI 18.00 834.00 37.00 548.00 3,285.62 16.95 766.36
TRPO – Vel 53.15 1,031.00 0.00 494.00 7,504.00 17.44 1,881.00
ACER – OAI 439.33 1,733.13 0.00 1,382.00 16,234.75 20.03 5,040.98
ACER – Vel 422.00 1,721.00 0.00 1,308.00 16,652.00 20.65 4,222.00
DQN – OAI 1.93 1,139.20 22.20 483.35 1,010.79 -7.21 2,566.60
DQN – Vel 213.00 2,544.00 662.00 862.25 5,117.00 16.62 6,482.00

Result commentary:

  • My A2C implementation seems to be doing slightly better which as far as I know is because of the differences in RMSProp implementation between TensorFlow and PyTorch
  • PPO results I consider to be in line
  • OpenAI baselines TRPO implementation has had a particularly poor parametrization chosen for the Atari environments. I've decided to use a parametrization closer to what other algorithms use. Because of that my results for TRPO on Atari are generally better. Enduro seems to be an outlier there, but I was not able to replicate the score they've reported while running their code locally. It may be possible they've gotten one lucky seed out of six runs,  what that gave them 37 points on average in benchmark.
  • ACER results are very close. There is a slight underperformance of my implementation when compared on the Beam Rider environment. Again, when I ran OpenAI baselines code locally, I was getting lower results than what is shown in the table.
  • DQN, in the same way as TRPO, has had a very poor parameter set chosen in OpenAI baselines repository. I've chosen a parameter set much closer to the DeepMind Nature publication, which gave my implementation much better results.

Result charts:

Benchmarks - MuJoCo

Below you can find the table with average rewards of the last 100 episodes of training on a selected set of MuJoCo environments using 1 M frames. Original OpenAI benchmark values are highlighted in blue and below them are my results. Averaging was done in the same way as for Atari environments.

HalfCheetah Hopper InvertedPendulum Swimmer InvertedDoublePendulum Reacher Walker2d
TRPO-OAI 1,289.70 1,912.90 905.10 94.96 6,731.63 -4.82 2,342.63
TRPO-Vel 1,094.00 1,821.00 887.00 110.00 7,301.00 -5.00 2,407.00
PPO-OAI 1,668.58 2,316.16 809.43 111.19 7,102.91 -6.71 3,424.95
PPO-Vel 2,139.00 2,172.00 858.00 80.00 6,425.00 -6.60 2,952.00

My results are quite in line with OpenAI results on MuJoCo environments as well. In a few cases results reported by OpenAI seem to be slightly higher than what I was able to replicate locally, but it is entirely possible that they have had a lucky set of seeds during evaluation rounds.

Result charts:

After all this hard work we can watch our agents play:

Infrastructure setup

As I try to conduct my research in the open, I've open sourced all of my code in a GitHub repository[2]. Since that's only half of a research setup, I'm also open sourcing the infrastructure setup that I've quickly hacked together to run experiments in the cloud[8]. Maybe they'll be useful to someone else trying to do similar things some day in the future.

Out of many things that I'm not, I'm for sure not a DevOps developer, so I went for something simple and reliable. To minimize deployment issues, I've dockerized my library[9]. I've downloaded a few Ansible roles to set up docker and python on a remote machine and I was ready to go.

To run the experiments, as a first step I've copied the hyperparameter config files for all the models to the remote machine. Next, I've mounted the experiment directory on a docker container and had a small python script spawn experiments sequentially one by one. Each experiment stored the results in a MongoDB database I'm running on MongoDB Atlas.

Each cloud machine was created with specific tags, which were loaded by Ansible playbooks to decide which experiments this particular machine should run. Here is a sample script I've used to spawn worker machines on GCP:

gcloud compute instances create rl-worker-0015 --source-instance-template rl-04cpu --zone=europe-west1-b --labels experiment=breakout_acer,algo=acer
gcloud compute instances create rl-worker-0016 --source-instance-template rl-04cpu --zone=europe-west1-b --labels experiment=qbert_acer,algo=acer
gcloud compute instances create rl-worker-0017 --source-instance-template rl-04cpu --zone=europe-west1-b --labels experiment=beamrider_acer,algo=acer
gcloud compute instances create rl-worker-0018 --source-instance-template rl-04cpu --zone=europe-west1-b --labels experiment=enduro_acer,algo=acer
gcloud compute instances create rl-worker-0019 --source-instance-template rl-04cpu --zone=europe-west1-b --labels experiment=pong_acer,algo=acer
gcloud compute instances create rl-worker-0020 --source-instance-template rl-04cpu --zone=europe-west1-b --labels experiment=seaquest_acer,algo=acer
gcloud compute instances create rl-worker-0021 --source-instance-template rl-04cpu --zone=europe-west1-b --labels experiment=spaceinvaders_acer,algo=acer

I was using a dynamic Ansible inventory plugin to load the list of live machines directly from GCP.

Closing remarks

Working on Vel[2] has been a great experience for me and a first step of a very exciting research project. My work on Deep Reinforcement Learning in sparse reward environments has been sponsored by the AI Grant[10], which gave me access to a much larger pool of compute resources than I would have otherwise. I'm very grateful for that as I'm only starting to understand how central role is compute playing in modern AI research.

If you are in a position where you're able to sponsor independent and Open Source AI research with compute resources, I would be very grateful for a contact.

In the meantime, please stay tuned for the followup articles about solving sparse-reward environments using reinforcement learning.

References