0%

Query2Label: A Simple Transformer Way to Multi-Label Classification

来自清华-博世机器学习研究中心,将Transformer解码器用于多标签分类,将label embedding作为query,计算与feature map的cross-attention。在MS-COCO、PASCAL VOC、NUS-WIDE和Visual Genome上进行了实验,取得了SOTA结果。

image-20211010113901822

Overview

Background

多标签分类主要有两个问题

  • 如何解决标签不平衡问题
  • 如何提取有效的local特征

前者是因为one-vs-all策略采用多个独立的二分类器,后者则是因为全局的池化特征稀释了其他标签,使得难以识别细小物体。

目前的研究方向主要有三类

  • 针对正负例的不平衡问题,改进loss函数,包括focal loss、distribution-balanced loss和今年阿里提出的asymmetric loss
  • 建模label correlations,比如使用label co-occurrence和GCN。
  • 定位感兴趣的区域,比如使用spatial transformer。

[AAAI 2019]《Cross-modality attention with semantic graph embedding for multi-label classification》这篇文章,在裁剪负值后,计算label embedding和feature map的cosine相似度作为attention map。但是这种分法可能会导致attention过于平滑,从而作用有限,难以提取有效的desired feature。

image-20211010122647527

Cross-modality attention

基于上述,本文利用Transformer内置的cross-attention作为特征选择器,提取有效的desired feature。受DETR启发,采用可学习的label embedding作为query,也避免了采用label corrleation等方法带来的噪声。

Method

本文是一个two-stage的方法,第一步采用backbone(如ViT)提取图片的时序特征,第二步将特征和label embedding送入transformer中训练。

image-20211010124725252

Query2Label的总体框架

给定图片,提取特征,后接全连接层并reshape得到特征

构造label embedding,其中为类别数,Transformer的每一层解码层都在更新参数。

在self-attention中,query、key和value都来自label embedding;而在cross-attention中,key和value变成了时序特征。

在经过L层Transformer后,得到最后一层的query向量,使用全连接层+sigmoid进行分类。

本文采用了一种简化的非对称损失以解决类别不平衡问题

在实验中选取

Experiment

使用了一层Transformer encoder和两层Transformer decoder,encoder只是为了更好地学习特征表示,但即使不用encoder只用一层decoder也可以表现很好。

采用Adam优化器,weight decay为1e-2,学习率设为1e-4,训练80epochs。

在四个数据集上刷新SOTA,并做了消融实验。

image-20211010133339563

消融实验