以下の内容はhttps://htn20190109.hatenablog.com/entry/2025/11/16/231046より取得しました。


双方向RNN

 


import torch
import torch.nn as nn

class BiRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        
        self.rnn = nn.RNN(
            input_size,
            hidden_size,
            num_layers,
            batch_first=True,
            bidirectional=True
        )
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x, seq_lengths, masks):
        seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
        x = x[perm_idx]
        masks = masks[perm_idx]
        x_pack = nn.utils.rnn.pack_padded_sequence(
            x,
            seq_lengths.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size)
        out_pack, _ = self.rnn(x_pack, h0)
        out, _ = nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=True)
        _, unperm_idx = perm_idx.sort(0)
        out = out[unperm_idx]
        out = out * masks.unsqueeze(-1).float()
        print("RNN 出力:", out.shape)
        out = out[:, :, :self.hidden_size] + out[:, :, self.hidden_size:]
        print("双方向合成後:", out.shape)
        out = self.fc(out)
        return out

class BiRNNLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()
    def forward(self, outputs, targets, masks):
        # reshape
        targets = targets.view(-1)
        outputs = outputs.view(-1, outputs.size(-1))
        masks = masks.view(-1)
        
        # 個別の loss(reduction=None なので shape = [N])
        losses = self.loss_fn(outputs, targets)
        # マスク適用(padding 部分の損失は 0 にする)
        # マスクされた部分だけで平均
        loss = torch.sum(losses * masks) / torch.sum(masks)
        return loss


# ========================
# 動作確認
# ========================
batch_size = 3
max_len = 5
input_size = 4
hidden_size = 6
num_layers = 1
num_classes = 2

model = BiRNN(input_size, hidden_size, num_layers, num_classes)
criterion = BiRNNLoss()

x = torch.randn(batch_size, max_len, input_size)
seq_lengths = torch.tensor([5, 3, 4])

masks = torch.zeros(batch_size, max_len)
for i, l in enumerate(seq_lengths):
    masks[i, :l] = 1

out = model(x, seq_lengths, masks)
print("出力 shape:", out.shape)

# ダミーターゲット
targets = torch.randint(0, num_classes, (batch_size, max_len))
loss = criterion(out, targets, masks)

print("loss:", loss.item())

 

 




以上の内容はhttps://htn20190109.hatenablog.com/entry/2025/11/16/231046より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14