以下の内容はhttps://htn20190109.hatenablog.com/entry/2025/11/27/205239より取得しました。


RandAugment

 

import torch
import numpy as np
from torchvision import datasets
from torchvision.transforms import v2
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


# --- RandAugment 実装 ---
class RandAugment:
    def __init__(self, n=2, m=10):
        """
        n: 1回に適用する操作数
        m: 強度(未使用、ここでは各操作がランダムにパラメータ決定)
        """
        self.n = n
        self.m = m
        self.operations = [
            self.randaugment_operation('RandomErasing'),
            self.randaugment_operation('RandomCrop', (32, 32)),
            self.randaugment_operation('AdjustContrast', 0.5, 1.5),
            self.randaugment_operation('AdjustBrightness', 0.5, 1.5),
            self.randaugment_operation('RandomRotation', -30, 30),
        ]
        
    def __call__(self, img):
        ops = np.random.choice(self.operations, size=self.n, replace=False)
        for op in ops:
            img = op(img)
        return img
    def randaugment_operation(self, opname, *args):
        if opname == 'RandomErasing':
            # torchvision v2 ではTransformを呼び出し可能
            eraser = v2.RandomErasing(p=0.3)
            return lambda img: eraser(img)
        elif opname == 'RandomCrop':
            return lambda img: F.crop(img, top=0, left=0, height=args[0][0], width=args[0][1])
        elif opname == 'AdjustContrast':
            return lambda img: F.adjust_contrast(img, np.random.uniform(*args))
        elif opname == 'AdjustBrightness':
            return lambda img: F.adjust_brightness(img, np.random.uniform(*args))
        elif opname == 'RandomRotation':
            return lambda img: F.rotate(img, np.random.uniform(*args))
        else:
            raise ValueError('Invalid operation name')


# --- Transform 定義 ---
transform = v2.Compose([
    v2.ToTensor(),
    RandAugment(n=2),
])


# --- CIFAR-10 読み込み ---
train_dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    transform=transform,
    download=True
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# --- サンプル可視化 ---
images, labels = next(iter(train_loader))
grid = torch.cat([img for img in images], dim=2).permute(1, 2, 0)

plt.figure(figsize=(12, 3))
plt.imshow(grid)
plt.title("CIFAR-10 with RandAugment (2 ops)")
plt.axis("off")
plt.show()




以上の内容はhttps://htn20190109.hatenablog.com/entry/2025/11/27/205239より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14