你乐谷
首页 > 图文

深度学习笔记56_GAN也很简单_GAN模型的训练(2)

2023-03-16 来源:你乐谷
generator = keras.models.Model(generator_input,x)
# 判别器模型
discriminator_input = layers.Input(shape=(height,width,channels))
# 开始卷积网络训练
x = layers.Conv2D(128,3)(discriminator_input)
x= layers.LeakyReLU()(x)
x = layers.Conv2D(128,4,strides=2)(x)
x= layers.LeakyReLU()(x)
x = layers.Conv2D(128,4,strides=2)(x)
x= layers.LeakyReLU()(x)
x = layers.Conv2D(128,4,strides=2)(x)
x= layers.LeakyReLU()(x)
# 拉伸用来实现全连接层网络
x = layers.Flatten()(x)
# 使用dropout层来实现相关的网络
x = layers.Dropout(0.4)(x)
# 用来分类
x = layers.Dense(1,activation=sigmoid)(x)
# 将判别器模型实例化,它将形状为 (32,32,3)
#的输入转换为一个二进制分类决策(真 / 假)
# 搭建网络
discriminator = keras.models.Model(discriminator_input,x)
# 优化器
discriminator_optimizer = keras.optimizers.RMSprop(
lr = 0.0008,
clipvalue=1.0,# 使用梯度裁剪限制梯度值的范围
decay=1e-8) # 使用学习率衰减
discriminator pile(
optimizer = discriminator_optimizer,
loss =binary_crossentropy
)
#GAN模型
#如果在此过程中可以对判别器的权重进行更新,
# 那么我们就是在训练判别器始终预测“真”
discriminator.trainable = False
# 输入是潜在随机值
gan_input = keras.Input(shape=(latent_dim,))
# 输出就是判别器的判断
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input,gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=1e-8)
ganpile(optimizer=gan_optimizer, loss=binary_crossentropy)
import os
from keras.preprocessing import image
# 加载CIFAR10 数据
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()
猜你喜欢