以下の内容はhttps://htn20190109.hatenablog.com/entry/2025/11/16/232431より取得しました。


GRUNet

 

import torch
import torch.nn as nn

class GRUNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
    
        self.gru = nn.GRU(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x, h=None):
        # h が None の場合は PyTorch が内部で 0 初期化するのでそのままOK
        out, h = self.gru(x, h)
        
        print(out.shape)
        print(out)
        # 系列最後の隠れ状態を分類に利用
        out = self.fc(out[:, -1, :])
        
        print(out.shape)
        print(out)
        return out


# デモ実行
batch_size = 4
seq_len = 5
input_size = 3
hidden_size = 8
num_layers = 1
num_classes = 2

model = GRUNet(input_size, hidden_size, num_layers, num_classes)

x = torch.randn(batch_size, seq_len, input_size)

out = model(x)

print("出力 shape:", out.shape)

 




以上の内容はhttps://htn20190109.hatenablog.com/entry/2025/11/16/232431より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14