PyTorch学习笔记 —— Categorical函数
一、介绍Categorical函数来自包 torch.distributions,官方定义的接口如下:class torch.distributions.Categorical(probs)作用是创建以参数probs为标准的类别分布,样本是来自 “0 … K-1” 的整数,其中 K是probs参数的长度。也就是说,按照传入的probs中给定的概率,在相应的位置处进行取样,取样返回...
文章共235字 · 阅读需要大约1分钟
一键AI生成摘要,助你高效阅读
问答
·
一、介绍
Categorical函数来自包 torch.distributions,官方定义的接口如下:
class torch.distributions.Categorical(probs)
作用是创建以参数probs为标准的类别分布,样本是来自 “0 … K-1” 的整数,其中 K 是probs参数的长度。也就是说,按照传入的probs中给定的概率,在相应的位置处进行取样,取样返回的是该位置的整数索引。
如果 probs
是长度为 K
的一维列表,则每个元素是对该索引处的类进行抽样的相对概率。
如果 probs
是二维的,它被视为一批概率向量。
二、使用示例
probs = torch.FloatTensor([[0.05,0.1,0.85],[0.05,0.05,0.9]])
dist = Categorical(probs)
print(dist)
# Categorical(probs: torch.Size([2, 3]))
index = dist.sample()
print(index.numpy())
# [2 2]
更多推荐
已为社区贡献3条内容
所有评论(0)