美文网首页
GAN-生成对抗网络

GAN-生成对抗网络

作者: 学了忘了学 | 来源:发表于2022-11-11 23:06 被阅读0次

GAN的基本思想

生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。

  • 生成模型:给一系列猫的图片,生成一张新的猫咪(不在数据集里)
  • 判别模型:给定一张图,判断这张图里的动物是不是猫

假如你是一名篮球运动员,你想在下次比赛中得到上场机会
于是在每一次训练赛之后你跟教练进行沟通:

你:教练,我想打球
教练:(评估你的训练赛表现之后)... 算了吧
(你通过跟其他人比较,发现自己的运球很差,于是你苦练了一段时间)

你:教练,我想打球
教练:... 嗯 还不行
(你发现大家投篮都很准,于是你苦练了一段时间的投篮)

你:教练,我想打球
教练: ... 嗯 还有所欠缺
(你发现你的身体不够壮,被人一碰就倒,于是你去泡健身房)
......

通过这样不断的努力和被拒绝,你最终在某一次训练赛之后得到教练的赞赏,获得了上场的机会。
值得一提的是在这个过程中,所有的候选球员都在不断地进步和提升。因而教练也要不断地通过对比场上球员和候补球员来学习分辨哪些球员是真正可以上场的,并且要“观察”得比球员更频繁。随着大家的成长教练也会会变得越来越严格。

GAN

GAN训练过程

定义了一个生成器(Generator)来生成手写数字,一个判别器(Discrimnator)来判别手写数字是否是真实的,和一些真实的手写数字数据集。那么我们怎样来进行训练呢?

生成器(Generator)

对于生成器,输入需要一个n维度向量,输出为图片像素大小的图片。因而首先我们需要得到输入的向量。
Tips: 这里的生成器可以是任意可以输出图片的模型,比如最简单的全连接神经网络,又或者是反卷积网络等。这里大家明白就好。

这里输入的向量我们将其视为携带输出的某些信息,比如说手写数字为数字几,手写的潦草程度等等。由于这里我们对于输出数字的具体信息不做要求,只要求其能够最大程度与真实手写数字相似(能骗过判别器)即可。所以我们使用随机生成的向量来作为输入即可,这里面的随机输入最好是满足常见分布比如均值分布,高斯分布等。
Tips: 假如我们后面需要获得具体的输出数字等信息的时候,我们可以对输入向量产生的输出进行分析,获取到哪些维度是用于控制数字编号等信息的即可以得到具体的输出。而在训练之前往往不会去规定它。

判别器(Discrimnator)

对于判别器不用多说,往往是常见的判别器,输入为图片,输出为图片的真伪标签。
Tips: 同理,判别器与生成器一样,可以是任意的判别器模型,比如全连接网络,或者是包含卷积的网络等等。

训练过程

训练过程
之所以要训练k次判别器,再训练生成器,是因为要先拥有一个好的判别器,使得能够教好地区分出真实样本和生成样本之后,才好更为准确地对生成器进行更新 生成器判别器与样本示意图
图中的黑色虚线表示真实的样本的分布情况,蓝色虚线表示判别器判别概率的分布情况,绿色实线表示生成样本的分布。

我们的目标是使用生成样本分布(绿色实线)去拟合真实的样本分布(黑色虚线),来达到生成以假乱真样本的目的。

可以看到在(a)状态处于最初始的状态的时候,生成器生成的分布和真实分布区别较大,并且判别器判别出样本的概率不是很稳定,因此会先训练判别器来更好地分辨样本。
通过多次训练判别器来达到(b)样本状态,此时判别样本区分得非常显著和良好。然后再对生成器进行训练。
训练生成器之后达到(c)样本状态,此时生成器分布相比之前,逼近了真实样本分布。
经过多次反复训练迭代之后,最终希望能够达到(d)状态,生成样本分布拟合于真实样本分布,并且判别器分辨不出样本是生成的还是真实的(判别概率均为0.5)。也就是说我们这个时候就可以生成出非常真实的样本啦,目的达到。

代码实现

import numpy as np

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
import torch.nn.functional as F
import torch.optim as optom
from torchvision.datasets import MNIST # Training dataset
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 数据准备

# 归一化
transform = transforms.Compose([
    transforms.ToTensor(), # channe,height,width
    transforms.Normalize(0.5,0.5)
])

train_ds = MNIST('../dataset/mnist/',train = True,
                                     transform = transform)

dataloader = torch.utils.data.DataLoader(train_ds,batch_size = 64,shuffle=True)


imgs,_ = next(iter(dataloader))

print(imgs.shape)

# 生成器

# 输入是长度为100的噪声(正太分布随机数)
# 输出为(1,28,18)的图片
"""
linear1:100--256
linear2:256--5112
linear3:521--28*28
reshape:28*28--(1,28,28)
"""



class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.main = nn.Sequential(
                                nn.Linear(100,256),
                                nn.ReLU(),
                                nn.Linear(256,512),
                                nn.ReLU(),
                                nn.Linear(512,28*28),
                                nn.Tanh()    # -1,1
        
        )
    
    def forward(self,x):  # 长度为100的噪声
        img = self.main(x)
        img = img.view(-1,28,28,1)   # channel 在后面
        return img



# 判别器

# 输入为 (1,28,18)的图片
# 输入为 二分类的概率值,使用sigmoid激活0-1
# BCEloss:损失函数,交叉熵

# LeakyReLU: f(x):x>0输出x,x<0,输出 a*x a表示很小的斜率值,比如0.1
# 班别其中一般推荐使用 LeakyReLU

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.main = nn.Sequential(
                                nn.Linear(28*28,512),
                                nn.LeakyReLU(),
                                nn.Linear(512,256),
                                nn.LeakyReLU(),
                                nn.Linear(256,1),
                                nn.Sigmoid()
        
        )
    def forward(self,x): # 输入为(1,28,18)的图片
        x = x.view(-1,28*28)
        x = self.main(x)
        return x

# 初始化模型、优化器及损失函数

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

gen = Generator().to(device)
dis = Discriminator().to(device)

d_optim = torch.optim.Adam(dis.parameters(),lr = 0.0001)
g_optim = torch.optim.Adam(dis.parameters(),lr = 0.0001)

loss_fn = torch.nn.BCELoss()



# 绘图函数

# 绘制出生成器的图片

def gen_img_plot(model,test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize = (4,4))
    for i in range(16):
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i]+1)/2) # 恢复到0-1之间
        plt.axis('off')
    plt.show

test_input = torch.randn(16,100,device = device)



# GAN 的训练

D_loss = []
G_loss = []

# 训练循环
for epoch in range(20):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader) # 返回批次数
    for step,(img,_) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size,100,device = device)
        
        #训练判别器
        d_optim.zero_grad()
        ## 真实图片训练判别器
        real_output = dis(img) #对判别器输入真实的图片,对真实图片的预测结果
        d_real_loss = loss_fn(real_output,
                              torch.ones_like(real_output)) #判别器在真实图像上的损失
        
        ## 生成图片训练判别器
        gen_img = gen(random_noise)
        
        fake_output = dis(gen_img.detach())  # 判别器输入生成的图片,对生成图片的预测
        d_fake_loss = loss_fn(fake_output,
                              torch.zeros_like(fake_output)) #判别器在生成图像上的损失
        d_loss = d_real_loss+d_fake_loss
        d_loss.backward()
        d_optim.step()
        
        # 训练生成器
        g_optim.zero_grad()
        
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output,
                        torch.ones_like(fake_output)) #生成器的损失
        g_loss.backward()
        g_optim.step()
        
        with torch.no_grad():      
            d_epoch_loss+=d_loss
            g_epoch_loss+=g_loss
    with torch.no_grad():
        d_epoch_loss/=count
        g_epoch_loss/=count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('Epoch:',epoch)
        gen_img_plot(gen,test_input)


加入CNN

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
from matplotlib import pyplot as plt
import os

if not os.path.exists('./cnn_img'):
    os.mkdir('./cnn_img')


def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out


batch_size = 128
num_epoch = 20
z_dimension = 100    #噪声维度

# Image processing
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])
# MNIST dataset
mnist = datasets.MNIST(
    root='../dataset/mnist/', train=True, transform=img_transform, download=True)
# Data loader
dataloader = torch.utils.data.DataLoader(
    dataset=mnist, batch_size=batch_size, shuffle=True)


# Discriminator
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 5, padding=2),    # batch,32,28,28
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, stride=2)     # batch,32,14,14
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, padding=2),  # batch, 64, 14, 14
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, stride=2)  # batch, 64, 7, 7
        )
        self.fc = nn.Sequential(
            nn.Linear(64 * 7 * 7, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


# Generator
class generator(nn.Module):
    def __init__(self, input_size, num_feature):
        super(generator, self).__init__()
        # 1.第一层线性变换
        self.fc = nn.Linear(input_size, num_feature)  # batch, 3136=1x56x56
        self.br = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.ReLU(True)
        )
        self.conv1_g = nn.Sequential(
            nn.Conv2d(1, 50, 3, stride=1, padding=1),  # batch, 50, 56, 56
            nn.BatchNorm2d(50),
            nn.ReLU(True)
        )
        self.conv2_g = nn.Sequential(
            nn.Conv2d(50, 25, 3, stride=1, padding=1),  # batch, 25, 56, 56
            nn.BatchNorm2d(25),
            nn.ReLU(True)
        )
        self.conv3_g = nn.Sequential(
            nn.Conv2d(25, 1, 2, stride=2),  # batch, 1, 28, 28
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 1, 56, 56)
        x = self.br(x)
        x = self.conv1_g(x)
        x = self.conv2_g(x)
        x = self.conv3_g(x)
        return x


D = discriminator()
G = generator(z_dimension, 3136)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

D = D.to(device)
G = G.to(device)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0001)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0001)

# Start training
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        #print(img.shape)             # inputs:img=[128,1,28,28]
        # =================train discriminator
        #img = img.view(num_img, -1)        # img.shape: [128, 784]
        real_img = Variable(img).to(device)
        #print(real_img.shape)


        # compute loss of real_img
        real_out = D(real_img)
        real_label = Variable(torch.ones_like(real_out)).to(device)

        d_loss_real = criterion(real_out, real_label)
        real_scores = real_out  # closer to 1 means better

        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension)).to(device)
        fake_img = G(z)
        fake_out = D(fake_img)
        fake_label = Variable(torch.zeros_like(fake_out)).to(device)
        d_loss_fake = criterion(fake_out, fake_label)
        fake_scores = fake_out  # closer to 0 means better

        # bp and optimize
        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()

        # ===============train generator
        # compute loss of fake_img
        z = Variable(torch.randn(num_img, z_dimension)).to(device)
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)

        # bp and optimize
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
                  'D real: {:.6f}, D fake: {:.6f}'.format(
                epoch, num_epoch, d_loss.item(), g_loss.item(),
                real_scores.data.mean(), fake_scores.data.mean()))
    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './cnn_img/real_images.png')

    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, './cnn_img/fake_images-{}.png'.format(epoch + 1))

torch.save(G.state_dict(), './cnn_img/generator.pth')
torch.save(D.state_dict(), './cnn_img/discriminator.pth')

参考

通俗理解生成对抗网络GAN - 知乎 (zhihu.com)
GAN 生成对抗网络 - 知乎 (zhihu.com)
GAN原理及简单mnist生成图片_JWangwen的博客-CSDN博客_gan生成图片

相关文章

网友评论

      本文标题:GAN-生成对抗网络

      本文链接:https://www.haomeiwen.com/subject/ltbttdtx.html