如何取出高维张量中满足一定条件的值(比如大于0.5),其余设为零?

解决方案一,张量花式索引

  代码如下

a = t.randn([2,3])
print(a)
a[a<0.5] = 0
print(a)

  结果
在这里插入图片描述

解决方案二,torch.where()的API

  比上种方法更快,其代码如下

a = t.randn([2,3])
print(a)
a = t.where(a>0.5, a, t.zeros_like(a))# 当a中某一个元素满足第一个形参的条件时,返回a的值,否则返回 t.zeros_like(a)中的值
print(a)

  结果
在这里插入图片描述

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐