import torch
import torch.nn as nn
class BiRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.RNN(
input_size,
hidden_size,
num_layers,
batch_first=True,
bidirectional=True
)
self.fc = nn.Linear(hidden_size, num_classes)
def forward(self, x, seq_lengths, masks):
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
x = x[perm_idx]
masks = masks[perm_idx]
x_pack = nn.utils.rnn.pack_padded_sequence(
x,
seq_lengths.cpu(),
batch_first=True,
enforce_sorted=False
)
h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size)
out_pack, _ = self.rnn(x_pack, h0)
out, _ = nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=True)
_, unperm_idx = perm_idx.sort(0)
out = out[unperm_idx]
out = out * masks.unsqueeze(-1).float()
print("RNN 出力:", out.shape)
out = out[:, :, :self.hidden_size] + out[:, :, self.hidden_size:]
print("双方向合成後:", out.shape)
out = self.fc(out)
return out
class BiRNNLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, outputs, targets, masks):
# reshape
targets = targets.view(-1)
outputs = outputs.view(-1, outputs.size(-1))
masks = masks.view(-1)
# 個別の loss(reduction=None なので shape = [N])
losses = self.loss_fn(outputs, targets)
# マスク適用(padding 部分の損失は 0 にする)
# マスクされた部分だけで平均
loss = torch.sum(losses * masks) / torch.sum(masks)
return loss
# ========================
# 動作確認
# ========================
batch_size = 3
max_len = 5
input_size = 4
hidden_size = 6
num_layers = 1
num_classes = 2
model = BiRNN(input_size, hidden_size, num_layers, num_classes)
criterion = BiRNNLoss()
x = torch.randn(batch_size, max_len, input_size)
seq_lengths = torch.tensor([5, 3, 4])
masks = torch.zeros(batch_size, max_len)
for i, l in enumerate(seq_lengths):
masks[i, :l] = 1
out = model(x, seq_lengths, masks)
print("出力 shape:", out.shape)
# ダミーターゲット
targets = torch.randint(0, num_classes, (batch_size, max_len))
loss = criterion(out, targets, masks)
print("loss:", loss.item())