Are you an EPFL student looking for a semester project?
Work with us on data science and visualisation projects, and deploy your project as an app on top of Graph Search.
"I choose this restaurant because they have vegan sandwiches" could be a typical explanation we would expect from a human. However, current Reinforcement Learning (RL) techniques are not able to provide such explanations, when trained on raw pixels. RL algorithms for state-of-the-art benchmark environments are based on neural networks, which lack interpretability, because of the very factor that makes them so versatile – they have many parameters and intermediate representations. Enforcing safety guarantees is important when deploying RL agents in the real world, and guarantees require interpretability of the agent. Humans use short explanations that capture only the essential parts and often contain few causes to explain an effect. In our thesis, we address the problem of making RL agents understandable by humans. In addition to the safety concerns, the quest to mimic human-like reasoning is of general scientific interest, as it sheds light on the easy problem of consciousness. The problem of providing interpretable and simple causal explanations of agent's behavior is connected to the problem of learning good state representations. If we lack such a representation, any reasoning algorithm's outputs would be useless for interpretability, since even the "referents" of the "thoughts" of such a system would be obscure to us. One way to define simplicity of causal explanations via the sparsity of the Causal Model that describes the environment: the causal graph has the fewest edges connecting causes to their effects. For example, a model for choosing the restaurant that only depends on the cause "vegan" is simpler and more interpretable than a model that looks at each pixel of a photo of the menu of a restaurant, and possibly relies as well on spurious correlations, such the style of the menu. In this thesis, we propose a framework "CauseOccam" for model-based Reinforcement Learning where the model is regularized for simplicity in terms of sparsity of the causal graph it corresponds to. The framework contains a learned mapping from observations to latent features, and a model predicting latent features at the next time-steps given ones from the current time-step. The latent features are regularized with the sparsity of the model, compared to a more traditional regularization on the features themselves, or via a hand-crafted interpretability loss. To achieve sparsity, we use discrete Bernoulli variables with gradient estimation, and to find the best parameters, we use the primal-dual constrained formulation to achieve a target model quality. The novelty of this work is in learning jointly a sparse causal graph and the representation taking pixels as the input on RL environments. We test this framework on benchmark environments with non-trivial high-dimensional dynamics and show that it can uncover the causal graph with the fewest edges in the latent space. We describe the implications of our work for developing priors enforcing interpretability.