最近在研究分类网络MSD-Net,打算跑cifar10和cifar100数据集来复现一下论文的研究结果。
GitHub:MSD-Net的Pytorch版

总体代码写完后,跑cifar10数据集非常完美,分类准确率比ResNet和DenseNet都有提升。但是跑cifar100数据集的时候就报了题目的错

C:/w/1/s/windows/pytorch/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [31,0,0] Assertion `t >= 0 && t < n_classes` failed.
THCudaCheck FAIL file=C:/w/1/s/windows/pytorch/aten/src\THCUNN/generic/ClassNLLCriterion.cu line=110 error=710 : device-side assert triggered
Traceback (most recent call last):
  File "C:/Users/15338/Desktop/pycharm_ssh/lzh/cifar_MSDNet.py", line 203, in <module>
    train(criterion, optimizer, trainloader)
  File "C:/Users/15338/Desktop/pycharm_ssh/lzh/cifar_MSDNet.py", line 91, in train
    loss += criterion(outputs[j], labels_var)
  File "C:\Pycharm Pro\Project1\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "C:\Pycharm Pro\Project1\lib\site-packages\torch\nn\modules\loss.py", line 916, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "C:\Pycharm Pro\Project1\lib\site-packages\torch\nn\functional.py", line 2009, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "C:\Pycharm Pro\Project1\lib\site-packages\torch\nn\functional.py", line 1838, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: cuda runtime error (710) : device-side assert triggered at C:/w/1/s/windows/pytorch/aten/src\THCUNN/generic/ClassNLLCriterion.cu:110

这是因为程序在计算loss的时候发现分类数量和标签数量不一致导致出错。在网上找了很多类似的错误,RuntimeError: cuda runtime error (59) : device-side assert triggered,他们都是因为label小于0或者大于分类数量,只要简单地label = label -1或者+1就能解决问题。

然而我的label打印出来后并没有问题,范围是0—99(cifar100数据集有100个分类)

print(label_var.data)
==>tensor([47, 22, 50,  8, 24, 43, 25, 51, 60, 30, 54, 65, 58, 88, 20, 64, 83, 83,
           17, 60, 75, 68, 88, 24, 25, 65, 30, 99, 51, 95, 69, 49, 50,  7, 74, 66,
           33, 33,  0, 49, 74, 38, 39, 11, 12, 32, 74, 63, 25, 84, 94, 82, 98, 12,
           58, 15,  1, 77, 81, 22, 81, 11, 42, 94], device='cuda:0')

翻了几十页百度和各大论坛帖子,终于在这一篇文章中找到了契机:sunflower_sara的文章
经过仔细翻查作者的分类网络结构,发现了问题所在

if args.data.startswith('cifar100'):
    self.classifier.append(
        self._build_classifier_cifar(nIn * args.grFactor[-1], 100))
elif args.data.startswith('cifar10'):
    self.classifier.append(
        self._build_classifier_cifar(nIn * args.grFactor[-1], 10))
elif args.data == 'ImageNet':
    self.classifier.append(
        self._build_classifier_imagenet(nIn * args.grFactor[-1], 1000))
else:
    raise NotImplementedError

后面的10,100,1000就是网络的分类输出,可以肯定作者的代码默认用的是cifar10的分类,所以我跑cifar100就出错了。然后我去args.py里把cifar10改成cifar100就能完美运行了。

Python新手小白,欢迎各位留言评论,一起学习进步

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐