InfoGAN 生成时序序列

简介

完整代码:https://github.com/SongDark/timeseries_infogan

  本文介绍用InfoGAN生成多维时序序列。


数据

数据集下载地址

NameClassDimensionTrain SizeTest SizeTruncated
CharacterTrajectories20314221436182

样本介绍

  CharacterTrajectories数据集是小写英文字符轨迹的数据集,包含20个类’a’ ‘b’ ‘c’ ‘d’ ‘e’ ‘g’ ‘h’ ‘l’ ‘m’ ‘n’ ‘o’ ‘p’ ‘q’ ‘r’ ‘s’ ‘u’ ‘v’ ‘w’ ‘y’ ‘z’。每个样本是一个时序序列,有三个维度,分别是 x x x轴坐标、 y y y轴坐标和笔尖力度。作者将所有样本都进行了截断或补零,统一变成长度为182的序列。样本进行了一阶差分处理,需要通过累积(cumsum)来恢复原本的坐标轨迹。

样本处理

  只选用样本的前两维: x x x轴坐标、 y y y轴坐标。
  尺寸为 [ N , 182 , 2 ] [N,182,2] [N,182,2]的mini-batch应该处理(Reshape)成 [ N , 182 , 2 , 1 ] [N,182,2,1] [N,182,2,1],而不是 [ N , 182 , 1 , 2 ] [N,182,1,2] [N,182,1,2],前者将样本作为长182宽2的单通道图像,后者将样本作为长182宽1的双通道图像。经试验,后者无法生成,前者可以。


InfoGAN

论文地址InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets

  论文认为普通的GAN做生成的时候,输入的噪声对输出的图像造成的影响十分不明确,噪声与图像之间的映射关系可能非常复杂。InfoGAN希望GAN能达到这样的效果:改变z的某个维度,能让生成的样本有明显改变。
  InfoGAN将噪声分成两部分: z z z c c c,其中 z z z是普通噪声部分,用于生成复杂的样本, c c c是编码部分,用于控制生成样本的关键信息(例如类别、旋转等)。InfoGAN由三部分组成:生成器 G G G(用于生成样本)、判别器 D D D(用于和G对抗)、分类器 Q Q Q(用于恢复 c c c)。分类器 Q Q Q的任务是猜测样本是由什么样的 c c c控制生成的,它要求 G G G不能把一些无意义的东西塞进 c c c里,这里 G G G Q Q Q的组合其实有点像AutoEncoder。然而, G G G有可能直接把 c c c拼在输入的某处直接送给 Q Q Q,为了避免这样的情况,判别器 D D D要求 G G G必须生成足够真实的样本。


生成效果

c c c 的构成

   c c c 有12位,前10位是one-hot过的标签,用于控制类别,后2位是随机数,控制属性未知,需要网络学习。

1. 随机生成

2. 固定噪声

3. 固定标签,噪声连续变化

在MNIST中,两位的噪声分别可以控制生成字符的宽度和旋转角度,但是在生成时序序列的任务中,两位噪声的作用不太显著。


一些经验和心得

1. 生成器最后一层

有些GAN的实现里,生成器最后一层加了sigmoid或tanh,我这里没有任何激活函数,也能生成。

2. 样本的shape

样本处理成 [ N , 182 , 2 , 1 ] [N,182,2,1] [N,182,2,1]尺寸能生成, [ N , 182 , 1 , 2 ] [N,182,1,2] [N,182,1,2]不能。

3. 类别控制

尽管 c c c中有类别,生成的样本并不是与给定的类别一致的。例如在 c c c中编码字符a,原本是希望能生成a的样本,但实际上生成的可能是b的样本。然而幸运的是,尽管类别是乱序的,但编码了某个类别后,生成的必然是某一类的样本,不会是多类的。原本我以为是我代码写错了,但github上一个高star代码的结果也是乱序的,应该不是我的问题。

4. 属性控制

MNIST做数据集时, c c c的最后两位可以控制宽度和旋转角度,而在时序序列的生成任务里,似乎效果不显著,可能有这样的效果:增大 c c c的数值,字符的尾巴会变长。其实InfoGAN原论文里也只是说作者发现改变这两位刚好对应了不同宽度和旋转的MNIST数字,要是能够在训练前就指定好某一位就是控制宽度、某一位就是控制旋转就好了。


完整代码

https://github.com/SongDark/timeseries_infogan

参考资料

tensorflow-generative-model-collections
timeseries_gan

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐