Analysis and metrics for continual learning and continual meta learning

PI: Richard Zemel 
Co-PI: Kathy McKeown, Columbia; Andreas Tolias, Stanford; Kim Stachenfeld, Columbia; Toni Pitassi, Columbia; David Schwab, CUNY; Mohammadreza Davoodi, Memphis

Abstract

Project 1. Analysis and metrics for continual learning: As we develop new objectives and models for continual learning over temporal and sequential tasks, we need ways to evaluate how well they work in real-world contexts. For example, if we are learning over temporal streams of data (e.g., news), is the model able to determine when a contradiction in its memory/knowledge occurs? Is it able to distinguish between contradictions that are slight variations on its previous knowledge and that don’t require forgetting some fact that was previously learned (e.g., when the number of victims in a hurricane goes up by one) and contradictions that are require more drastic changes in the knowledge store (e.g., a new congressman is elected and the old one resigns). Is it able to recognize changes that are trends over time (e..g, rising floodwaters) and record the trend? In sequential learning across domains, does the model learn what humans might learn in the same situation? For example, after reading a series of scientific articles, does the model learn about experimental methods? Does it learn probabilistic theory? Does it learn very basic facts about scientific articles (e..g, that they begin with an abstract, have a methods section and end with a conclusion)? And, finally, when we discover gaps in the models’ ability to learn how to avoid forgetting and how to identify when knowledge has changed, what changes in the objectives and models that we develop can help to improve performance? More elaborate benchmarks and analysis are needed to quantify these phenomena, either by measuring model responses or latent representations as in mechanistic interpretability.

Project 2. Continual meta-learning. Currently, the LLM development lifecycle consists of three stages–pre-training on vast and general dataset condensed from the internet; fine-tuning on specialized task-specific data; and response alignment to human preferences via RLHF. However, pre-trained LLMs lose their few-shot generalization ability of performing new tasks with only a few examples upon fine-tuning. Additionally, fine-tuned LLMs catastrophically forget their specialized task and alignment when they are subsequently fine-tuned on data from new tasks. This makes it difficult to update and repurpose LLMs without considerable effort and expenditure. Currently, the above problems are often treated as being independent of one another. Continual learning aims to tackle the catastrophic forgetting problem that occurs when training a learning model on a sequence of tasks, whereas meta-learning procedures explicitly optimize for the generalization ability of models in a setup that is akin to LLM pre-training. To tackle these challenges holistically, this project aims to develop a continual meta-learning process to train LLMs on a sequence of tasks without the catastrophic forgetting of previous tasks or the loss of few-shot generalization capabilities.

Publications

In progress

Resources

In progress