import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# --- Mixup関数 ---
def mixup(X, y, alpha=0.9, num_classes=10):
"""Mixupデータ拡張"""
batch_size = X.size(0)
indices = torch.randperm(batch_size)
lam = np.random.beta(alpha, alpha)
lam = torch.tensor(lam, dtype=torch.float).to(X.device)
X_mix = lam * X + (1 - lam) * X[indices, :]
# One-hot変換
y_onehot = torch.nn.functional.one_hot(y, num_classes=num_classes).float()
y_mix = lam * y_onehot + (1 - lam) * y_onehot[indices, :]
return X_mix, y_mix, lam
# --- CIFAR-10 読み込み ---
transform = transforms.Compose([
transforms.ToTensor()
])
train_dataset = datasets.CIFAR10(
root="./data",
train=True,
transform=transform,
download=True
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# --- サンプル取得 ---
images, labels = next(iter(train_loader))
# --- Mixup適用 ---
images_mix, labels_mix, lam = mixup(images, labels, alpha=0.9)
print("λ (lambda) =", lam.item())
print("ラベルの例(Mixup後):\n", labels_mix[:2])
# --- 可視化 ---
grid = torch.cat([img for img in images_mix], dim=2).permute(1, 2, 0)
grid.shape
plt.figure(figsize=(12, 3))
plt.imshow(grid)
plt.title("CIFAR-10 Mixup Example (Left: Original, Right: Mixed)")
plt.axis("off")
plt.show()