一、Softmax函数作用

Softmax函数是一个非线性转换函数,通常用在网络输出的最后一层,输出的是概率分布(比如在多分类问题中,Softmax输出的是每个类别对应的概率),计算方式如下:

得到的是第i个位置对应的概率,每个位置的概率之和为1(可以看出Softmax仅进行计算,没有需要学习的参数)。

二、PyTorch计算方式

在PyTorch中,包 torch.nn.functional 中实现了Softmax函数,官方文档接口定义如下:

torch.nn.functional.softmax(input, dim=None, _stacklevel=3, dtype=None)

input是我们输入的数据,dim是在哪个维度进行Softmax操作(如果没有指定,默认dim=1)。

举例如下:

import torch
import torch.nn.functional as F

data=torch.FloatTensor([[1.0,2.0,3.0],[4.0,6.0,8.0]])
print(data)
print(data.shape)
print(data.type())

prob = F.softmax(data,dim=0) # dim = 0,在列上进行Softmax;dim=1,在行上进行Softmax
print(prob)
print(prob.shape)
print(prob.type())

输出为:

tensor([[1., 2., 3.],
        [4., 6., 8.]])
torch.Size([2, 3])
torch.FloatTensor

tensor([[0.0474, 0.0180, 0.0067],
        [0.9526, 0.9820, 0.9933]])
torch.Size([2, 3])
torch.FloatTensor

有时候也会见到dim = -1,dim = -2的情况,对于二维输入来说,dim = -1表示行,dim = -2 表示列。

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐