[关闭]
@w460461339 2018-09-17T14:03:22.000000Z 字数 9747 阅读 2282

生成对抗网络:CGAN,InfoGAN

MachineLearning


1、CGAN

这个真的是好简单呃…但是效果出奇的好。

https://blog.csdn.net/liuxiao214/article/details/73485333
https://zhuanlan.zhihu.com/p/44170959
https://zhuanlan.zhihu.com/p/25542274

1.1 基本思想

传统GAN,特别是生成器,在生成的时候,往往会生成很多随机的内容(毕竟输入就是个随机向量,还想怎么样)。现在人们想通过给生成器一定的约束,让它能够在生成的时候有一点规律。

放到公式里,就是这样:
传统GAN


CGAN

即,我们需要求在给定条件y的情况下,D和G的最优解。

那么y是什么呢:

1、对于生成器而言,就是这次你希望生成器生成的内容满足一个什么样的条件。
   在之前有的例子是生成二次元萌妹的,没有任何约束的时候,生成的都是随机的妹子;但是CGAN中,如果你让y=‘金发碧眼’,那么生成的所有妹子都是金发碧眼。
2、对于判别器而言,在没有y之前,只需要判断这个是生成数据还是真实数据。但是现在,生成器起到了一点分类器的作用:
    2.1 对于真实数据,他不仅要判断这是一个真实数据,还需要判断这个真实数据是不是满足输入条件y。
    2.2 对于生成数据,他要在输入约束y的条件下,去判断这个x是不是生成数据。
1.2 网络结构

image_1cn3koihsjm13qnogeov112t819.png-60.7kB

这个图也许不够清楚,那么看看下面这个:
image_1cn3kqqn76hk1dqhe3gm7tuht3m.png-184.2kB

网上的文章说的很少,我一开始还以为是没人看,结果是太简单了。

什么条件概率啥的,放到网络里,就是无脑拼接= -。

1.3 改进策略

感觉CGAN和InfoGAN就是除了随机噪声外,还加入一些隐变量,然后通过一些办法,让隐变量起作用。

当然,为了让模型更好的训练,作者还加入了一些调优策略。

1.4 搭建网络

算来算去,这是我写的第一个GAN,先上效果图:(左边是初始状态,右边CPU跑了1000轮)
image_1cn3l0llc1pa26n9119g1a3k117a6j.png-100.7kB

下面代码静下心来看看,应该很好看懂,不想静心的话,可以直接看看训练部分,弄清楚每个输入数据的格式,shape,就能够大概明白整个流程。

代码:

  1. # load data
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. # 把mnist数据放在MNIST_data下就好
  4. mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
  5. OUTPUT_DIR='samples'
  6. if not os.path.exists(OUTPUT_DIR):
  7. os.mkdir(OUTPUT_DIR)
  8. tf.reset_default_graph()
  9. X = tf.placeholder(dtype=tf.float32, shape=[None, HEIGHT, WIDTH, 1], name='X')
  10. y_label = tf.placeholder(dtype=tf.float32, shape=[None, HEIGHT, WIDTH, LABEL], name='y_label')
  11. noise = tf.placeholder(dtype=tf.float32, shape=[None, z_dim], name='noise')
  12. y_noise = tf.placeholder(dtype=tf.float32, shape=[None, LABEL], name='y_noise')
  13. is_training = tf.placeholder(dtype=tf.bool, name='is_training')
  14. def lrelu(x,leak=0.2):
  15. return tf.maximum(x,leak*x)
  16. def sigmoid_cross_entropy_with_logits(x,y):
  17. return tf.nn.sigmoid_cross_entropy_with_logits(logits=x,labels=y)
  18. # discriminator
  19. def discriminator(image,label,reuse=None,is_training=is_training):
  20. momentum=0.9
  21. with tf.variable_scope('discriminator',reuse=reuse):
  22. # 把输入图像和condition进行拼接 【无脑拼接】
  23. h0 = tf.concat([image,label],axis=3)
  24. h0 = tf.layers.conv2d(h0,kernel_size=5,filters=64,strides=2,padding='same')
  25. h0 = lrelu(h0)
  26. # 对于batch_norm,训练的时候需要,推断的时候就不需要。
  27. h1 = tf.layers.conv2d(h0,kernel_size=5,filters=128,strides=2,padding='same')
  28. h1 = lrelu(tf.contrib.layers.batch_norm(h1,is_training=is_training,decay=momentum))
  29. h2 = tf.layers.conv2d(h1,kernel_size=5,filters=256,strides=2,padding='same')
  30. h2 = lrelu(tf.contrib.layers.batch_norm(h2,is_training=is_training,decay=momentum))
  31. h3 = tf.layers.conv2d(h2,kernel_size=5,filters=512,strides=2,padding='same')
  32. h3 = lrelu(tf.contrib.layers.batch_norm(h3,is_training=is_training,decay=momentum))
  33. # 最后一层输出拉平+全连接,没有用FCN。
  34. h4 = tf.contrib.layers.flatten(h3)
  35. h4 = tf.layers.dense(h4,units=1)
  36. # h4是一个batch_size*1的向量吧
  37. return tf.nn.sigmoid(h4),h4
  38. # generator
  39. def generator(z,label,is_training=is_training):
  40. momentum=0.9
  41. with tf.variable_scope('generator',reuse=None):
  42. d =3
  43. z = tf.concat([z,label],axis=1)
  44. h0 = tf.layers.dense(z,units=d*d*512)
  45. h0 = tf.reshape(h0,shape=[-1,d,d,512])
  46. h0 = tf.nn.relu(tf.contrib.layers.batch_norm(h0,is_training=is_training,decay=momentum))
  47. h1 = tf.layers.conv2d_transpose(h0,kernel_size=5,filters=256,strides=2,padding='same')
  48. h1 = tf.nn.relu(tf.contrib.layers.batch_norm(h1,is_training=is_training,decay=momentum))
  49. h2 = tf.layers.conv2d_transpose(h1,kernel_size=5,filters=128,strides=2,padding='same')
  50. h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h2,is_training=is_training,decay=momentum))
  51. h3 = tf.layers.conv2d_transpose(h2,kernel_size=5,filters=64,strides=2,padding='same')
  52. h3 = tf.nn.relu(tf.contrib.layers.batch_norm(h3,is_training=is_training,decay=momentum))
  53. h4 = tf.layers.conv2d_transpose(h3,kernel_size=5,filters=1,strides=1,
  54. padding='valid',activation=tf.nn.tanh,name='g')
  55. return h4
  56. # loss
  57. # 1、拿到生成的生成内容
  58. g = generator(noise, y_noise)
  59. # 2、用真实数据训练判别器
  60. d_real, d_real_logits = discriminator(X, y_label)
  61. # 3、用生成数据训练判别器
  62. d_fake, d_fake_logits = discriminator(g, y_label, reuse=True)
  63. # 喵喵喵???
  64. vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
  65. vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
  66. # 4、通过判别器拿到V(G,D)中第一项
  67. loss_d_real=tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_real_logits,tf.ones_like(d_real)))
  68. # 5、通过判别器拿到V(G,D)中第二项
  69. loss_d_fake=tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits,tf.zeros_like(d_fake)))
  70. # 6、通过判别器定义生成器的loss
  71. loss_g=tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits,tf.ones_like(d_fake)))
  72. # 7、组合得到判别器的loss
  73. loss_d=loss_d_fake+loss_d_real
  74. # optimizer
  75. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  76. with tf.control_dependencies(update_ops):
  77. optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5).minimize(loss_d,var_list=vars_d)
  78. optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002,beta1=0.5).minimize(loss_g,var_list=vars_g)
  79. # 就是为了方便看结果,和模型没关系
  80. def montage(images):
  81. if isinstance(images,list):
  82. images = np.array(images)
  83. img_h = images.shape[1]
  84. img_w = images.shape[2]
  85. n_plots = int(np.ceil(np.sqrt(images.shape[0])))
  86. m = np.ones((images.shape[1]*n_plots+n_plots+1,images.shape[2]*n_plots+n_plots+1))*0.5
  87. for i in range(n_plots):
  88. for j in range(n_plots):
  89. this_filter = i*n_plots+j
  90. if this_filter < images.shape[0]:
  91. this_img = images[this_filter]
  92. m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
  93. 1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
  94. return m
  95. sess=tf.Session()
  96. sess.run(tf.global_variables_initializer())
  97. z_samples = np.random.uniform(-1.0,1.0,[batch_size,z_dim]).astype(np.float32)
  98. y_samples = np.zeros([batch_size,LABEL])
  99. for i in range(LABEL):
  100. for j in range(LABEL):
  101. y_samples[i*LABEL+j,i]=1
  102. samples=[]
  103. loss={'d':[],'g':[]}
  104. import time
  105. '''
  106. batch_size=100
  107. z_dim=100
  108. WIDTH=28
  109. HEIGHT=28
  110. LABEL=10
  111. '''
  112. for i in tqdm(range(60000)):
  113. # 0、generator每次也就100个输入,每个输入为长100的向量
  114. n = np.random.uniform(-1.0,1.0,[batch_size,z_dim]).astype(np.float32)
  115. # 1、拿到的训练集图片,像素值已经被归一化到0~1之间
  116. batch,label=mnist.train.next_batch(batch_size=batch_size)
  117. # 1.1 从batch_size,784 reshape到 batch_size,28,28; 最后那个1表示channel
  118. batch = np.reshape(batch,[batch_size,HEIGHT,WIDTH,1])
  119. # 1.2 将0~1的图像像素值变成-1~1之间,满足tanh的要求。
  120. batch = (batch-0.5)*2
  121. yn = np.copy(label)
  122. # 2、能够理解这么做是为了在判别器中将label和图片接在一起,只是这种做法有点迷= -
  123. # label从 batch,1,1,LABEL_LEN的张量,变成,batch,28,28,LABEL_LEN的张量
  124. y1 = np.reshape(label,[batch_size,1,1,LABEL])
  125. y1 = y1*np.ones([batch_size,HEIGHT,WIDTH,LABEL])
  126. # 3、先跑一下loss,保存loss结果,为了后面画图用
  127. d_ls,g_ls = sess.run([loss_d,loss_g],feed_dict={X:batch,noise:n,y_label:y1,
  128. y_noise:yn,is_training:True})
  129. loss['d'].append(d_ls)
  130. loss['g'].append(g_ls)
  131. # 4、真正训练在这里,基本就是判别器训练一次,生成器训练2次。
  132. # 4.1 其实可以这么理解,训练一次判别器,需要真实数据(batch),以及生成器生成的数据。
  133. # 4.2 而生成器一次就只生成batch_size大小的数据。
  134. # 4.3 因此判别器一次训练,所使用的数据量,是判别器的两倍,所以训练一次判别器,需要训练两次生成器。
  135. sess.run(optimizer_d,feed_dict={X: batch, noise: n, y_label: y1,
  136. y_noise: yn, is_training: True})
  137. sess.run(optimizer_g, feed_dict={X: batch, noise: n, y_label: y1,
  138. y_noise: yn, is_training: True})
  139. sess.run(optimizer_g, feed_dict={X: batch, noise: n, y_label: y1,
  140. y_noise: yn, is_training: True})
  141. if i % 1000 == 0:
  142. print(i, d_ls, g_ls)
  143. # 5、每次到了节点,把这个时刻生成器生成的内容输出一下。
  144. gen_imgs = sess.run(g, feed_dict={noise: z_samples, y_noise: y_samples,
  145. is_training: False})
  146. # 6、生成器输出的图像,像素在-1~1之间,现在需要把它们转化回0~1之间。
  147. gen_imgs = (gen_imgs + 1) / 2
  148. # 7、逐个拿一下图片,拼图啥的。
  149. imgs = [img[:, :, 0] for img in gen_imgs]
  150. gen_imgs = montage(imgs)
  151. plt.axis('off')
  152. plt.imshow(gen_imgs, cmap='gray')
  153. imageio.imsave(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % i), gen_imgs)
  154. plt.show()
  155. samples.append(gen_imgs)
  156. plt.plot(loss['d'], label='Discriminator')
  157. plt.plot(loss['g'], label='Generator')
  158. plt.legend(loc='upper right')
  159. plt.savefig('Loss.png')
  160. plt.show()
  161. imageio.mimsave(os.path.join(OUTPUT_DIR, 'samples.gif'), samples, fps=5)
  162. saver = tf.train.Saver()
  163. saver.save(sess, './mnist_cgan', global_step=60000)

使用生成器

  1. import tensorflow as tf
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. batch_size = 100
  5. z_dim = 100
  6. LABEL = 10
  7. def montage(images):
  8. if isinstance(images, list):
  9. images = np.array(images)
  10. img_h = images.shape[1]
  11. img_w = images.shape[2]
  12. n_plots = int(np.ceil(np.sqrt(images.shape[0])))
  13. m = np.ones((images.shape[1] * n_plots + n_plots + 1, images.shape[2] * n_plots + n_plots + 1)) * 0.5
  14. for i in range(n_plots):
  15. for j in range(n_plots):
  16. this_filter = i * n_plots + j
  17. if this_filter < images.shape[0]:
  18. this_img = images[this_filter]
  19. m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
  20. 1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
  21. return m
  22. sess = tf.Session()
  23. sess.run(tf.global_variables_initializer())
  24. saver = tf.train.import_meta_graph('./mnist_cgan-60000.meta')
  25. saver.restore(sess, tf.train.latest_checkpoint('./'))
  26. graph = tf.get_default_graph()
  27. g = graph.get_tensor_by_name('generator/g/Tanh:0')
  28. noise = graph.get_tensor_by_name('noise:0')
  29. y_noise = graph.get_tensor_by_name('y_noise:0')
  30. is_training = graph.get_tensor_by_name('is_training:0')
  31. n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
  32. y_samples = np.zeros([batch_size, LABEL])
  33. for i in range(LABEL):
  34. for j in range(LABEL):
  35. y_samples[i * LABEL + j, i] = 1
  36. gen_imgs = sess.run(g, feed_dict={noise: n, y_noise: y_samples, is_training: False})
  37. gen_imgs = (gen_imgs + 1) / 2
  38. imgs = [img[:, :, 0] for img in gen_imgs]
  39. gen_imgs = montage(imgs)
  40. plt.axis('off')
  41. plt.imshow(gen_imgs, cmap='gray')
  42. plt.show()

2、InfoGAN

https://blog.csdn.net/dagekai/article/details/53953513
https://blog.csdn.net/u011699990/article/details/71599067
https://www.jianshu.com/p/b8c34c6f09ad

2.1 基本理论

对于原始的GAN而言,生成器的输入z和输出x之间,是没有什么可解释性的。即,我们不知道什么样的z能够生成什么样的x。

在CGAN中,我们通过给定一个隐向量c,以及让判别器去判断生成/真实图像是否符合描述,来约束生成器的行为。这一定程度上,可以认为是,描述c控制了输出x中的与c相关的特征。

然后,呃,有的大佬或许觉得这样不过瘾吧,说我想控制所有的特征;然而我并不能知道一幅图像有哪些特征。于是,我给一个向量作为隐向量,和GAN网络一起训练,然后可以认为,这个向量的每一维,都描述了图像的某一种特征。

问题:

对于这样的隐向量,我没有办法知道它到底描述的是哪几个特征,因此没办法使用CGAN里面的策略来判断这个隐向量是否生效了。

解决:

隐向量最直接的影响应该是生成器的输出,那么只要判断隐向量和生成器输出之间的互信息,就能够知道隐向量是否起作用了。

最终的COST——function:其中L1是经过优化的互信息下界。
image_1cn44dkouifg1mjpk63gf414u8h.png-48.7kB

详细推断:
image_1cn443p9s1191ki61q9113pu1g6777.png-628.4kB

2.2 网络结构

网络的逻辑结构就是这样,具体的结构可以和CGAN差不多。
image_1cn44ca0u1itsehnhde1mpm22v84.png-99.1kB

2.3 具体实现

网上找的一版感觉有点问题= -,之后试试keras版本的。

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注