元学习——原型网络(Prototypical Networks)

1. 基本介绍

1.1 本节引入

在之前的的文章中,我们介绍了关于连体网络的相关概念,并且给出了使用Pytorch实现的基于连体网络的人脸识别网络的小样本的学习过程。在接下来的内容中,我们来继续介绍另外一种小样本学习的神经网络结构——原型网络。这种网络的特点是拥有能够不仅仅应用在当前数据集的泛化分类能力。在接下来的内容中,我们将介绍以下几个内容:

  1. 原型网络的基本结构。
  2. 原型网络算法描述。
  3. 将原型网络应用于分类任务。
1.2 原型网络引入

相比于连体网络,原型网络是另外一种简单,高效的小样本的学习方式。与连体网络的学习目标类似。原型网络的目标也是学习到一个向量空间来实现文本分类任务。

原型网络的基本思路是对于每一个分类来创建一个原型表示(protoypicla representation)。并且对于一个需要分类的查询,采用计算分类的原型向量和查询点的距离来进行确定。

确定基本思路之后,下面从一个例子开始,对于原型网络进行具体描述。

2 原型网络

2.1 从一个例子开始

现在,我们拥有一个支持集(support set),内部包含狮子,大象,狗三个分类的图片。也就是说,对于分类任务,我们一共拥有三个分类:{狮子,大象,狗}。现在,我们需要对于每一个分类创建一个原型表示。建立的基本流程如下图所示:

  1. 首先,我们对于每一个样本使用编码的方式 f φ ( ) f_φ() fφ(),学习到每一个样本的编码表示(信息抽取)。举个例子,我们可以使用卷积操作来实现对于图片编码信息的抽取。
    在这里插入图片描述
  2. 在学习到每一个样本的编码表示之后,我们对于每一个分类下的所有的样本编码进行求和求取平均的操作,将结果作为分类的原型表示。因此,一个分类的原型表示使用向量求和求平均的过程过程进行表示。
    在这里插入图片描述

当一个新的数据样本被输入到网络中的时候,我们需要的是对于这个样本预测出其分类情况。
3. 第一步,我对于这个新的数据样本使用 f φ ( ) f_φ() fφ()生成其编码表示。如下图所示:
在这里插入图片描述
4. 接下来,我们需要做的就是计算新的样本的编码表示和每一个分类的原型表示之间的距离情况,通过最下距离来确定查询样本属于哪一个分类。对于距离计算,并没有特殊的要求,可以使用欧式距离或者Cos相似度等等计算方式。
在这里插入图片描述

  1. 最后在计算出所有的分类之间的距离之后,我们使用softmax的方式将距离转换成概率的形式。我们有三个分类,那么对于样本在softmax之后,获取到的就是对于这三个分类的距离情况。

在本节的最后,我们回到我们的学习过程,我们希望的是网络从小样本的数据集中进行学习。所以我们在训练的时候,我们对于每一个分类随机的生成少量的样本,我们成这些少量的样本集合为支持集,在整个的训练过程,我们只需要使用到支持集即可。而不需要所有的数据集。同理,我们随机的从数据集中抽取一个样本作为查询点并且对其进行分类的预测。这样就完成了我们从小样本学习的方式。

2.2 原型网络的整体架构

首先,我们给出原型网络的整体架构图:

在这里插入图片描述

我们从整体的架构上来分析一下这种网络结构:

  1. 第一步,我们对于支持集中的每一个样本点生成一个编码表示,通过通过求和平均的方式来生成每一个分类的原型表示。同时,对于我们的查询样本,我们也对其生成一个向量表示
  2. 同时,我们需要计算每一个查询点和每一个分类原型表示的距离情况。并计算softmax概率结果。生成对于各个分类的概率分布情况。

进一步,对于原型网络而言,其应用的范围不仅仅在单样本/小样本的学习过程中,同时还可以应用在零样本的学习方式。对于这种应用的思路是:尽管我们没有当前分类的数据样本,但是如果能够在更高的层次中生成分类的原型表示(元信息)。通过这种元信息,我们也可以完成和上面类似的计算,完成我们的分类任务。

2.3 算法描述

这里我们结合网络结构和数学公式来对原型网络进行算法描述:

  1. 假设我们当前的数据集为D,其内部的样本的表示形式为{ ( x 1 , y 1 ) , ( x 2 , y 2 ) , . . . . , ( x n , y n ) (x_1,y_1),(x_2,y_2),....,(x_n,y_n) (x1,y1),(x2,y2),....,(xn,yn)},其中x表示的向量表示,y表示分类分类标签。
  2. 对于每一个分类,我们随机的从总的样本集中为其生成n个样本点,对于每一个分类,我们生成最后支持集为S。
  3. 同理,我们随机的从总的样本集中为每一个分类选择n个样本点来生成查询集Q。
  4. 对于支持集内部的样本点,使用编码公式 f φ f_φ fφ来为每一个分类生成一个原型表示,这里的编码公式 f φ f_φ fφ可以是任意的一种信息抽取的方式。例如CNN,LSTM等等。
  5. 对于每一个分类,我们生成其原型表示为 :
    i . e . C l a s s P r o t o t y p e ( c ) = 1 S ∑ ( x i , y i ) ∈ S f φ ( x i ) i.e. Class Prototype(c)=\frac{1}{S}∑_{(x_i,y_i)∈S}f_φ(x_i) i.e.ClassPrototype(c)=S1(xi,yi)Sfφ(xi)
  6. 类似的是,我们对于查询集也生成查询集的编码。
  7. 进一步,我们需要计算的是查询集和支持集的原型表示的距离情况。
  8. 最后,需要计算的是当前样本属于每一个分类的概率 p w ( y = k ∣ x ) p_w(y=k|x) pw(y=kx),这里使用softmax的计算方式:
    i . e . p φ ( y = k ∣ x ) = e x p ( − d ( f φ ( x ) , c ) ) ∑ k e x p ( − d ( f φ ( x ) , c ) ) i.e. p_φ(y=k|x)=\frac{exp(-d(f_φ(x),c))}{∑_kexp(-d(f_φ(x),c))} i.e.pφ(y=kx)=kexp(d(fφ(x),c))exp(d(fφ(x),c))
  9. 最终,我们计算损失函数为 J ( φ ) J(φ) J(φ)
    J ( φ ) = − l o g p w ( y = k ∣ x ) J(φ)=-logp_w(y=k|x) J(φ)=logpw(y=kx)
2.4 代码描述

这里,我们选择自定义了一个简单的评论数据集,一共两个分类,每一个分类下面有5个数据,每个分类我们选择3个作为支持集,3个作为查询集,其具体的实现如下:

#encoding=utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F
import jieba
import random
import torch.optim as optim



def createData():
    text_list_pos = ["电影内容很好","电影题材很好","演员演技很好","故事很感人","电影特效很好"]
    text_list_neg = ["电影内容垃圾","电影是真的垃圾","表演太僵硬了","故事又臭又长","电影太让人失望了"]
    test_pos = ["电影","很","好"]
    test_neg = ["电影","垃圾"]
    words_pos = [[item for item in jieba.cut(text)] for text in text_list_pos]
    words_neg = [[item for item in jieba.cut(text)] for text in text_list_neg]
    words_all = []
    for item in words_pos:
        for key in item:
            words_all.append(key)
    for item in words_neg:
        for key in item:
            words_all.append(key)
    vocab = list(set(words_all))
    word2idx = {w:c for c,w in enumerate(vocab)}
    idx_words_pos = [[word2idx[item] for item in text] for text in words_pos]
    idx_words_neg = [[word2idx[item] for item in text] for text in words_neg]
    idx_test_pos = [word2idx[item] for item in test_pos]
    idx_test_neg = [word2idx[item] for item in test_neg]
    return vocab,word2idx,idx_words_pos,idx_words_neg,idx_test_pos,idx_test_neg
def createOneHot(vocab,idx_words_pos,idx_words_neg,idx_test_pos,idx_test_neg):
    input_dim = len(vocab)
    features_pos = torch.zeros(size=[len(idx_words_pos),input_dim])
    features_neg = torch.zeros(size=[len(idx_words_neg), input_dim])
    for i in range(len(idx_words_pos)):
        for j in idx_words_pos[i]:
            features_pos[i,j] = 1.0

    for i in range(len(idx_words_neg)):
        for j in idx_words_neg[i]:
            features_neg[i,j] = 1.0
    features = torch.cat([features_pos,features_neg],dim=0)
    labels = [1,1,1,1,1,0,0,0,0,0]
    labels = torch.LongTensor(labels)
    test_x_pos = torch.zeros(size=[1,input_dim])
    test_x_neg = torch.zeros(size=[1,input_dim])
    for item in idx_test_pos:
        test_x_pos[0,item] = 1.0
    for item in idx_test_neg:
        test_x_neg[0,item] = 1.0
    test_x = torch.cat([test_x_pos,test_x_neg],dim=0)
    test_labels = torch.LongTensor([1,0])
    return features,labels,test_x,test_labels
def randomGenerate(features):
    N = features.shape[0]
    half_n = N // 2
    support_input = torch.zeros(size=[6, features.shape[1]])
    query_input = torch.zeros(size=[4,features.shape[1]])
    postive_list = list(range(0,half_n))
    negtive_list = list(range(half_n,N))
    support_list_pos = random.sample(postive_list,3)
    support_list_neg = random.sample(negtive_list,3)
    query_list_pos = [item for item in postive_list if item not in support_list_pos]
    query_list_neg = [item for item in negtive_list if item not in support_list_neg]
    index = 0
    for item in support_list_pos:
        support_input[index,:] = features[item,:]
        index += 1
    for item in support_list_neg:
        support_input[index,:] = features[item,:]
        index += 1
    index = 0
    for item in query_list_pos:
        query_input[index,:] = features[item,:]
        index += 1
    for item in query_list_neg:
        query_input[index,:] = features[item,:]
        index += 1
    query_label = torch.LongTensor([1,1,0,0])
    return support_input,query_input,query_label




class fewModel(nn.Module):
    def __init__(self,input_dim,hidden_dim,num_class):
        super(fewModel,self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_class = num_class
        # 线性层进行编码
        self.linear = nn.Linear(input_dim,hidden_dim)


    def embedding(self,features):
        result = self.linear(features)
        return result

    def forward(self,support_input,query_input):

        support_embedding = self.embedding(support_input)
        query_embedding = self.embedding(query_input)
        support_size = support_embedding.shape[0]
        every_class_num  = support_size // self.num_class
        class_meta_dict = {}
        for i in range(0,self.num_class):
            class_meta_dict[i] = torch.sum(support_embedding[i*every_class_num:(i+1)*every_class_num,:],dim=0) / every_class_num
        class_meta_information = torch.zeros(size=[len(class_meta_dict),support_embedding.shape[1]])
        for key,item in class_meta_dict.items():
            class_meta_information[key,:] = class_meta_dict[key]
        N_query = query_embedding.shape[0]
        result = torch.zeros(size=[N_query,self.num_class])
        for i in range(0,N_query):
            temp_value = query_embedding[i].repeat(self.num_class,1)
            cosine_value = torch.cosine_similarity(class_meta_information,temp_value,dim=1)
            result[i] = cosine_value
        result = F.log_softmax(result,dim=1)
        return result

hidden_dim = 4
n_class = 2
lr = 0.01
epochs = 1000
vocab,word2idx,idx_words_pos,idx_words_neg,idx_test_pos,idx_test_neg = createData()
features,labels,test_x,test_labels = createOneHot(vocab,idx_words_pos,idx_words_neg,idx_test_pos,idx_test_neg)

model = fewModel(features.shape[1],hidden_dim,n_class)
optimer = optim.Adam(model.parameters(),lr=lr,weight_decay=5e-4)

def train(epoch,support_input,query_input,query_label):
    optimer.zero_grad()
    output = model(support_input,query_input)
    loss = F.nll_loss(output,query_label)
    loss.backward()
    optimer.step()
    print("Epoch: {:04d}".format(epoch),"loss:{:.4f}".format(loss))

if __name__ == '__main__':
    for i in range(epochs):
        support_input, query_input, query_label = randomGenerate(features)
        train(i,support_input,query_input,query_label)


 
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐