import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# ===== デバイス設定 =====
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# ===== ハイパーパラメータ =====
epochs = 2
batch_size = 64
nz = 100 # ノイズベクトルの次元
lrD = 0.0004
lrG = 0.0001
beta1 = 0.5
# ===== データセット(CIFAR-10) =====
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
from torch.utils.data import Subset
dataset2 = Subset(dataset, list(range(1000)))
dataloader = DataLoader(dataset2, batch_size=batch_size, shuffle=True)
print(len(dataloader))
# ===== Generator モデル =====
class Generator(nn.Module):
def __init__(self, nz):
super().__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(nz, 512, 4, 1, 0, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# ===== Discriminator モデル =====
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 1, 8, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x).view(-1)
# ===== モデル初期化 =====
net_g = Generator(nz).to(device)
net_d = Discriminator().to(device)
# ===== オプティマイザと損失関数 =====
criterion = nn.BCELoss()
optimizer_d = optim.Adam(net_d.parameters(), lr=lrD, betas=(beta1, 0.999))
optimizer_g = optim.Adam(net_g.parameters(), lr=lrG, betas=(beta1, 0.999))
# ===== 学習ループ =====
for epoch in range(epochs):
loss_d_sum = 0.0
loss_g_sum = 0.0
for i, (real_images, _) in enumerate(dataloader):
# --- Discriminator の学習 ---
net_d.zero_grad()
# 実データ
real_images = real_images.to(device)
output_real = net_d(real_images)
real_label = torch.ones(len(output_real)).to(torch.float32).to(device)
err_d_real = criterion(output_real, real_label)
# 生成データ
noise = torch.randn(real_images.size(0), nz, 1, 1, device=device)
fake_images = net_g(noise)
output_fake = net_d(fake_images.detach())
fake_label = torch.zeros(len(output_fake)).to(torch.float32).to(device)
err_d_fake = criterion(output_fake, fake_label)
# 誤差の合計
err_d = err_d_real + err_d_fake
err_d.backward()
optimizer_d.step()
# --- Generator の学習 ---
net_g.zero_grad()
label_g = torch.ones(real_images.size(0), device=device)
output = net_d(fake_images)
err_g = criterion(output, label_g)
err_g.backward()
optimizer_g.step()
loss_d_sum += err_d.item()
loss_g_sum += err_g.item()
if i % 200 == 0:
print(f"[{epoch}/{epochs}] Step: {i} "
f"Loss_D: {err_d.item():.4f} Loss_G: {err_g.item():.4f}")
# ===== サンプル生成 =====
net_g.eval()
with torch.no_grad():
fixed_noise = torch.randn(16, nz, 1, 1, device=device)
fake_images = net_g(fixed_noise).cpu()
fake_images = (fake_images + 1) / 2.0 # [-1,1] → [0,1]
grid = torch.cat([img for img in fake_images], dim=2).permute(1, 2, 0)
plt.figure(figsize=(12, 4))
plt.imshow(grid)
plt.title("Generated CIFAR-10 style images (DCGAN)")
plt.axis("off")
plt.show()