源代码:

dice_target = nn.functional.one_hot(dice_target, num_classes).float()

解决:
dice_target.to(torch.int64)
修改后的代码:

dice_target = nn.functional.one_hot(dice_target.to(torch.int64), num_classes).float()

亲测有效
原文链接在此

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐