你乐谷
首页 > 图文

深度学习笔记56_GAN也很简单_GAN模型的训练(3)

2023-03-16 来源:你乐谷
# 选择青蛙图像(类别编号为 6)
x_train = x_train[y_train.flatten() == 6]
# 数据标准化
x_train = x_train.reshape((x_train.shape[0],) 
(height, width, channels)).astype(float32) / 255.
# 迭代的次数 10000次
iterations = 10000
batch_size = 20
save_dir = result
start = 0
for step in range(iterations):
# 在潜在空间中采样随机点
random_latent_vectors = np.random.normal(
size=(batch_size,
latent_dim))
# 通过生成模型,将这些点解码为虚假图像
generated_images = generator.predict(random_latent_vectors)
# 本次迭代的需要处理多少张图片
stop = start batch_size
# 获得真实的图片
real_images = x_train[start:stop]
# 将这些虚假图像与真实图像合在一起
combined_images = np.concatenate([generated_images,real_images])
# 合并标签,区分真实和虚假的图像
labels = np.concatenate([np.ones((batch_size,1)),
np.zeros((batch_size,1))])
# 向标签中添加随机噪声
labels  = 0.05 * np.random.random(labels.shape)
# 训练判别器
d_loss = discriminator.train_on_batch(combined_images,labels)
# 在潜在空间中采样随机点
random_latent_vectors = np.random.normal(size=(batch_size,latent_dim))
# 合并标签,全部是“真实图像”(这是在撒谎)
misleading_targets = np.zeros((batch_size,1))
# 通过 gan 模型来训练生成器( 此 时 冻 结 判别器权重)
a_loss = gan.train_on_batch(random_latent_vectors,
misleading_targets)
start  = batch_size
if startlen(x_train)-batch_size:
start = 0
# 每 100 步保存并绘图
if step % 100 ==0:
# 保存模型权重
gan.save_weights(gan.h5)
print(discriminator loss:, d_loss)
print(adversarial loss:, a_loss)
# 保存一张生成图像
img = image.array_to_img(generated_images[0] * 255., scale=False)
猜你喜欢