import torch
import torch.nn as nn
class GRUNet(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.gru = nn.GRU(
input_size,
hidden_size,
num_layers,
batch_first=True
)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x, h=None):
# h が None の場合は PyTorch が内部で 0 初期化するのでそのままOK
out, h = self.gru(x, h)
print(out.shape)
print(out)
# 系列最後の隠れ状態を分類に利用
out = self.fc(out[:, -1, :])
print(out.shape)
print(out)
return out
# デモ実行
batch_size = 4
seq_len = 5
input_size = 3
hidden_size = 8
num_layers = 1
num_classes = 2
model = GRUNet(input_size, hidden_size, num_layers, num_classes)
x = torch.randn(batch_size, seq_len, input_size)
out = model(x)
print("出力 shape:", out.shape)