看了 UeFan写的tensorflow tf.where使用方法,大于某个值为1,小于为0文章,依然有些模糊,自己结合实际注解了一下。

	a=tf.random.uniform([2,2])    #随机生成2*2矩阵a
	
	one = tf.ones_like(a)            #生成与a大小一致的值全部为1的矩阵
    zero = tf.zeros_like(a)
    label = tf.where(a <0.5, x=zero, y=one)     #0.5为阈值
    

得到结果:

tensorflow
一个面向所有人的开源机器学习框架
a: tf.Tensor(
	[[0.25626993 0.53764176]
	[0.27858937 0.92834556]], shape=(2, 2), dtype=float32)

label: tf.Tensor(
			[[0. 1.]
 			[0. 1.]], shape=(2, 2), dtype=float32)

在numpy中:

threshold=0.5
x = np.where(pred < threshold, 0, 1)  #数值小于0.50,大于0.51
推荐内容
GitHub 加速计划 / te / tensorflow
26
4
下载
一个面向所有人的开源机器学习框架
最近提交(Master分支:2 个月前 )
4f64a3d5 Instead, check for this case in `ResolveUsers` and `ResolveOperand`, by querying whether the `fused_expression_root` is part of the `HloFusionAdaptor`. This prevents us from stepping into nested fusions. PiperOrigin-RevId: 724311958 2 个月前
aa7e952e Fix a bug in handling negative strides, and add a test case that exposes it. We can have negative strides that are not just -1, e.g. with a combining reshape. PiperOrigin-RevId: 724293790 2 个月前
Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐