以下の内容はhttps://htn20190109.hatenablog.com/entry/2025/11/29/171724より取得しました。


Mixup

import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# --- Mixup関数 ---
def mixup(X, y, alpha=0.9, num_classes=10):
    """Mixupデータ拡張"""
    batch_size = X.size(0)
    indices = torch.randperm(batch_size)
    lam = np.random.beta(alpha, alpha)
    lam = torch.tensor(lam, dtype=torch.float).to(X.device)
    
    X_mix = lam * X + (1 - lam) * X[indices, :]
    
    # One-hot変換
    y_onehot = torch.nn.functional.one_hot(y, num_classes=num_classes).float()
    y_mix = lam * y_onehot + (1 - lam) * y_onehot[indices, :]
    
    return X_mix, y_mix, lam


# --- CIFAR-10 読み込み ---
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    transform=transform,
    download=True
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# --- サンプル取得 ---
images, labels = next(iter(train_loader))

# --- Mixup適用 ---
images_mix, labels_mix, lam = mixup(images, labels, alpha=0.9)
print("λ (lambda) =", lam.item())
print("ラベルの例(Mixup後):\n", labels_mix[:2])


# --- 可視化 ---
grid = torch.cat([img for img in images_mix], dim=2).permute(1, 2, 0)
grid.shape
plt.figure(figsize=(12, 3))
plt.imshow(grid)
plt.title("CIFAR-10 Mixup Example (Left: Original, Right: Mixed)")
plt.axis("off")
plt.show()

 




以上の内容はhttps://htn20190109.hatenablog.com/entry/2025/11/29/171724より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14