A UNIVERSAL REPRESENTATION TRANSFORMER LAYER ...

12
Published as a conference paper at ICLR 2021 AU NIVERSAL R EPRESENTATION T RANSFORMER L AYER FOR F EW-S HOT I MAGE C LASSIFICATION Lu Liu 1,2* , William Hamilton 1,3, Guodong Long 2 , Jing Jiang 2 , Hugo Larochelle 1,41 Mila, 2 Australian AI Institute, UTS, 3 McGill University, 4 Google Research, Brain Team Correspondence to [email protected] ABSTRACT Few-shot classification aims to recognize unseen classes when presented with only a small number of samples. We consider the problem of multi-domain few-shot image classification, where unseen classes and examples come from diverse data sources. This problem has seen growing interest and has inspired the development of benchmarks such as Meta-Dataset. A key challenge in this multi-domain setting is to effectively integrate the feature representations from the diverse set of train- ing domains. Here, we propose a Universal Representation Transformer (URT) layer, that meta-learns to leverage universal features for few-shot classification by dynamically re-weighting and composing the most appropriate domain-specific representations. In experiments, we show that URT sets a new state-of-the-art result on Meta-Dataset. Specifically, it achieves top-performance on the highest number of data sources compared to competing methods. We analyze variants of URT and present a visualization of the attention score heatmaps that sheds light on how the model performs cross-domain generalization. Our code is available at https://github.com/liulu112601/URT. 1 I NTRODUCTION Learning tasks from small data remains a challenge for machine learning systems, which show a no- ticeable gap compared to the ability of humans to understand new concepts from few examples. A promising direction to address this challenge is developing methods that are capable of performing transfer learning across the collective data of many tasks. Since machine learning systems generally improve with the availability of more data, a natural assumption is that few-shot learning systems should benefit from leveraging data across many different tasks and domains—even if each individ- ual task has limited training data available. This research direction is well captured by the problem of multi-domain few-shot classification. In this setting, training and test data spans a number of different domains, each represented by a different source dataset. A successful approach in this multi-domain setting must not only address the regular challenge of few-shot classification—i.e., the challenge of having only a handful of examples per class. It must also discover how to leverage (or ignore) what is learned from different domains, achieving generalization and avoiding cross-domain interference. Recently, Triantafillou et al. (2020) proposed a benchmark for multi-domain few-shot classifica- tion, Meta-Dataset, and highlighted some of the challenges that current methods face when training data is heterogeneous. Crucially, they found that methods which trained on all available domains would normally obtain improved performance on some domains at the expense of others. Follow- ing on their work, progress has been made, which includes the design of adapted hyper-parameter optimization strategies (Saikia et al., 2020) and more flexible meta-learning algorithms (Requeima et al., 2019). Most notable is SUR (Selecting Universal Representation) (Dvornik et al., 2020), a method that relies on a so-called universal representation, extracting from a collection of pre-trained and domain-specific neural network backbones. SUR prescribes a hand-crafted feature-selection procedure to infer how to weight each backbone for each task at hand, and produces an adapted rep- resentation for each task. This was shown to lead to some of the best performances on Meta-Dataset. * This work was done while Lu Liu was a research intern with Mila. Canada CIFAR AI Chair 1

Transcript of A UNIVERSAL REPRESENTATION TRANSFORMER LAYER ...

Published as a conference paper at ICLR 2021

A UNIVERSAL REPRESENTATION TRANSFORMERLAYER FOR FEW-SHOT IMAGE CLASSIFICATION

Lu Liu1,2∗, William Hamilton1,3†, Guodong Long2, Jing Jiang2, Hugo Larochelle1,4†1 Mila, 2 Australian AI Institute, UTS, 3 McGill University, 4 Google Research, Brain TeamCorrespondence to [email protected]

ABSTRACT

Few-shot classification aims to recognize unseen classes when presented with onlya small number of samples. We consider the problem of multi-domain few-shotimage classification, where unseen classes and examples come from diverse datasources. This problem has seen growing interest and has inspired the developmentof benchmarks such as Meta-Dataset. A key challenge in this multi-domain settingis to effectively integrate the feature representations from the diverse set of train-ing domains. Here, we propose a Universal Representation Transformer (URT)layer, that meta-learns to leverage universal features for few-shot classification bydynamically re-weighting and composing the most appropriate domain-specificrepresentations. In experiments, we show that URT sets a new state-of-the-artresult on Meta-Dataset. Specifically, it achieves top-performance on the highestnumber of data sources compared to competing methods. We analyze variants ofURT and present a visualization of the attention score heatmaps that sheds lighton how the model performs cross-domain generalization. Our code is available athttps://github.com/liulu112601/URT.

1 INTRODUCTION

Learning tasks from small data remains a challenge for machine learning systems, which show a no-ticeable gap compared to the ability of humans to understand new concepts from few examples. Apromising direction to address this challenge is developing methods that are capable of performingtransfer learning across the collective data of many tasks. Since machine learning systems generallyimprove with the availability of more data, a natural assumption is that few-shot learning systemsshould benefit from leveraging data across many different tasks and domains—even if each individ-ual task has limited training data available.

This research direction is well captured by the problem of multi-domain few-shot classification.In this setting, training and test data spans a number of different domains, each represented by adifferent source dataset. A successful approach in this multi-domain setting must not only addressthe regular challenge of few-shot classification—i.e., the challenge of having only a handful ofexamples per class. It must also discover how to leverage (or ignore) what is learned from differentdomains, achieving generalization and avoiding cross-domain interference.

Recently, Triantafillou et al. (2020) proposed a benchmark for multi-domain few-shot classifica-tion, Meta-Dataset, and highlighted some of the challenges that current methods face when trainingdata is heterogeneous. Crucially, they found that methods which trained on all available domainswould normally obtain improved performance on some domains at the expense of others. Follow-ing on their work, progress has been made, which includes the design of adapted hyper-parameteroptimization strategies (Saikia et al., 2020) and more flexible meta-learning algorithms (Requeimaet al., 2019). Most notable is SUR (Selecting Universal Representation) (Dvornik et al., 2020), amethod that relies on a so-called universal representation, extracting from a collection of pre-trainedand domain-specific neural network backbones. SUR prescribes a hand-crafted feature-selectionprocedure to infer how to weight each backbone for each task at hand, and produces an adapted rep-resentation for each task. This was shown to lead to some of the best performances on Meta-Dataset.∗This work was done while Lu Liu was a research intern with Mila.†Canada CIFAR AI Chair

1

Published as a conference paper at ICLR 2021

In SUR, the classification procedure for each task is fixed and not learned. Thus, except for theunderlying universal representation, there is no transfer learning performed with regards to howclassification rules are inferred across tasks and domains. Yet, cross-domain generalization mightbe beneficial in that area as well, in particular when tasks have only few examples per class.

Present work. To explore this question, we propose a Universal Representation Transformer (URT)layer, which can effectively learn to transform a universal representation into task-adapted represen-tations. The URT layer is inspired from Transformer (Vaswani et al., 2017) and uses an attentionmechanism to learn to retrieve or blend the appropriate backbones to use for each task. By trainingthis layer across few-shot tasks from many domains, it can support transfer across these tasks.

We show that our URT layer on top of a universal representation’s pre-trained backbones sets anew state-of-the-art performance on Meta-Dataset. It succeeds at outperforming SUR on 4 datasetsources without impairing accuracy on the others. This leads to top performance on 7 dataset sourceswhen comparing to a set of competing methods. To interpret the strategy that URT learns to weighthe backbones from different domains, we visualize the attention scores for both seen and unseendomains and find that our model generates meaningful weights for the pre-trained domains. A com-prehensive analysis on variants and ablations of the URT layer is provided to show the importanceof various components of URT, notably the number of attention heads.

2 FEW-SHOT CLASSIFICATION

2.1 PROBLEM SETTING

In this section, we will introduce the problem setting for few-shot classification and the formulationof meta-learning for few-shot classification. Few-shot classification aims to classify samples whereonly few examples are available for each class. We describe a few-shot learning classification taskas the pair of examples, comprising of a support set S to define the classification task and the queryset Q of samples to be classified.

Meta-learning is a technique that aims to model the problem of few-shot classification as learning tolearn from instances of few-shot classification tasks. The most popular way to train a meta-learningmodel is with episodic training. Here, tasks T = (Q,S) are sampled from a larger dataset by takingsubsets of the dataset to build a support set S and a query set Q for the task. A common approach isto sample N -way-K-shot tasks, each time selecting a random subset of N classes from the originaldataset and choosing only K examples for each class to add to the support set S.

The meta-learning problem can then be formulated by the following optimization:

minΘ

E(S,Q)∼p(T ) [L(S,Q,Θ)] , L(S,Q,Θ) =1

|Q|∑

(x,y)∼Q

− log p(y|x, S; Θ) + λΩ(Θ), (1)

where p(T ) is the distribution of tasks, Θ are the parameters of the model and p(y|x, S; Θ) is theprobability assigned by the model to label y of query example x (given the support set S), and Ω(Θ)is an optional regularization term on the model parameters with factor λ.

Conventional few-shot classification targets the setting of N -way-K-shot, where the number ofclasses and examples are fixed in each episode. Popular benchmarks following this approachinclude Omniglot (Lake et al., 2015) or benchmarks made of subsets of ImageNet, such asminiImageNet (Vinyals et al., 2016) and tieredImageNet (Ren et al., 2018). In such benchmarks,the tasks for training cover a set of classes that is disjoint from the classes in the test set of tasks.However, with the training and test sets tasks coming from a single dataset/domain, the distributionof tasks found in either sets is similar and lacks variability, which may be unrealistic in practice.

It is in this context that Triantafillou et al. (2020) proposed Meta-Dataset, as a further step towardslarge-scale, multi-domain few shot classification. Meta-Dataset includes ten datasets (domains),with eight of them available for training. Additionally, each task sampled in the benchmark varies inthe number of classes N , with each class also varying in the number of shots K. As in all few-shotlearning benchmarks, the classes used for training and testing do not overlap.

2

Published as a conference paper at ICLR 2021

2.2 BACKGROUND AND RELATED WORK

Meta-Learning A promising approach for few-shot classification is to use meta-learning to moredirectly train a model to learn to perform few-shot classification, in an end-to-end way. Thetwo most popular methods are Prototypical Networks (Snell et al., 2017) and Model AgnosticMeta-Learning (MAML) (Finn et al., 2017). Triantafillou et al. (2020) showed that prototypicalnetworks and MAML could be combined by leveraging prototypes for the initialization of theoutput weights value in the inner loop. Requeima et al. (2019) also proposed Conditional NeuralAdaptive Processes (CNAPs) for few-shot classification, which can be seen as extending prototyp-ical networks with a more sophisticated architecture that allows for improved task adaptation. Thisarchitecture was later improved further by Bateni et al. (2020) with Simple CNAPS, leading to oneof the current best methods on Meta-Dataset. Another line of work which leverages the idea of“transfer by fine-tuning” can be found in Appendix B.

Universal Representations In contrast, our work instead builds on that of Dvornik et al. (2020)and their method SUR (Selecting from Universal Representations). Bilen & Vedaldi (2017) intro-duced the term universal representation to refer to a representation that supports good performancein multiple domains. One proposal towards such a representation is to train different neural net-works backbones separately on the data of each available domain, then simply to concatenate therepresentation learned by each. Another is to introduce some parameter sharing between the back-bones, by having a single network conditioned on the domain of the provenance of each batch oftraining data (Rebuffi et al., 2018), e.g. using Feature-wise Linear Modulate (FiLM) (Perez et al.,2018). SUR proposes to leverage a universal representation in few-shot learning tasks with a featureselection procedure that assigns different weights to each of the domain-specific subvectors of theuniversal representation. The objective is to assign high weights only to the domain-specific repre-sentations that are specifically useful for each few-shot task at hand. The weights are inferred byoptimizing a loss on the support set that encourages high accuracy of a nearest-centroid classifier.As such, the method does not involve any meta-learning—a choice motivated by the concern thatmeta-learning may struggle in generalizing to domains that are dissimilar to the training domains.SUR achieved some of the best performances on Meta-Dataset. However, a contribution of ourwork is to provide evidence that meta-learning can actually be used to replace SUR’s hand-designedinference procedure and improve performance further.

Task Adaptive Representations Another line of work tries to retrieve task adaptive representationsfor each task. Task specific representations can be conditioned on a representation of the currenttask (Oreshkin et al., 2018; Wang et al., 2019), projected to another space (Yoon et al., 2019), ormasked based on inter-class commonality and inter-class uniqueness (Li et al., 2019). While the rep-resentation extracted from URT is also task adaptive, it is adaptive to a set of pretrained backbonesand can be applied to more complicated multi-domain scenarios. Wang & Hebert (2016) proposedto improve a CNN by adding extra layers and train it using unsupervised data while our contributionmainly lies in composing representations instead of an improved CNN. Alet et al. (2018) introduceda modular meta-learning method, which learns a repertoire of modules that serves as nodes to con-struct a tree structure to solve a new robotic-related task. Comparatively, URT is a one-for-all layerwhich doesn’t need to construct different module structures for each task.

Transformer Networks Our meta-learning approach to leverage universal representations is in-spired directly from Transformer networks (Vaswani et al., 2017). Our model structure is inspiredby the structure of the dot-product self-attention in the Transformer, which we adapted here to multi-domain few-shot learning by designing appropriate parametrizations for queries, keys and values.Self-attention was explored in the single-domain training regime by Ye et al. (2020); Liu et al.(2019b;a; 2020), however for a different purpose, where each representation of individual examplesin a task support set is influenced by all other examples. Rather than using self-attention betweenindividual examples in the support set, our model uses self-attention to select between differentdomain-specific backbones.

3 UNIVERSAL REPRESENTATION TRANSFORMER LAYER

In this section, we describe our proposed URT layer, which uses meta-learning episodic training tolearn how to combine the domain-specific backbones of a universal representation for any given few-shot learning classification task. URT layer can be built on top of any set of pretrained backbones

3

Published as a conference paper at ICLR 2021

Waxcap

Waxcap

Fly agaric

Fly agaric

AverageScaled

Dot-product Attention

Attention scores( Waxcap)

Attention scores(Fly agaric)

Attention scoresfor task

!1,1

Averageavg

Support set representation one slot

Dot-product

Scale

Softmax

Attention score

Query Key

Sam

ple

repr

esen

tatio

n tra

nsfo

rmed

by

URT

Scaled Dot-product

Attention

?

Scaled Dot-product Attention

r(S2)

r(S1)

r(Sc) ri(Sc)cBackbones

ri

cBackbonesri

cBackbonesri

cBackbonesri

cBackbonesri

!i,c

!2,1 !3,1 !4,1

!1,2 !2,2 !3,2 !4,2

!1 !2 !3 !4

Inferring adapted representation for task

UniversalRepresentations

Figure 1: Illustration of how a single-head URT layer uses a universal representation to producea task-specific representation. This example assumes the use of four backbones, with each colorillustrating their domain-specific sub-vector representation in the universal representation.

without further costly fine-tuning of the backbones. More details on how to train multiple domain-specific backbones can be found in Appendix C.

Conceptually, the proposed model views the support set S of a task as providing information on howto query and retrieve from the set ri of m pre-trained backbones the most appropriate backboneto build an adapted representation φ for the task.

We would like the model to support a variety of strategies on how to retrieve backbones. For ex-ample, it might be beneficial for the model to retrieve a single backbone from the set, especially ifthe domain of the given task matches perfectly that of a domain found in the training set. Alterna-tively, if some of the training domains benefit from much more training data than others, a betterstrategy might be to attempt some cross-domain generalization towards the few-shot learning taskby blending many backbones together, even if none matches the domain of the task perfectly.

This motivates us to use dot-product self-attention, inspired by layers of Transformer net-works (Vaswani et al., 2017). For this reason, we refer to our model as a Universal Repre-sentation Transformer (URT) layer. Additionally, since each class of the support set might re-quire a different strategy, we perform attention separately for each class and their support setSc = x|(x, y) ∈ S and y = c.

3.1 SINGLE-HEAD URT LAYER

We start by describing an URT layer consisting of a single attention head. An illustration of a single-head URT layer is shown in Figure 1. Let ri(x) be the output vector of the backbone for domain i.We then write the universal representation as

r(x) = concat(r1(x), . . . , rm(x)). (2)

This representation provides a natural starting point to obtain a representation of a support set class.Specifically, we will note

r(Sc) =1

|Sc|∑x∈Sc

r(x) (3)

as the representation for the set Sc. From this, we can describe the URT layer by defining thequeries1, keys, the attention mechanism and output of the layer:

Queries qc: For each class c, we obtain a query through qc = Wqr(Sc) + bq , where we have alearnable query linear transformation represented by matrix Wq and bias bq .

Keys ki,c: For each domain i and class c, we define keys as ki,c = Wkri(Sc)+bk, using a learnablelinear transformation Wk and bk and where ri(Sc) = 1/|Sc|

∑x∈Sc

ri(x), using a similar notationas for r(Sc).

1Unable to avoid the unfortunate double usage of the term ”query” due to conflicting conventions, we high-light the difference between the query sets Q of few-shot tasks and the queries qc of an attention mechanism.

4

Published as a conference paper at ICLR 2021

Algorithm 1 Training of URT layer

Input: Number of tasks τtotal, m pre-trained backbones ;1: for τ ∈ 1, · · · , τtotal do2: Sample a few-shot task T with support set S and query set Q;3: # Infer adapted representation for task from S4: For each class, obtain representation using m pre-trained backbones as in Eq. (3);5: Obtain attention scores using Eq. (4,5) for each head using support set S;6: # Use adapted representation to predict labels in Q from support set S7: Compute adapted representation of examples in S and Q as in Eq. (6,7);8: Compute probabilities of label of examples in Q using Prototypical Network as in Eq. (9);9: Compute loss as in Eq. (1,8) and perform gradient descent step on URT parameters Θ;

10: end for

Attention scores αi: as for regular Transformer layers, we use scaled dot-product attention

αi,c =exp(βi,c)∑i′ exp(βi′,c)

, βi,c =qc>ki,c√l

, (4)

where l is the dimensionality of the keys and queries. Then, these per-class scores are aggregated toobtain scores for the full support set by averaging

αi =

∑c αi,c

N. (5)

Equipped with these attention scores, the URT layer can now produce an adapted representation forthe task (for the support and query set examples) by computing

φ(x) =∑i

αiri(x) . (6)

As we can see, this approach has the flexibility of either selecting a single domain-specific backbone(by assigning αi = 1 for a single domain) or blending different domains together (by having αi >>0 for multiple backbones).

3.2 MULTI-HEAD URT LAYER

The URT layer described so far can only learn to retrieve a single backbone (or blending of back-bones). Yet, it might be beneficial to retrieve multiple different (blended) backbones, especially fora few-shot task that would include many classes of varying complexity.

Thus, to achieve such diversity in the adapted representation, we also consider URT layers withmultiple heads, i.e. where each head corresponds to the calculation of Equation 6 and each head hasits own set of parameters (Wq,bq,Wk,bk). Denoting each head now as φh, a multi-head URTlayer then produces as its output the concatenation of all of its heads:

φ(x) = concat(φ1(x), . . . , φH(x)). (7)

Empirically we found that the randomness in the initialization of head weights alone did not lead touniqueness and being complimentary between the heads, so inspired by Lin et al. (2017), we add aregularizer to avoid duplication of the attention scores:

Ω(Θ) = ‖(AA> − I)‖F2, (8)

where ‖ · ‖F is the Frobenius norm of a matrix and A ∈ Rn×m is the matrix for attention scores,with Ah being the vector of all scores αi for head h. The identity matrix I regularizes each set ofattention scores to be more focused so that multiple heads can attend to different domain-specificbackbones.

3.3 TRAINING STRATEGY

We train representations produced by the URT layer by following the approach of PrototypicalNetworks (Snell et al., 2017), where the probability of a label y for a query example x given the

5

Published as a conference paper at ICLR 2021

Table 1: Test accuracy (mean±CI%95) over 600 few-shot tasks. URT and the most recent methods,which are listed in the first column, are compared on Meta-Dataset (Triantafillou et al., 2020), whichare listed in the first row. The numbers in bold have intersecting confidence intervals with the mostaccurate method.

ILSVRC

Omniglot

Aircra

ft

Birds

Textures

QuickDra

w

FungiVGGFlow

er

TrafficS

igns

MSCOCO

avg.

rank

MAML 37.8±1.0 83.9±1.0 76.4±0.7 62.4±1.1 64.1±0.8 59.7±1.1 33.5±1.1 79.9±0.8 42.9±1.3 29.4±1.1 6.4ProtoNet 44.5±1.1 79.6±1.1 71.1±0.9 67.0±1.0 65.2±0.8 64.9±0.9 40.3±1.1 86.9±0.7 46.5±1.0 39.9±1.1 5.7ProtoMAML 46.5±1.1 82.7±1.0 75.2±0.8 69.9±1.0 68.3±0.8 66.8±0.9 42.0±1.2 88.7±0.7 52.4±1.1 41.7±1.1 4.1CNAPs 50.8±1.1 91.7±0.5 83.7±0.6 73.6±0.9 59.5±0.7 74.7±0.8 50.2±1.1 88.9±0.5 56.5±1.1 39.4±1.1 3.6SUR 56.1±1.1 93.1±0.5 84.6±0.7 70.6±1.0 71.0±0.8 81.3±0.6 64.2±1.0 82.8±0.7 53.4±1.0 50.1±1.0 2.3SimpleCNAPS 56.5±1.1 91.9±0.6 83.8±0.7 76.1±0.8 70.0±0.8 78.3±0.7 49.1±1.2 91.3±0.6 59.2±1.0 42.4±1.1 2.2URT (Ours) 55.7±1.0 94.4±0.4 85.8±0.6 76.3±0.8 71.8±0.7 82.5±0.6 63.5±1.0 88.2±0.6 51.1±1.1 52.2±1.1 1.5

support set of a task is modeled as:

p(y = c|x, S; Θ) =exp(−d(φ(x)− pc))∑N

c′=1 exp(−d(φ(x)− pc′)), (9)

where d is a distance metric and pc = 1/|Sc|∑

x∈Scφ(x) corresponds to the centroid of class c,

referred to as its prototype. We use (negative) cosine similarity as the distance. The full trainingalgorithm is presented in Algorithm 1.

4 EXPERIMENTS

In this section, we seek to answer three key experimental questions:

Q1 How does URT compare with previous state-of-the-art on Meta-Dataset for multi-domain few-shot classification?

Q2 Do the URT attention heads generate interpretable and meaningful attention scores?Q3 Does the URT layer provide consistent benefits, even when pre-trained backbones are trained in

different ways?

In addition, we investigate architectural choices made, such as our models for keys/queries and theirregularization, and study their contribution to achieving strong performance with URT.

4.1 DATASETS AND SETUP

We test our methods on the large-scale few-shot learning benchmark Meta-Dataset (Triantafillouet al., 2020). It consists of ten datasets with various data distributions across different domains,including natural images (Birds, Fungi, VGG Flower), hand-written characters (Omniglot, QuickDraw), and human created objects (Traffic Signs, Aircraft). Among the ten datasets, eight providedata that can be used during either training, validation and testing (with each class assigned toonly one of those sets), while two datasets are solely used for testing. Following Bateni et al.(2020); Requeima et al. (2019), we also report results on MNIST (LeCun et al., 1998), CIFAR10and CIFAR100 (Krizhevsky et al., 2009) as additional unseen test datasets. Following Triantafillouet al. (2020), few-shot tasks are sampled with varying number of classesN , varying number of shotsK and class imbalance. The performance is reported as the average accuracy over 600 sampled tasks.More details of Meta-Dataset can be found in Triantafillou et al. (2020).

The domain-specific backbones are pre-trained following the setup in (Dvornik et al., 2020). Then,we freeze the backbone and train the URT layer for 10,000 episodes, with an initial learning rate of0.01 and a cosine learning rate scheduler. Following Chen et al. (2020), the training episodes have50% probability coming from the ImageNet data source. Since different pre-trained backbones mayproduce representations with different vector norms, we normalize the outputs of the backbones as inDvornik et al. (2020). URT is trained with parameter weight decay of 1e-5 and with a regularizationfactor λ = 0.1. The number of heads (H in Equation 7), is set to 2 and the dimension of the keys and

6

Published as a conference paper at ICLR 2021

Test

task

dom

ain

Source domain for pretrained backbone

Test

task

dom

ain

Source domain for pretrained backbone

ILSVRC

Omniglot

AircraftBirds

Textures

QuickDrawFungi

VGGFlower

TrafficSigns

MSCOCO

ILSVRC

Omniglot

Aircraft

Birds

Textures

QuickDraw

Fungi

VGGFlower

TrafficSigns

MSCOCO

ILSVRC

Omniglot

Aircraft

Birds

Textures

QuickDraw

Fungi

VGGFlower

ILSVRC

Omniglot

AircraftBirds

Textures

QuickDrawFungi

VGGFlower

Figure 2: Average attention scores generated by URT with two heads. Rows correspond to thedomain of the test tasks and the columns correspond to the pre-trained backbones ri(x) trained onthe eight training domains.

queries (l in Equation 4) is set to 1024. We choose the hyper-parameters based on the performanceof the validation set. Details of the hyper-parameter selection and how the performance is influencedby them are outlined in Section 4.5. We find that the bottleneck of training URT is extracting featuresfrom CNN. Since we freeze the CNN when training the URT, we find dumping the extracted featureepisodes can significantly speed up the training procedure from days to around 2 hours.

4.2 COMPARISON WITH PREVIOUS APPROACHES

Table 1 presents a comparison of URT with SUR, as well as other baselines based on transfer learn-ing by fine-tuning (Saikia et al., 2020) or meta-learning (Prototypical Networks (Snell et al., 2017),first-order MAML (Finn et al., 2017), ProtoMAML (Triantafillou et al., 2020), CNAPs (Requeimaet al., 2019)) and Simple CNAPS(Bateni et al., 2020).

We observe in Table 1 that URT establishes a new state-of-the-art on Meta-Dataset, by achieving thetop performance on 8 out of the 10 dataset sources. When comparing to its predecessor, URT out-performs SUR on 4 datasets without compromising performance on others, which is challenging toachieve in the multi-domain setting. Of note, the average inference time for URT is 0.04 second pertask, compared to 0.43 for SUR, on a single V100. Thus, getting rid of the optimization procedurefor every episode with our meta-trained URT layer also significantly increases the latency, by morethan 10×. More results on additional datasets can be found in Appendix A.

4.3 INTERPRETING AND VISUALIZING ATTENTION BY URT

To better understand how the URT model of Section 4.2 uses its two heads to build adapted repre-sentations, we visualize the attention scores produced on the test tasks of Meta-Dataset in Figure 2.

The blue (first head) and orange (second head) heatmaps summarize the values of the attentionscores (Equation 5), averaged across several tasks for each test domain. Specifically, the elementon row t and column i is the averaged attention scores αi computed on test set domain t for thebackbone from domain i. Note that the last two rows are the two unseen domain datasets. We foundthat for datasets from the seen domains, i.e. the first eight rows, one head (right, orange) consistentlyputs most of its weight on the backbone pre-trained on the same domain, while the other head (left,blue) learns relatively smoother weight distributions that blend other related domains. For unseendatasets, the right head puts half of its weight on ImageNet and the left head learned to blend therepresentations from four backbones.

7

Published as a conference paper at ICLR 2021

Table 2: Test accuracy (mean±CI%95) over 600 few-shot tasks. All methods use parametric net-work family (pf) backbones.

SUR-pf URT-pf VS.ILSVRC 56.0 ± 1.1 55.5 ± 1.1 =Omniglot 90.0 ± 0.8 90.2 ± 0.6 =Aircraft 79.7 ± 0.8 79.8 ± 0.7 =Birds 75.9 ± 0.9 77.5 ± 0.8 =Textures 72.5 ± 0.7 73.5 ± 0.7 =Quick Draw 76.7 ± 0.7 75.8 ± 0.7 =Fungi 49.8 ± 1.1 48.1 ± 0.9 =VGG Flower 90.0 ± 0.6 91.9 ± 0.5 +

Traffic Signs 52.2 ± 0.8 52.0 ± 1.4 =MSCOCO 50.2 ± 1.1 52.1 ± 1.0 =MNIST 93.2 ± 0.4 93.9 ± 0.4 =CIFAR10 66.4 ± 0.8 66.1 ± 0.8 =CIFAR100 57.1 ± 1.0 57.3 ± 1.0 =

4.4 URT USING FILM MODULATED BACKBONES

As additional evidence of the benefit of URT on universal representations, we also present exper-iments based on a different set of backbone architectures. Following SUR (Dvornik et al., 2020),we consider the backbones from a parametric network family, obtained by training a base backboneon one dataset (ILSVRC) and then learning separate FiLM layers (Perez et al., 2018) for each otherdataset, to modulate the backbone so it is adapted to the other domains. These backbones collec-tively have only 0.5% more parameters than a single backbone. More details of the backbones canbe found in Appendix C.

A comparison between SUR and URT using these backbones (referred to as SUR-pf and URT-pf)is presented in Table 2. Once again, URT can improve the performance on VGG Flower withoutsacrificing performance on others.

4.5 HYPER-PARAMETER AND ABLATION STUDIES

We analyze the importance of the various components of URT’s attention mechanism structure andtraining strategy in Table 3. First we analyze the importance of using the support set to modelqueries and/or keys. To this end, we consider setting the matrices Wq / Wk of the query / key lineartransformation to 0, which only leaves the bias term. We found that the support set representation ismost crucial for building the keys (row w/o Wk in the table) and has minor benefits for queries (roww/o Wq) in the table. This observation is possibly related to the success of attention-based modelswith learnable constant queries (Liu et al., 2016; Lin et al., 2017). We also found that adding a reg-ularizer Ω(Θ) as in Equation 8 is important for some datasets, specifically VGG Flower and Birds.

Table 3: Meta-Dataset performance variation on ablations of elements of the URT layer.

ILSVRC Omniglot Aircraft Birds Textures Draw Fungi Flower Signs MSCOCOw/o Wq +0.2 -0.2 -0.6 -0.1 -0.3 -0.2 0.0 -0.2 -0.8 -0.1w/o Wk -14.2 -2.8 -10.7 -18.1 -7.6 -9.3 -22.4 -3.6 -0.26 -10.9w/o r(Sc) -14.2 -2.8 -10.7 -18.1 -7.6 -9.2 -22.4 -3.6 -0.26 -10.9w/o Ω(Θ) 0.0 -0.9 -0.4 -3.3 -1.2 -0.2 +0.3 -9.0 -2.0 0.0

An important hyper-parameter in URT is the number of heads H . We chose this hyper-parameterbased on the performance on validation set of tasks in Meta-Dataset. In Table 4, we show thevalidation performance of URT for varying number of heads. As suggested by Triantafillou et al.(2020), we considered looking at the rank of the performance achieved by each choice of H foreach validation domains, and taking the average across domains as a validation metric. However,

8

Published as a conference paper at ICLR 2021

since the performances when using two to four heads are similar and yield the same average rank,we instead simply consider the average accuracy as the selection criteria.

Table 4: Validation performance on Meta-Dataset using different number of heads

1 2 3 4 5 6 7 8Average Accuracy 74.605 77.145 76.943 76.984 76.602 75.906 75.454 74.473Average Rank 2.875 1.000 1.000 1.000 2.250 2.250 2.25 2.50

In general, we observe a large jump in performance when using multiple heads instead of just one.However, since the number of heads controls the capacity, predictably we also observe that havingtoo many heads leads to overfitting.

5 CONCLUSION

We proposed the URT layer to effectively integrate representations from multiple domains anddemonstrated improved performance in multi-domain few-shot classification. Notably, our URTapproach was able to set a new state-of-the-art on Meta-Dataset, and never performs worse thanits predecessor (SUR) while also being 10× more efficient at inference. This work suggests thatcombining meta-learning with pre-trained universal representations is a promising direction for newfew-shot learning methods. Specifically, we hope that future work can investigate the design ofricher forms of universal representations that go beyond simply pre-training a single backbone foreach domain, and developing meta-learners adapted to those settings.

REFERENCES

Ferran Alet, Tomas Lozano-Perez, and Leslie P Kaelbling. Modular meta-learning. In Conferenceon Robot Learning (CoRL), 2018.

Peyman Bateni, Raghav Goyal, Vaden Masrani, Frank Wood, and Leonid Sigal. Improved few-shotvisual classification. In Proceedings of the IEEE Conference on Computer Vision and PatternRecognition (CVPR), 2020.

Hakan Bilen and Andrea Vedaldi. Universal representations: The missing link between faces, text,planktons, and cat breeds. arXiv preprint arXiv:1701.07275, 2017.

Wei-Yu Chen, Yen-Cheng Liu, Zsolt Liu, Yu-Chiang Frank Wang, and Jia-Bin Huang. A closerlook at few-shot classification. In International Conference on Learning Representations (ICLR),2019.

Yinbo Chen, Xiaolong Wang, Zhuang Liu, Huijuan Xu, and Trevor Darrell. A new meta-baselinefor few-shot learning. arXiv preprint arXiv:2003.04390, 2020.

Guneet Singh Dhillon, Pratik Chaudhari, Avinash Ravichandran, and Stefano Soatto. A baseline forfew-shot image classification. In International Conference on Learning Representations (ICLR),2020.

Nikita Dvornik, Cordelia Schmid, and Julien Mairal. Selecting relevant features from a universalrepresentation for few-shot classification. arXiv preprint arXiv:2003.09338, 2020.

Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptationof deep networks. In The International Conference on Machine Learning (ICML), 2017.

Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recog-nition. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition(CVPR), 2016.

Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images.Technical report, University of Toronto, 2009.

9

Published as a conference paper at ICLR 2021

Brenden M Lake, Ruslan Salakhutdinov, and Joshua B Tenenbaum. Human-level concept learningthrough probabilistic program induction. Science, 350(6266):1332–1338, 2015.

Yann LeCun, Leon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied todocument recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.

Hongyang Li, David Eigen, Samuel Dodge, Matthew Zeiler, and Xiaogang Wang. Finding task-relevant features for few-shot learning by category traversal. In Proceedings of the IEEE Confer-ence on Computer Vision and Pattern Recognition (CVPR), 2019.

Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu, Bing Xiang, Bowen Zhou,and Yoshua Bengio. A structured self-attentive sentence embedding. arXiv preprintarXiv:1703.03130, 2017.

Lu Liu, Tianyi Zhou, Guodong Long, Jing Jiang, Lina Yao, and Chengqi Zhang. Prototype propaga-tion networks (PPN) for weakly-supervised few-shot learning on category graph. In InternationalJoint Conferences on Artificial Intelligence (IJCAI), 2019a.

Lu Liu, Tianyi Zhou, Guodong Long, Jing Jiang, and Chengqi Zhang. Learning to propagate forgraph meta-learning. In Neural Information Processing Systems (NeurIPS), 2019b.

Lu Liu, Tianyi Zhou, Guodong Long, Jing Jiang, and Chengqi Zhang. Attribute propagation networkfor graph zero-shot learning. In AAAI Conference on Artificial Intelligence (AAAI), 2020.

Yang Liu, Chengjie Sun, Lei Lin, and Xiaolong Wang. Learning natural language inference usingbidirectional lstm model and inner-attention. arXiv preprint arXiv:1605.09090, 2016.

Boris Oreshkin, Pau Rodrıguez Lopez, and Alexandre Lacoste. Tadam: Task dependent adaptivemetric for improved few-shot learning. In The Conference on Neural Information ProcessingSystems (NeurIPS), 2018.

Ethan Perez, Florian Strub, Harm de Vries, Vincent Dumoulin, and Aaron C. Courville. Film:Visual reasoning with a general conditioning layer. In AAAI Conference on Artificial Intelligence(AAAI), 2018.

Sylvestre-Alvise Rebuffi, Hakan Bilen, and Andrea Vedaldi. Efficient parametrization of multi-domain deep neural networks. In Proceedings of the IEEE Conference on Computer Vision andPattern Recognition (CVPR), 2018.

Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, Kevin Swersky, Joshua B Tenenbaum,Hugo Larochelle, and Richard S Zemel. Meta-learning for semi-supervised few-shot classifica-tion. In International Conference on Learning Representations (ICLR), 2018.

James Requeima, Jonathan Gordon, John Bronskill, Sebastian Nowozin, and Richard E Turner.Fast and flexible multi-task classification using conditional neural adaptive processes. In TheConference on Neural Information Processing Systems (NeurIPS), pp. 7957–7968, 2019.

Tonmoy Saikia, Thomas Brox, and Cordelia Schmid. Optimized generic feature learning for few-shot classification across domains. arXiv preprint arXiv:2001.07926, 2020.

Jake Snell, Kevin Swersky, and Richard Zemel. Prototypical networks for few-shot learning. In TheConference on Neural Information Processing Systems (NeurIPS), 2017.

Eleni Triantafillou, Tyler Zhu, Vincent Dumoulin, Pascal Lamblin, Utku Evci, Kelvin Xu, RossGoroshin, Carles Gelada, Kevin Swersky, Pierre-Antoine Manzagol, et al. Meta-dataset: A datasetof datasets for learning to learn from few examples. In International Conference on LearningRepresentations (ICLR), 2020.

Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In The Conference on NeuralInformation Processing Systems (NeurIPS), 2017.

Oriol Vinyals, Charles Blundell, Tim Lillicrap, Daan Wierstra, et al. Matching networks for oneshot learning. In The Conference on Neural Information Processing Systems (NeurIPS), 2016.

10

Published as a conference paper at ICLR 2021

Xin Wang, Fisher Yu, Ruth Wang, Trevor Darrell, and Joseph E Gonzalez. Tafe-net: Task-awarefeature embeddings for low shot learning. In Proceedings of the IEEE Conference on ComputerVision and Pattern Recognition (CVPR), 2019.

Yu-Xiong Wang and Martial Hebert. Learning from small sample sets by combining unsuper-vised meta-training with cnns. In The Conference on Neural Information Processing Systems(NeurIPS), 2016.

Han-Jia Ye, Hexiang Hu, De-Chuan Zhan, and Fei Sha. Few-shot learning via embedding adaptationwith set-to-set functions. In Proceedings of the IEEE Conference on Computer Vision and PatternRecognition (CVPR), 2020.

Sung Whan Yoon, Jun Seo, and Jaekyun Moon. Tapnet: Neural network augmented with task-adaptive projection for few-shot learning. In The International Conference on Machine Learning(ICML), 2019.

A EXPERIMENTS ON MORE DATASETS

We also report performances on the MNIST, CIFAR-10 and CIFAR-100 dataset sources in Table 5,and compare with the subset of methods that have reported on these datasets. There, URT neitherimproves nor gets worse performance than SUR, yeilding top performance on the MNIST domainbut not on the CIFAR-10/CIFAR-100 domain, on which Simple CNAPS has the best performance.

Table 5: Test performance (mean+CI%95) over 600 few-shot tasks on additional datasets.

MNIST CIFAR10 CIFAR100 avg. rankCNAPs 92.7 ± 0.4 61.5 ± 0.7 50.1 ± 1.0 4.7TaskNorm 92.3 ± 0.4 69.3 ± 0.8 54.6 ± 1.1 3.3SUR 94.3 ± 0.4 66.8 ± 0.9 56.6 ± 1.0 2.3SimpleCNAPS 93.9 ± 0.4 74.3 ± 0.7 60.5 ± 1.0 1.7

URT (Ours) 94.8 ± 0.4 67.3 ± 0.8 56.9 ± 1.0 2.0

B MORE RELATED WORKS

Transfer by fine-tuning A simple and effective method for few-shot classification is to performtransfer learning by first learning a neural network classifier on all data available for training andusing its representation to initialize and then fine-tune neural networks on the few-shot classificationtasks found at test time (Chen et al., 2019; Triantafillou et al., 2020; Dhillon et al., 2020; Saikia et al.,2020). Specifically, Saikia et al. (2020) have shown that competitive performance can be reached us-ing a strong hyper-parameter optimization method applied on a carefully designed validation metricappropriate for few-shot learning.

C STRUCTURES AND TRAINING STRATEGY OF BACKBONES

Our URT layer can be built on top of a set of pretrained backbones. The structures of the backbonesfollow the approach in Dvornik et al. (2020). For Table 1, we use the ResNet18 architecture (Heet al., 2016) for the backbones, where a separate backbone is pretrained on each domain separatelyusing the corresponding training data in meta-dataset. The training domains include ImageNet,Omniglot, Aircraft, CU-Birds, Textures, Quick Draw, Fungi and VGG-Flower. For the parametricnetwork family in Table 2, a base ResNet18 is first trained on ImageNet. Then a small numberof modulating parameters are trained on the other domains, using their domain-specific trainingdata, while the rest of the base backbone’s weights stay fixed. Specifically, FiLM feature modu-lation (Perez et al., 2018) is used. This type of parametric network family thus allows for a muchreduced number of learnable parameters, by reusing the weights from the base network. In exper-iments, we use the pretrained backbones released by Dvornik et al. (2020) for both cases, without

11

Published as a conference paper at ICLR 2021

any further finetuning. End-to-end training of a set of backbones and the URT layers requires unaf-fordable computational cost, so we fix the pretrained backbones and only train the URT layer.

Implementation details for training backbones The training details of the backbones comefrom Dvornik et al. (2020). For optimization, SGD with momentum was used, using cosine learningrate annealing. Since the datasets come from different domains, the starting learning rate, the max-imum number of training iterations and annealing frequency are set individually for each dataset.Data augmentation is applied and a constant weight decay of 7×10−4 is set. For each dataset, a gridsearch over batch size in [8, 16, 32, 64] was run and the one that maximizes accuracy on the valida-tion set was picked. For the parametric network family, the base ResNet18 trained on ImageNet isthe same. For other backbones, cosine annealing as learning rate policy is also used, weight decayand data augmentation employed as above. Please refer to Table 4 and Table 5 in the original SURpaper (Dvornik et al., 2020) for the specific values of the hyperparameters for the individual featurenetworks and the parametric network family, respectively.

D RESULTS ON TRAFFIC SIGNS

The shuffle buffer bug described in meta-dataset issue #54 (https://github.com/google-research/meta-dataset/issues/54) has been propagated to previous works such as CNAPs (Requeimaet al., 2019). The results for the Traffic Signs dataset are considerably worse after fixing this bug. Forinstance, URT degrades from 69.4±0.8 to 51.1±1.1. Please visit the official meta-dataset GitHubrepo for more details. The main paper shows the corrected results (for URT and competing ap-proaches). For completeness, we also provide the tables as they were in the initial version of thispaper for your reference in this Appendix. Both sets of results overall support the advantageousperformance of URT over previous work.

Table 6: Test accuracy (mean±CI%95) over 600 few-shot tasks. URT and the most recent methods,which are listed in the first column, are compared on Meta-Dataset (Triantafillou et al., 2020), whichare listed in the first row. The numbers in bold have intersecting confidence intervals with the mostaccurate method.

ILSVRC

Omniglot

Aircra

ft

Birds

Textures

QuickDra

w

FungiVGGFlow

er

TrafficS

igns

MSCOCO

avg.

rank

MAML 37.8±1.0 83.9±1.0 76.4±0.7 62.4±1.1 64.1±0.8 59.7±1.1 33.5±1.1 79.9±0.8 42.9±1.3 29.4±1.1 8.0ProtoNet 44.5±1.1 79.6±1.1 71.1±0.9 67.0±1.0 65.2±0.8 64.9±0.9 40.3±1.1 86.9±0.7 46.5±1.0 39.9±1.1 7.3ProtoMAML 46.5±1.1 82.7±1.0 75.2±0.8 69.9±1.0 68.3±0.8 66.8±0.9 42.0±1.2 88.7±0.7 52.4±1.1 41.7±1.1 5.4CNAPs 52.3±1.0 88.4±0.7 80.5±0.6 72.2±0.9 58.3±0.7 72.5±0.8 47.4±1.0 86.0±0.5 60.2±0.9 42.6±1.1 5.1BOHB-E 55.4±1.1 77.5±1.1 60.9±0.9 73.6±0.8 72.8±0.7 61.2±0.9 44.5±1.1 90.6±0.6 57.5±1.0 51.9±1.0 4.4TaskNorm 50.6±1.1 90.7±0.6 83.8±0.6 74.6±0.8 62.1±0.7 74.8±0.7 48.7±1.0 89.6±0.6 67.0±0.7 43.4±1.0 3.8SUR 56.3±1.1 93.1±0.5 85.4±0.7 71.4±1.0 71.5±0.8 81.3±0.6 63.1±1.0 82.8±0.7 70.4±0.8 52.4±1.1 2.5SimpleCNAPS 58.6±1.1 91.7±0.6 82.4±0.7 74.9±0.8 67.8±0.8 77.7±0.7 46.9±1.0 90.7±0.5 73.5±0.7 46.2±1.1 2.4URT (Ours) 55.7±1.0 94.4±0.4 85.8±0.6 76.3±0.8 71.8±0.7 82.5±0.6 63.5±1.0 88.2±0.6 69.4±0.8 52.2±1.1 1.6

Table 7: Test accuracy (mean±CI%95) over 600 few-shot tasks. All methods use parametric net-work family (pf) backbones.

SUR-pf URT-pf VS.ILSVRC 56.4 ± 1.2 55.5 ± 1.1 =Omniglot 88.5 ± 0.8 90.2 ± 0.6 +Aircraft 79.5 ± 0.8 79.8 ± 0.7 =Birds 76.4 ± 0.9 77.5 ± 0.8 =Textures 73.1 ± 0.7 73.5 ± 0.7 =Quick Draw 75.7 ± 0.7 75.8 ± 0.7 =Fungi 48.2 ± 0.9 48.1 ± 0.9 =VGG Flower 90.6 ± 0.5 91.9 ± 0.5 +

Traffic Signs 65.1 ± 0.8 67.5 ± 0.8 +MSCOCO 52.1 ± 1.0 52.1 ± 1.0 =MNIST 93.2 ± 0.4 93.9 ± 0.4 =CIFAR10 66.4 ± 0.8 66.1 ± 0.8 =CIFAR100 57.1 ± 1.0 57.3 ± 1.0 =

12