文本分类(二)复杂场景下分类任务应用介绍

AI 3周前 admin
22 0 0

作者简介Neo,二范数智能AI团队成员,武汉大学硕士,研究方向为信息检索、知识图谱、医疗数据挖掘与分析,对机器学习、深度学习在NLP领域的应用有着浓厚兴趣。



1、引言

文本分类任务在二分类、多分类等简单场景下的应用不胜枚举,当探讨其多标签分类、长文本分类模型压缩加速复杂场景的应用时,传统的文本策略出现了不尽匹配的情况,亟需进行优化迭代,因此,我们在本文带来复杂场景下文本分类任务的介绍。



2、多标签任务

任务简介:

由前文文本分类(一)分类任务与模型介绍的文章我们能够从定义层面厘清多分类和多标签任务的区别:

  • 多分类任务中一条数据只有一个label,但是label有可能存在多个值。

  • 多标签分类任务指的是一条数据可能有一个或者多个label。


本节,我们重点分析二者模型架构层面存在的差别:


算法架构:

(1)模型输入输出

本节仍以新闻分类为例来进行两种任务的探讨

模型输入是新闻文本,如下所示:

  • “今天大盘涨了3%,地产、传媒板块领涨” [金融]

  • “汪峰今天发布了新歌,娱乐圈又将有大事发生”[娱乐、音乐]

假设新闻文本只有[金融、体育、娱乐、音乐]四类,那模型输出便采用one-hot思路对[金融、体育、娱乐、音乐]进行有序排列,二者区别如下表所示:

属性

多标签分类

多分类任务

任务定义

一条数据有多个或者多个label

一条数据只有一个label,但是label有多种

模型输入

一段新闻文本

一段新闻文本

模型输出

[0,0,1,1]

[1,0,0,0]


(2)模型细节阐述
  • 关于多分类任务,模型只需要在输出层阶段执行一次分类器判断,此时输出层的激活函数是softmax,这种输出是一种基于分布的形式,判断哪一类的可能性最大;

  • 关于多标签任务,我们当然也可以采取多分类任务的思路,对每个label进行执行二分类判断,以上面新闻分类为例,就需要训练4个二分类器,这样子不仅费时费力,也在一定程度上损坏了label之间的依赖关系。所以,我们一般采用下面的思路:将传统多分类任务中的输出层的softmax激活函数变换为sigmoid激活函数,对每个节点的值进行一次激活,对单个节点执行0-1判断


同样,根据上述任务实现思路,可以将损失函数进行调整,由多分类任务的多类别交叉熵损失函数(Categorical_crossentropy调整为更适合多标签任务的二分类交叉熵损失函数Binary_crossentropy


二者区别如下表所示:
属性
多标签分类
多分类任务
任务定义
一条数据有多个或者多个label
一条数据只有一个label,但是label有多种
特征提取器
神经网络
神经网络
输出层激活函数

Sigmoid

Softmax
损失函数

Categorical_crossentropy

Binary_crossentropy


在文本分类的复杂场景中,样本总会存在多标签情况,本节主要对多类别和多标签任务的区别进行厘清,方便未来应用。



3、长文本分类

诸如BERT等各种预训练模型目前已经广泛应用于文本分类任务,但是模型仍存在一定的局限性,即它对于输入文本的最大长度有一定的限制,除去[cls]、[sep]标签外,文本最多只能再输入510个token(下文统一把[cls]、[sep]也算作token,即512),但是现实场景中,长于512个token的文本比比皆是,那么如何实现预训练模型在这些长文本分类任务中的应用呢?


实现思路:
首先对于长文本分类,有两个思路,
  • 第一个从数据层面进行解决,即改造我们的文本,使之符合模型的要求;

  • 第二个,从模型层面进行解决,即迭代模型,使之能够容纳更长的文本。


数据层面一般有如下做法:
  1. 截断法:前或者后截断,使文本满足512个字以内

  2. 分段法:分为多个512个字的段

  3. 压缩法:裁剪无意义的句子


模型层面则有如下模型:

Transformer-XL、Longformer等


下面分别对它们进行详细介绍:

首先对数据层面的各个方法进行分析


(1)截断法

截断法主要采取如下方式:

  1. 头截断, 只保留最前面N(如512)个字;

  2. 尾截断, 只保留最后面N个字;

  3. 头+尾截断, 开头结尾各保留一部分;


截断法的特点

  1. 尽管要求最大长度是512个token, 但去除[cls]、[sep]后, 实际是510个token;

  2. 选择头截断、还是尾截断、还是两者结合,主要看数据的关键信息分布;

  3. 截断法适合大量几百字的文本, 如果文本几千字, 粗暴截断会丢失重要信息;


2)分段法

分段法主要采取如下方式:

  1. 将长文本依次划分为n个不超过512字的段(为避免语义丢失,最好考虑断句);

  2. 针对n个段分别进行BERT编码;

  3. 将n段经过BERT后的[CLS]向量进行max-pooling或mean-pooling;

  4. 然后再接一个全连接层做分类;


分段法特点
  1. 考虑到全局信息, 相比截断法, 对几千字的长文本效果较好;

  2. 性能较差, 每个段都要encode一次, 文本越长,速度越慢;

  3. 段落之间联系会丢失, 易出现badcase;


3)压缩法

压缩法主要采取如下方式:

其核心是裁减掉一些无意义的句子,例如:

  1. 一些文章开头或结尾有一些无用“套路话术”, 这些可以删除掉;
  2. 去除url;
  3. 句子筛选,只保留最重要的N个句子,如:计算句子和标题的相似度;


接着分析模型层面的迭代优化:

(1)transformer-xl模型

当数据过长时,如果使用截断法,它没有考虑句子的自然边界,而是根据固定的长度来划分序列,导致分割出来的文本在语义上是不完整的;如果使用分段法,每个句子之间独立训练,不同的token之间最长的依赖关系,就取决于句子的长度。

  • 模型介绍

transformer-xl提出了一个状态复用的块级别循环用以解决长序列问题,虽然这个模型的提出主要是为了解决文本生成任务,但我们可以参考其解决长序列问题的思路。

  • 块级别循环训练阶段介绍

  1. 依然文本是分块(句子)输入, 但在计算当前块的输出时, 会缓存并利用上一个segment中所有layer的隐向量序列

  2. 其中,所有隐向量序列只参与前向计算,不再进行反向传播。


文本分类(二)复杂场景下分类任务应用介绍

2Longformer模型
注意力机制能够快速便捷地从整个文本序列中捕获重要信息。然而传统的注意力机制的时空复杂度与文本的序列长度呈平方的关系,这在很大程度上限制了模型的输入不能太长。
  • 模型介绍

基于这些考虑,Longformer被提出来拓展模型在长序列建模的能力,它提出了一种时空复杂度同文本序列长度呈线性关系的注意力机制,用以保证模型使用更低的时空复杂度建模长文档,并将文本处理长度扩充到了4096。

  • 提出新的注意力机制

下图展示了经典的注意力机制和Longformer提出的注意力机制,其中a是经典的注意力机制,它是一种“全关注的注意力机制,即每个token都要和序列中的其他所有token进行交互,因此它的时空复杂度是O(n²)  。右边的三种模式是Longformer提出来的注意力机制,分别是滑动窗口注意力(Sliding Window Attention)扩张滑动窗口注意力(Dilated Sliding Window)和 全局+滑动窗口注意力(Global+Sliding Window)


文本分类(二)复杂场景下分类任务应用介绍


下面对其进行详细解释:

①滑动窗口注意力
  • 引入固定长度的滑动窗口,即当前词只与相邻的k个词关联

  • 注意力复杂度从0(n²)降到0(nk);

  • 操作类似于卷积操作,单层感受野是k,L层感受野能达到L*k;


②扩张滑动窗口注意力
  • 在滑动窗口注意力基础上引入膨胀卷积, 类似IDCNN,在卷积核中增加空洞,扩充单层感受野,关注到更多上下文。从下图可以看到,同样是尺寸为 3 的卷积核,同样是两层卷积层,传统卷积上下文大小为 5,而膨胀卷积的上下文大小为 7。

文本分类(二)复杂场景下分类任务应用介绍


全局+滑动窗口注意力

  • 首先需要全局注意力关注一些预先设置的位置,即设定某些位置的token能够看见全部的token,同时其他的所有token也能看见这些位置的token,相当于是将这些位置的token“暴露”在最外面。

  • 同时,这些位置的确定和具体的任务有关,例如对于分类任务,这个带有全局视角的token是“CLS”,确保其能Attention到整个序列;对于问答任务,这些带有全局视角的token是Question对应的这些token。


4、模型压缩加速策略—模型蒸馏

BERT参数过多导致模型笨重,硬件受限下,如何实现模型压缩与加速?除了onnx推理加速, 知识蒸馏(Knowledge Distillation)也是一种非常常用的方法,本节将带来蒸馏的技术的介绍。


  • 蒸馏定义:

蒸馏是用teacher模型指导student模型训练,以期提升student模型精度。一般来说,teacher模型精度高,不过计算复杂度也大,不适合在终端设备部署,而student模型计算复杂度虽符合终端设备要求,但精度不够,所以可以采取模型蒸馏(distillation)解决这一问题。


  • 蒸馏目标:

用推理效率更高的、轻量的学生模型, 近似达到老师的大模型的效果,一般老师的模型size(参数量)要大过学生, 比如用BERT-large去教BERT-base。


  • 蒸馏过程:
蒸馏,即老师将知识(Embedding/hidden/logits)教给学生的过程。


下面将以经典完成蒸馏的预训练模型进行介绍:

(1)DistilBERT

  • 基本介绍:
DistilBERT是一个6层的BERT, 由12层的BERT_Base当老师, 在预训练阶段蒸馏得到。


  • 蒸馏流程:

  1. DistilBERT直接使用老师模型的前6层进行初始化(各层之间维度相同)

  2. DistilBERT只进行MLM任务,没有进行NSP任务(该任务被认为是无效策略)

  3. 另外注意的是,学生模型在学习时,除了要利用真实的label,还需要学习老师模型的隐层输出(hidden)和输出概率(soft_label)


流程如下图所示:


文本分类(二)复杂场景下分类任务应用介绍

 

  • 模型细节:

其蒸馏过程中最重要的就是loss的学习,下面我们将分析蒸馏的loss如何定义:

文本分类(二)复杂场景下分类任务应用介绍

其中,

第一项为有监督MLM损失:

被mask的部分作为label,与学生输出计算交叉熵
文本分类(二)复杂场景下分类任务应用介绍

第二项为蒸馏MLM损失:

学生的输出si向老师输出ti看齐,两者计算交叉熵:

 

文本分类(二)复杂场景下分类任务应用介绍

蒸馏时,老师的输出ti也称作soft_label,它是logits经过softmax后的概率

并且需要注意的是,这里的softmax一般带温度系数T,训练时设置T=8,推理时设置T=1
文本分类(二)复杂场景下分类任务应用介绍

第三项为输出层余弦损失:

学生的last hidden 向老师的last hidden看齐,计算余弦距离

文本分类(二)复杂场景下分类任务应用介绍

  • 模型效果
  1. 从模型大小来看,DistilBERT模型参数由BERT-base的110M降为66M
  2. 从推理速度来看,推理速度获得40%的提升
  3. 从模型效果来看,下游任务直接微调时, 获得97%的BERT-base效果

我们看到DistilBERT模型仅学习老师模型的最后部分,那么是否可以向老师模型学习到更多的结构呢?


2TinyBERT

TinyBERT能够很好的解决上述问题,

  • 首先,在模型结构层面,它对于模型学习得更彻底,基于 transformer 的模型专门设计的知识蒸馏方法,即将Embedding层和中间层都进行蒸馏,如下图所示。

  • 其次,在学习阶段层面,它使用了两阶段蒸馏,即在预训练和微调阶段均进行了蒸馏


文本分类(二)复杂场景下分类任务应用介绍


  • 蒸馏流程:
  1. TinyBERT提出了一种两阶段学习框架,包括通用形式蒸馏和特定任务的蒸馏,如下图所示。
  2. 在通用蒸馏阶段,使用原始BERT,无需进行任何微调即可将其用作teacher,并使用大型文本语料库作为训练数据。通过对来自一般领域的文本执行Transformer蒸馏,获得了可以针对下游任务进行微调的常规TinyBERT,通用形式蒸馏帮助TinyBERT学习预训练BERT中嵌入的丰富知识,这在改进TinyBERT的泛化能力中起着重要作用。
  3. 在特定任务的蒸馏阶段,使用增强的特定任务的数据集,重新执行Transformer蒸馏,特定任务的蒸馏进一步向TinyBERT教学了经过微调的BERT的知识。

 

文本分类(二)复杂场景下分类任务应用介绍


  • 模型细节:

下面我们将分析TinyBERT的loss如何定义:
文本分类(二)复杂场景下分类任务应用介绍

其中,

第一项为词向量损失:

计算学生词向量和老师词向量的均方误差,因为两者维度未必一致,所以需要引入映射e

文本分类(二)复杂场景下分类任务应用介绍

第二项为中间层损失:

若学生4层,老师12层,则老师的(3,6,9,12)层分别蒸馏到学生的(1,2,3,4)层,中间层的损失由隐层均方误差损失和注意力损失组成:

文本分类(二)复杂场景下分类任务应用介绍

其中隐层均方误差损失:
文本分类(二)复杂场景下分类任务应用介绍

学生的第i层隐层输出和老师的第j层隐层输出计算MSE,用h做映射

其中注意力损失:

文本分类(二)复杂场景下分类任务应用介绍

学生第i层多头注意力矩阵 和老师第j层多头注意力矩阵计算MSE,K为head数


第三项为预测层损失:

和DistilBERT一样,学生学习老师的soft_label并计算交叉熵:

文本分类(二)复杂场景下分类任务应用介绍

模型效果

  • 4层的TinyBERT, 能够达到老师(BERT-base)效果的96.8%、参数量缩减为原来的 13.3%、仅需要原来10.6%的推理时间

  • DistilBERT可以不微调蒸馏, 但 TinyBERT最好要做微调蒸馏, 仅4 层的它直接微调效果可能下降明显

  • 预训练蒸馏时TinyBERT没有使用预测层损失,主要因为预训练阶段主要学习文本表示


进一步思考:

DistilBERT和TinyBERT主要将模型变浅,已有研究证实,相比于模型变窄,模型变浅让精度损失更大,那么能否通过降低模型宽度来实现蒸馏?


3MobileBERT
  • 模型结构

MobileBERT为上述问题的解决提供了思路,直接对其进行微调, 便可以达到BERT-Base 99.2%的效果、参数量小了4倍、推理 速度快了5.5倍,形象结构如下图所示:

 

文本分类(二)复杂场景下分类任务应用介绍


(a)图是标准的BERT, L层 transformer;

(b)图是Teacher模型, 是一个 Inverted-Bottleneck BERT_Large;Bottleneck结构是一个线性层, 主要将模型加宽;

(c)图是MobileBERT学生模型, 它的Bottleneck结构主要将模型变窄;


  • 模型对比,如下表所示:

  • IB-BERT将521的 hidden加宽到1024来,近似标准的BERT_Large

  • MobileBERT的细节则是:

  • 将512的 hidden变窄到128

  • 堆了更多的Feed Forward层,防止FFN的HHA的参数数量比例失衡;

  • 移除了LayerNorm,替换Gelu为Relu激活;

  • Embedding层为128,通过kernel size为3的1维卷积转化为512维;

 

文本分类(二)复杂场景下分类任务应用介绍


  • 蒸馏流程

MobileBERT使用渐进式知识迁移蒸馏
  1. 最开始的Embedding层和最后的分类层直接从老师拷贝到学生

  2. 由于老师学生层数相同,学生逐层学习老师的hidden和attention

  3. 当学生在学习i层时,前面的所有层 (小于i层)参数均不更新

 

文本分类(二)复杂场景下分类任务应用介绍


  • 模型细节:

下面我们将分析其蒸馏的loss如何定义:

文本分类(二)复杂场景下分类任务应用介绍

主要围绕四个损失进行计算:第一项为有监督MLM损失;第二项为有监督NSP损失;

第三项为隐层蒸馏损失;第四项为注意力矩阵损失。

各项损失的计算方法基本与前面一致, 除了注意力矩阵损失,使用KL散度替代MSE

文本分类(二)复杂场景下分类任务应用介绍

(4)蒸馏工具的使用

我们可以通过TextBrewer工具,自定义各种蒸馏策略。

  • 特点如下

  1. 适用范围广支持多种模型结构(如Transformer、 RNN)和多种NLP任务(如文本分类、阅 读理解和序列标注等) ;

  2. 配置方便灵活:知识蒸馏过程由配置对象(Configurations) 配置。通过配置对象可自由组合 多种知识蒸馏方法;

  3. 多种蒸馏方法与策略:TextBrewer不仅提供了标准和常见的知识蒸馏方法,也包括了计算 机视觉(CV) 领域中的一些蒸馏技术。

  4. 简单易用:为了使用TextBrewer蒸馏模型, 用户无须修改模型部分的代码,并且可复用已 有训练脚本的大部分代码, 如模型初始化、数据处理和任务评估, 仅需额外完成一些准备工作。


  • 架构如下

其架构主要围绕DistillersConfigurations展开:


  • Distillers 是TextBrewer的核心,用来训练蒸馏模型保存模型调用回调函数。目前,工具包中提供了五种 Distillers。

  1. BasicDistiller: 进行最基本的知识蒸馏;

  2. GeneralDistiller:相比于BasicDistiller,额外 提供中间层损失函数(Interme- diate Loss Functions)的支持;

  3. MultiTeacherDistiller:多教师单任务知识蒸馏, 将多个同任务的教师模型蒸馏到一一个学生模型;

  4. MultiTaskDistiller: 多教师多任务知识蒸馏, 将多个不同任务的教师模型蒸馏到一个学生模型;

  5. BasicTrainer:用于在有标签数据上有监督地训 备、模型储存频率和评测频率等; 练教师模型


  • Configurations:Distillers训练 或蒸馏模型的具体方式由两个配 置对象——TrainingConfigDistillationConfig指定。

  1. TrainingConfig:定义了深度学习实验的通用配置,如日志目录与模型储存目录、运行设备、模型存储频率和评测频率等。

  2. DistillationConfig:定义了和知识蒸馏密切相关的配置,如知识蒸馏损失的类型、知识蒸馏温度、硬标签损失的权重、调节器和中间隐含层状态损失函数等。调节器用于动态调整损失权重和温度。


代码参考:
Longformer可参考huggingface中的longformer-chinese-base-4096,调用正确接口即可:
elif 'longformer' in bert_base_model_dir.lower():  # # 自动加载longformer模型            self.bert_tokenizer = BertTokenizer.from_pretrained(bert_base_model_dir)            # # longformer-chinese-base-4096模型参数prefix为bert而非标准的longformer,这是个坑            LongformerModel.base_model_prefix = 'bert'            self.bert_model = LongformerModel.from_pretrained(bert_base_model_dir)

关于蒸馏的核心代码如下:
from textbrewer import DistillationConfig, TrainingConfig, GeneralDistiller# 获取老师模型、启用return_extra# 通过BertFCPredictor获取teacher modelteacher_predictor = BertFCPredictor(    '../model/chinese-roberta-wwm-ext', '../tmp/bertfc', enable_parallel=enable_parallel)teacher_model = teacher_predictor.modelteacher_model.forward = partial(teacher_model.forward, return_extra=True)  # 启用return_extraprint('teacher模型加载成功,label mapping:', teacher_predictor.vocab.id2tag)
# 获取学生模型、启用return_extra# 通过BertFCTrainer获取student modelpretrained_model, model_dir = './model/TinyBERT_4L_zh', './tmp/bertfc'student_trainer = BertFCTrainer(pretrained_model, model_dir, enable_parallel=enable_parallel)student_trainer.vocab.build_vocab(labels=train_labels, build_texts=False, with_build_in_tag_id=False)student_trainer._build_model()student_trainer.vocab.save_vocab('{}/{}'.format(student_trainer.model_dir, student_trainer.vocab_name))student_trainer._save_config()student_model = student_trainer.modelstudent_model.forward = partial(student_model.forward, return_extra=True) # 启用return_extraprint('student模型加载成功,label mapping:', student_trainer.vocab.id2tag) # 确保学生老师label mapping要一致
# 蒸馏配置distill_config = DistillationConfig( # 设置温度系数temperature, tiny-bert论文作者使用1表现最好,一般大于1比较好 temperature=4, # 设置ground truth loss权重 hard_label_weight=1, # 设置预测层蒸馏loss(即soft label损失)为交叉熵,并稍微放大其权重 kd_loss_type='ce', kd_loss_weight=1.2, # 配置中间层蒸馏映射 intermediate_matches=[ # 配置hidden蒸馏映射、维度映射 {'layer_T': 0, 'layer_S': 0, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1, 'proj': ['linear', 312, 768]}, # embedding层输出 {'layer_T': 3, 'layer_S': 1, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1, 'proj': ['linear', 312, 768]}, {'layer_T': 6, 'layer_S': 2, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1, 'proj': ['linear', 312, 768]}, {'layer_T': 9, 'layer_S': 3, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1, 'proj': ['linear', 312, 768]}, {'layer_T': 12, 'layer_S': 4, 'feature': 'hidden', 'loss': 'hidden_mse', 'weight': 1, 'proj': ['linear', 312, 768]}, # 配置attention矩阵蒸馏映射,注意layer序号从0开始 {"layer_T": 2, "layer_S": 0, "feature": "attention", "loss": "attention_mse", "weight": 1}, {"layer_T": 5, "layer_S": 1, "feature": "attention", "loss": "attention_mse", "weight": 1}, {"layer_T": 8, "layer_S": 2, "feature": "attention", "loss": "attention_mse", "weight": 1}, {"layer_T": 11, "layer_S": 3, "feature": "attention", "loss": "attention_mse", "weight": 1}, ])
# 训练配置epoch = 20 # 使用大一点的epochoptimizer = AdamW(student_model.parameters(), lr=1e-4) # 使用大一点的lrtrain_config = TrainingConfig( output_dir=model_dir, log_dir='./log', data_parallel=enable_parallel, ckpt_frequency=1 # 一个epoch存1次模型)
# 配置model中logits hiddens attentions losses的获取方法def simple_adaptor(batch, model_outputs): return { 'logits': model_outputs[-1]['logits'], 'hidden': model_outputs[-1]['hiddens'], 'attention': model_outputs[-1]['attentions'], 'losses': model_outputs[1], }
# 蒸馏distiller = GeneralDistiller( train_config=train_config, distill_config=distill_config, model_T=teacher_model, model_S=student_model, adaptor_T=simple_adaptor, adaptor_S=simple_adaptor)with distiller: print('开始蒸馏') distiller.train(optimizer, train_dataloader, num_epochs=epoch) print('蒸馏结束')


关于二范数智能:二范数AI教育是一家新锐的AI+科创公司,团队主要来自阿里巴巴,成员毕业于华科、武大、东南大学等知名高校。我们在自然语言处理、计算机视觉、推荐系统等领域有深厚的技术积累,同时也具备多年的教育经验。AI培训,我们是最专业的!


欢迎大家通过下面联系方式联系我们:

文本分类(二)复杂场景下分类任务应用介绍

原文始发于微信公众号(二范数智能):文本分类(二)复杂场景下分类任务应用介绍

版权声明:admin 发表于 2022年11月11日 下午10:16。
转载请注明:文本分类(二)复杂场景下分类任务应用介绍 | CTF导航

相关文章

暂无评论

暂无评论...