Answer a question

There is a famous trick in u-net architecture to use custom weight maps to increase accuracy. Below are the details of it:

enter image description here

Now, by asking here and at multiple other place, I get to know about 2 approaches. I want to know which one is correct or is there any other right approach which is more correct?

  1. First is to use torch.nn.Functional method in the training loop:

    loss = torch.nn.functional.cross_entropy(output, target, w) where w will be the calculated custom weight.

  2. Second is to use reduction='none' in the calling of loss function outside the training loop criterion = torch.nn.CrossEntropy(reduction='none')

    and then in the training loop multiplying with the custom weight:

    gt # Ground truth, format torch.long
    pd # Network output
    W # per-element weighting based on the distance map from UNet
    loss = criterion(pd, gt)
    loss = W*loss # Ensure that weights are scaled appropriately
    loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
    loss = torch.mean(loss) # Average across a batch
    

Now, I am kinda confused which one is right or is there any other way, or both are right?

Answers

The weighting portion looks like just simply weighted cross entropy which is performed like this for the number of classes (2 in the example below).

weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)

EDIT:

Have you seen this implementation from Patrick Black?

# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)

# Calculate log probabilities
logp = F.log_softmax(logits)

# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))

# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)

# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)

# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()
Logo

Python社区为您提供最前沿的新闻资讯和知识内容

更多推荐