import torch
import torch.nn as nn
# --- モデル定義 ---
class Model(nn.Module):
def __init__(self):
super().__init__()
self.main = nn.Sequential(
nn.Conv2d(1024, 1024, kernel_size=7, stride=2, padding=3, bias=False),
nn.InstanceNorm2d(1024),
nn.LeakyReLU(0.2)
)
self.last = nn.Sequential(
nn.Conv2d(1024, 1, kernel_size=(4, 2), stride=1, padding=0, bias=False),
nn.InstanceNorm2d(1)
)
def forward(self, x):
out = self.main(x)
out = self.last(out)
out = out.view(out.size(0), -1)
out = torch.tanh(out)
return out
# モデル作成
model = Model()
# ダミー入力(例:バッチサイズ=2、チャンネル=1024、画像サイズ=32x32)
x = torch.randn(2, 1024, 32, 32)
# 順伝播
y = model(x)
print("入力 shape:", x.shape)
print("出力 shape:", y.shape)
#print("出力テンソル例:\n", y)
# パラメータ一覧表示
print("\n--- モデルパラメータ ---")
for name, param in model.named_parameters():
print(name, param.shape)
# モデル構造表示
print("\n--- モデル構造 ---")
print(model)
from torchinfo import summary
summary(model, (2, 1024, 32, 32))