ML algorithms need lots of data and are prone to catastrophic forgetting. We present a new method for continual few-shot learning, bringing us closer to the way humans learn: sample efficient, while maintaining long-term retention.https://arxiv.org/abs/2301.04584
below:
Consider a problem of learning from a sequence of T tasks, each described using only a few labeled samples, such that you don’t forget about previous tasks while learning new ones.
To solve this problem, we used the recently proposed HyperTransformer (HT, https://arxiv.org/abs/2201.04182), a Transformer-based hypernetwork that generates CNN weights directly from the few-shot task description.
The idea that we propose is simple: we want to recursively reuse the generated weights for the previously learned task as input to the HT for the next task. By doing this, the CNN weights themselves act as a representation of previously learned tasks.
Continual HyperTransformer is trained to update these weights in a way that enables the new task to be learned w/o forgetting past tasks. Unlike other continual learning methods, we do not rely on replay buffers, special regularization, or task-dependent architectural changes.
We have also replaced the fixed-dimensional cross-entropy loss with a more flexible Prototypical loss. This allows us to project an increasing number of classes from different tasks to the same embedding space without changing the architecture of the generated CNN.
We tested the model on three different use cases. First, when classes do not change between the tasks (i.e. tasks = mini-batch), we showed that running a 5-shot problem continually (i.e. running 5 tasks one example at a time) is comparable to 5-shot problem run as a single task!
Second, when classes for all the tasks come from the same distribution, we observed positive backwards transfer, where the accuracy on past tasks improves after learning a subsequent task.
Finally, our model can also learn when each task has its own semantic meaning and comes from a separate distribution. It performs much better than the baseline that trains a single embedding for all the tasks.
This work is done with my wonderful colleagues Andrey Zhmoginov and Mark Sandler from Google Research. Please check out the paper at https://arxiv.org/abs/2301.04584 and don’t hesitate to reach out if you have any questions.