import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, models, transforms
# --- データ前処理 ---
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# --- CIFAR-10 データセット ---
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# --- VGG16 モデル読み込み ---
net = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
net.eval() # 推論モード(BN, Dropout無効)
print(net)
# --- 出力層を CIFAR-10 に変更 ---
net.classifier[6] = nn.Linear(in_features=4096, out_features=10)
# --- パラメータ更新設定 ---
params_to_update_1 =
params_to_update_2 =
params_to_update_3 = []
update_param_names_1 = ["features"]
update_param_names_2 = ["classifier.0.weight", "classifier.0.bias",
"classifier.3.weight", "classifier.3.bias"]
update_param_names_3 = ["classifier.6.weight", "classifier.6.bias"]
for name, param in net.named_parameters():
if update_param_names_1[0] in name:
param.requires_grad = True
params_to_update_1.append(param)
elif name in update_param_names_2:
param.requires_grad = True
params_to_update_2.append(param)
elif name in update_param_names_3:
param.requires_grad = True
params_to_update_3.append(param)
else:
param.requires_grad = False
# --- オプティマイザ ---
optimizer = optim.SGD([
{'params': params_to_update_1, 'lr': 0.0002},
{'params': params_to_update_2, 'lr': 0.0008},
{'params': params_to_update_3, 'lr': 0.001}
], momentum=0.9)
# --- 確認 ---
print("更新対象パラメータ数:")
print(f"Group1 (features): {len(params_to_update_1)}")
print(f"Group2 (classifier middle): {len(params_to_update_2)}")
print(f"Group3 (classifier final): {len(params_to_update_3)}")