import torch
import torch.nn as nn
from torchinfo import summary
# --- モデル定義 ---
class Model(nn.Module):
def __init__(self):
super().__init__()
# 入力チャンネル3、出力チャンネル6
self.conv = nn.Conv2d(3, 6, kernel_size=3, stride=1, padding=1)
# GroupNorm: グループ数3, チャンネル数6 → 各グループに2チャンネルずつ
self.norm = nn.GroupNorm(3, 6)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = torch.relu(x)
return x
# --- モデル作成 ---
model = Model()
# --- モデル構造表示 ---
print("\n--- モデル構造 ---")
print(model)
# --- モデル要約情報 ---
# 入力テンソル形状を正しく設定(batch=2, channel=3, height=32, width=32)
summary(model, (2, 3, 32, 32))
# --- 実際の動作確認 ---
x = torch.randn(2, 3, 32, 32)
y = model(x)
print("\n出力 shape:", y.shape)