track_running_stats=True ※デフォルト
学習時: 各バッチごとに平均・分散を使用してデータ正規化
推論時: 学習時に蓄積した移動平均・分散を使用してデータ正規化
track_running_stats=False
学習時: 各バッチごとに平均・分散を使用してデータ正規化
推論時: 各バッチごとに平均・分散を使用してデータ正規化
→ 推論が安定しない可能性あり
import torch
import torch.nn as nn
# --- モデル定義 ---
class Model(nn.Module):
def __init__(self, track_running_stats=True):
super().__init__()
self.conv1 = nn.Conv2d(3, 4, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(4, track_running_stats=track_running_stats)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = torch.relu(x)
return x
# --- データ作成 ---
x_train = torch.randn(8, 3, 32, 32) # バッチサイズ8
x_test = torch.randn(4, 3, 32, 32)
# --- モデルを2種類作成 ---
model_true = Model(track_running_stats=True)
model_false = Model(track_running_stats=False)
model_true.train()
out_true_train = model_true(x_train)
print(f"running_mean : {model_true.bn1.running_mean}")
print(f"running_var : {model_true.bn1.running_var}")
print(f"mean : {out_true_train.mean().item():.4f}, std: {out_true_train.std().item():.4f}")
model_true.eval()
out_true_eval = model_true(x_test)
print(f"running_mean : {model_true.bn1.running_mean}")
print(f"running_var : {model_true.bn1.running_var}")
print(f"mean : {out_true_eval.mean().item():.4f}, std: {out_true_eval.std().item():.4f}")
model_false.train()
out_false_train = model_false(x_train)
print(f"running_mean : {model_false.bn1.running_mean}")
print(f"running_var : {model_false.bn1.running_var}")
print(f"mean : {out_false_train.mean().item():.4f}, std: {out_false_train.std().item():.4f}")
model_false.eval()
out_false_eval = model_false(x_test)
print(f"running_mean (after training): {model_false.bn1.running_mean}")
print(f"running_var (after training): {model_false.bn1.running_var}")
print(f"mean : {out_false_eval.mean().item():.4f}, std: {out_false_eval.std().item():.4f}")