import torch
import numpy as np
from torchvision import datasets
from torchvision.transforms import v2
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
# --- RandAugment 実装 ---
class RandAugment:
def __init__(self, n=2, m=10):
"""
n: 1回に適用する操作数
m: 強度(未使用、ここでは各操作がランダムにパラメータ決定)
"""
self.n = n
self.m = m
self.operations = [
self.randaugment_operation('RandomErasing'),
self.randaugment_operation('RandomCrop', (32, 32)),
self.randaugment_operation('AdjustContrast', 0.5, 1.5),
self.randaugment_operation('AdjustBrightness', 0.5, 1.5),
self.randaugment_operation('RandomRotation', -30, 30),
]
def __call__(self, img):
ops = np.random.choice(self.operations, size=self.n, replace=False)
for op in ops:
img = op(img)
return img
def randaugment_operation(self, opname, *args):
if opname == 'RandomErasing':
# torchvision v2 ではTransformを呼び出し可能
eraser = v2.RandomErasing(p=0.3)
return lambda img: eraser(img)
elif opname == 'RandomCrop':
return lambda img: F.crop(img, top=0, left=0, height=args[0][0], width=args[0][1])
elif opname == 'AdjustContrast':
return lambda img: F.adjust_contrast(img, np.random.uniform(*args))
elif opname == 'AdjustBrightness':
return lambda img: F.adjust_brightness(img, np.random.uniform(*args))
elif opname == 'RandomRotation':
return lambda img: F.rotate(img, np.random.uniform(*args))
else:
raise ValueError('Invalid operation name')
# --- Transform 定義 ---
transform = v2.Compose([
v2.ToTensor(),
RandAugment(n=2),
])
# --- CIFAR-10 読み込み ---
train_dataset = datasets.CIFAR10(
root="./data",
train=True,
transform=transform,
download=True
)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# --- サンプル可視化 ---
images, labels = next(iter(train_loader))
grid = torch.cat([img for img in images], dim=2).permute(1, 2, 0)
plt.figure(figsize=(12, 3))
plt.imshow(grid)
plt.title("CIFAR-10 with RandAugment (2 ops)")
plt.axis("off")
plt.show()