hw8

代码

任务描述

用自编码器进行异常检测,训练数据都是正常的数据,测试数据有正常有异常的,让你检测出其中的异常数据。训练一个自编码器使其能够还原输入的图像,使用MSEloss作为损失函数。使用测试数据均方差损失函数的数值作为Anomaly score,使用ROC AUC score作为评价指标,ROC AUC score知道TP, TN, FP, FN 的意思再结合助教给的例子就很容易理解。ROC AUC score越大表示输出的Anomaly score越好。
本次任务的 strong baseline:AUC > 0.77196 boss baseline: AUC>0.79506

实验方法

本次实验我只过了strong baseline, 使用的方法也很简单,助教一共提供了四个模型:线性模型,CNN,VAE和Resnet,后面三个看起来很牛B,感觉结果肯定不会差,但你实际跑了一下才知道这三个模型来做异常检测得到的结果都不太行,反而是fcn表现得很好,虽然还没有过boss baseline ,但对于它来说,这个结果已经很棒了!
模型结构如下:

class fcn_autoencoder(nn.Module):
    def __init__(self):
        super(fcn_autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(64 * 64 * 3, 2048),
            nn.ReLU(),
            nn.Linear(2048, 1024),
            nn.ReLU(), 
            nn.Linear(1024, 512), 
            nn.ReLU(), 
            nn.Linear(512, 256),
            nn.ReLU(), 
            nn.Linear(256, 128),
           
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2048),
            nn.ReLU(),
            nn.Linear(2048, 64 * 64 * 3), 
            nn.Tanh())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

提交结果:
在这里插入图片描述

过程总结

之前我说过每个任务的目标一定是要过boss baseline不然没有意义,你看我写的解析感觉就一个fcn调调参数,你还有脸写解析,但其实这个任务我是投入了大量时间的,只是最终还是没有过boss baseline,只能无奈先放弃了,后续我会把我尝试的代码贴出来,大家如果对异常检测有兴趣的话,希望我的尝试过程可以帮助到你。

根据助教的提示,我找到了一篇异常检测论文,(OCGAN)。
论文中网络的结构如下:
在这里插入图片描述
整体模型共由四个子网络组成:一个自编码器,两个判别器,和一个分类器。使用GAN的方法来进行训练。
先说一下我对于这篇论文方法的感受,这么多子网络要一起train,还是GAN,我的评价是巨难train,但是既然人家能够投中,那么存在即合理,还是硬着头皮开始看它给的代码,它给的代码使用的框架是mxnet,在官方文档的帮助下也不难,基本知道代码的意思,然后就是让代码能够跑起来了。主要做了以下几个工作:

  1. 安装mxnet框架,看我之前的文章,过程也是很曲折。
  2. 给的代码有问题,跑不起来,问题是**train_data.reset()**这行代码(也找了好久)。
  3. 它原来进行异常检测的数据集是 MNIST,现在我要重写数据处理部分,让模型能够处理hw08的数据集。
    终于,我可以使用它的模型来跑这个任务了,但是很不幸的是就像我之前预料的一样,train不起来,判别器和分类器的正确率要么一直是0,要么一直是1,我感觉还是因为它原来的数据集是MNIST太简单了,而hw8的数据集长这样:
    在这里插入图片描述
    在这里插入图片描述

上面第一张图片是训练数据都是正常的,而下面的图片是测试数据,让你找出其中异常的数据,这相比与它任务的异常检测(训练数据是只有某一个数字,而测试数据是多种数字混合)要难多了,而且MNIST数据集只有一个通道,而这个数据集的图片是五颜六色的人脸。
综上,我感觉是任务难度导致了我无法在hw08的数据集上训练起来它的模型,我的代码后面会提供给大家,如果有同学感兴趣可以运行一下,看看是不是训练不起来,当然也有可能我的代码有错误,非常欢迎大家指出。
再说一下助教提供的网络结构
在这里插入图片描述
加随机向量,以及分类器的思想都和OCGAN的思想很类似,不过助教的网络结构要比OCGAN要更简单点,训练起来也更有可能成功。那我为什么没有尝试呢?太多的不确定性了,只有一个大致的草图,而没有较详细的指导,我觉得实现难度过大,若有大佬实现了并过了boss baseline,还请dd我。

hw8因为尝试了很多,但是都没有成功,所以过程总结写了大量的篇幅,我的博客不只是告诉你过strong/boss baseline 应该怎么做,而且要分享我的实验过程,成功的和失败的我都会分享,来告诉你我是怎么一步步地达到目标的。当然也有很多不足的地方,欢迎大佬们指正。

Logo

分享最新、最前沿的AI大模型技术,吸纳国内前几批AI大模型开发者

更多推荐