@NumberFairy
2018-07-04T09:46:09.000000Z
字数 804
阅读 963
README_pix2pix
Pix2pix的实验到目前为止结构上并没有进行大的改变,现在做的只是将语义标签加入到模型训练中。
index.py中包含该实验的所有代码。main()函数是该程序入口,在会话session创建之前有两件事需要做。第一,定义变量input和target;第二,构建模型(见build_model()方法)。以上两步骤几乎没有自己发挥的内容,完全参考最初pix2pix的模型,然后在这里进行了重建。
开启session之后,要分train和test两块内容。如果是train,总结为一句话就是加载数据然后迭代训练。我们要做的工作就是加载数据这块,与最初pix2pix实验相比我们需要添加自己的语义到模型中,简单来说就是我们构建不一样的输入,之前的输入shape为(?,256,256,3),而现在我们的输入shape为(?,256, 256, 3+num_labels)。在代码中体现的话请参考load_batch_data()方法,该方法返回我们需要的batch数据,shape即为(?,256,256, 3+num_labels)。sess.run(XXX)之后把最后一组训练参数保存到ckpt。
如果是test的话,加载之前保存的ckpt模型参数,sess.run()运行生成器得到输出结果,此时得到的结果generator_output的shape为(?,256,256, 3+num_labels),我们需要降通道处理(这里采取的措施为只取channels的前三个通道值【此方法合理与否仍值得商榷】,对应到代码中的方法为process_generator_output()),得到rgb图像。
OVER!!!