这次看下wgan脚本,这里使用fastai来完成wgan的训练和使用。
老三样,我就不加标题了
%reload_ext autoreload
%autoreload 2
%matplotlib inline
1 重要的包
from fastai.vision import *
from fastai.vision.gan import *
其中gan包是在../fastai/vision下的文件夹。大家从fastai的官网上下载源代码可以看见。
2 数据
本教程使用的脚本是LSun bedroom数据集,该数据集是卧室的图片,我们目的是使用fastai使用wgan生成卧室的图片。
这里使用了该数据集的一小部分,因为原始数据集实在是太大了。
使用data.show_batch()显示部分数据。
图1 部分数据示例
3 模型
生成对抗网络GAN有很多种。
这里使用WGAN(Wassertein GAN)。
模型训练过程如下:
WGAN有两部分:
- 生成器(generator)
- 对抗器(critic)
- freeze generator,然后训练critic。
- 喂一个批次的真数据和假数据
- 根据critic的损失,训练critic的网络。原则是奖励真目标,惩罚假目标。
2 freeze critic,然后训练generator。
- 生成一个批次的假数据
- 使用critic得到损失函数
- 由损失函数训练generator网络。原则是奖励真目标,惩罚假目标。
4 模型训练
训练使用fit_one_cycle,max_lr=2e-4。
这里对比了下Google和kaggle提供的免费模型训练平台。
因此我选择了kaggle的kernel,速度是真的快!但是不能超过9个小时,这个要注意。
这里的fastai需要事先定义:
generator = basic_generator(in_size=64, n_channels=3, n_extra_layers=1)
critic = basic_critic (in_size=64, n_channels=3, n_extra_layers=1)
learn = GANLearner.wgan(data, generator, critic, switch_eval=False,
opt_func = partial(optim.Adam, betas = (0.,0.99)), wd=0.)
# 模型训练
learn.fit(30,2e-4)
5 结果
learn.gan_trainer.switch(gen_mode=True)
learn.show_results(ds_type=DatasetType.Train, rows=16, figsize=(8,8))
图2 生成结果









网友评论