focal loss的tensorflow实现
最近在进行分类任务的时候,发现了数据存在类别不平衡问题。除了类别不平衡问题之外还有难学样本和易学样本之间的不平衡问题。因此考虑使用了focal loss。这里直接上代码:def focal_loss(logits, labels, gamma):''':param logits:[batch_size, n_class]:param labels: [batch...
文章共400字 · 阅读需要大约2分钟
一键AI生成摘要,助你高效阅读
问答
·
最近在进行分类任务的时候,发现了数据存在类别不平衡问题。除了类别不平衡问题之外还有难学样本和易学样本之间的不平衡问题。因此考虑使用了focal loss。这里直接上代码:
def focal_loss(logits, labels, gamma):
'''
:param logits: [batch_size, n_class]
:param labels: [batch_size]
:return: -(1-y)^r * log(y)
'''
softmax = tf.reshape(tf.nn.softmax(logits), [-1]) # [batch_size * n_class]
labels = tf.range(0, logits.shape[0]) * logits.shape[1] + labels
prob = tf.gather(softmax, labels)
weight = tf.pow(tf.subtract(1., prob), gamma)
loss = -tf.reduce_mean(tf.multiply(weight, tf.log(prob)))
return loss
附上链接:
论文:Focal Loss for Dense Object Detection
更好地理解focal loss
更多推荐
已为社区贡献1条内容
所有评论(0)