import torch
from torchvision import datasets
from torchvision.transforms import v2
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# --- ノイズ関数 ---
def add_gaussian_noise(img, mean=0.0, std=0.1):
print(img.shape)
noise = torch.randn_like(img) * std + mean
img_noisy = img + noise
return torch.clamp(img_noisy, 0.0, 1.0)
# --- データ変換 ---
transform = v2.Compose([
v2.ToTensor(),
v2.Lambda(lambda img: add_gaussian_noise(img, mean=0.0, std=0.05)),
])
# --- CIFAR-10 データセットの読み込み ---
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
# --- DataLoader ---
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# --- サンプルを可視化 ---
images, labels = next(iter(train_loader))
grid = torch.cat([img for img in images], dim=2).permute(1, 2, 0)
plt.figure(figsize=(12, 3))
plt.imshow(grid)
plt.title("CIFAR-10 with Gaussian Noise (std=0.05)")
plt.axis("off")
plt.show()