以下の内容はhttps://htn20190109.hatenablog.com/entry/2025/12/30/114444より取得しました。


DCGAN

 


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# ===== デバイス設定 =====
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# ===== ハイパーパラメータ =====
epochs = 2
batch_size = 64
nz = 100       # ノイズベクトルの次元
lrD = 0.0004
lrG = 0.0001
beta1 = 0.5

# ===== データセット(CIFAR-10) =====
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
from torch.utils.data import Subset
dataset2 = Subset(dataset, list(range(1000)))
dataloader = DataLoader(dataset2, batch_size=batch_size, shuffle=True)
print(len(dataloader))

# ===== Generator モデル =====
class Generator(nn.Module):
    def __init__(self, nz):
        super().__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, x):
        return self.main(x)


# ===== Discriminator モデル =====
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, 8, 1, 0, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.main(x).view(-1)


# ===== モデル初期化 =====
net_g = Generator(nz).to(device)
net_d = Discriminator().to(device)

# ===== オプティマイザと損失関数 =====
criterion = nn.BCELoss()
optimizer_d = optim.Adam(net_d.parameters(), lr=lrD, betas=(beta1, 0.999))
optimizer_g = optim.Adam(net_g.parameters(), lr=lrG, betas=(beta1, 0.999))

# ===== 学習ループ =====
for epoch in range(epochs):
    loss_d_sum = 0.0
    loss_g_sum = 0.0
    for i, (real_images, _) in enumerate(dataloader):
        # --- Discriminator の学習 ---
        net_d.zero_grad()
        # 実データ
        real_images = real_images.to(device)
        output_real = net_d(real_images)
        real_label = torch.ones(len(output_real)).to(torch.float32).to(device)
        err_d_real = criterion(output_real, real_label)
        
        # 生成データ
        noise = torch.randn(real_images.size(0), nz, 1, 1, device=device)
        fake_images = net_g(noise)
        output_fake = net_d(fake_images.detach())
        fake_label = torch.zeros(len(output_fake)).to(torch.float32).to(device)
        err_d_fake = criterion(output_fake, fake_label)
        
        # 誤差の合計
        err_d = err_d_real + err_d_fake
        err_d.backward()
        optimizer_d.step()
        
        # --- Generator の学習 ---
        net_g.zero_grad()
        label_g = torch.ones(real_images.size(0), device=device)
        output = net_d(fake_images)
        err_g = criterion(output, label_g)
        err_g.backward()
        optimizer_g.step()
        
        loss_d_sum += err_d.item()
        loss_g_sum += err_g.item()
        if i % 200 == 0:
            print(f"[{epoch}/{epochs}] Step: {i} "
                  f"Loss_D: {err_d.item():.4f} Loss_G: {err_g.item():.4f}")


# ===== サンプル生成 =====
net_g.eval()
with torch.no_grad():
    fixed_noise = torch.randn(16, nz, 1, 1, device=device)
    fake_images = net_g(fixed_noise).cpu()
    fake_images = (fake_images + 1) / 2.0  # [-1,1] → [0,1]
    
grid = torch.cat([img for img in fake_images], dim=2).permute(1, 2, 0)
plt.figure(figsize=(12, 4))
plt.imshow(grid)
plt.title("Generated CIFAR-10 style images (DCGAN)")
plt.axis("off")
plt.show()




以上の内容はhttps://htn20190109.hatenablog.com/entry/2025/12/30/114444より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14