最近在进行分类任务的时候,发现了数据存在类别不平衡问题。除了类别不平衡问题之外还有难学样本和易学样本之间的不平衡问题。因此考虑使用了focal loss。这里直接上代码:

def focal_loss(logits, labels, gamma):
    '''
    :param logits:  [batch_size, n_class]
    :param labels: [batch_size]
    :return: -(1-y)^r * log(y)
    '''
    softmax = tf.reshape(tf.nn.softmax(logits), [-1])  # [batch_size * n_class]
    labels = tf.range(0, logits.shape[0]) * logits.shape[1] + labels
    prob = tf.gather(softmax, labels)
    weight = tf.pow(tf.subtract(1., prob), gamma)
    loss = -tf.reduce_mean(tf.multiply(weight, tf.log(prob)))
    return loss   

附上链接:
论文:Focal Loss for Dense Object Detection
更好地理解focal loss

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐