import torch
from torchvision import datasets
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# --- データ変換 ---
transform = v2.Compose([
v2.ToTensor(),
v2.ColorJitter(contrast=0.2, brightness=0.2, saturation=0.5, hue=0),
])
# --- CIFAR-10 データセットの読み込み ---
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
# --- DataLoader ---
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# --- サンプルを可視化 ---
images, labels = next(iter(train_loader))
grid = torch.cat([img for img in images], dim=2).permute(1, 2, 0)
plt.figure(figsize=(12, 3))
plt.imshow(grid)
plt.title("CIFAR-10 with ColorJitter")
plt.axis("off")
plt.show()