import torch
import torch.nn as nn
import math
# ==============================
# Scaled Dot-Product Attention
# ==============================
class ScaledDotProductAttention(nn.Module):
def __init__(self, scale_factor):
super().__init__()
self.scale_factor = scale_factor
def forward(self, q, k, v):
dk = k.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(dk * self.scale_factor)
attention = torch.softmax(scores, dim=-1)
output = torch.matmul(attention, v)
return output
# ==============================
# Multi-Head Attention
# ==============================
class MultiheadAttention(nn.Module):
def __init__(self, num_heads, input_size, head_size):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.q_linear = nn.Linear(input_size, num_heads * head_size)
self.k_linear = nn.Linear(input_size, num_heads * head_size)
self.v_linear = nn.Linear(input_size, num_heads * head_size)
self.attention = ScaledDotProductAttention(scale_factor=1.0)
self.fc = nn.Linear(num_heads * head_size, input_size)
def forward(self, q, k, v):
bs = q.size(0)
# Q, K, V の変換
q = self.q_linear(q).view(bs, -1, self.num_heads, self.head_size).transpose(1, 2)
k = self.k_linear(k).view(bs, -1, self.num_heads, self.head_size).transpose(1, 2)
v = self.v_linear(v).view(bs, -1, self.num_heads, self.head_size).transpose(1, 2)
# Attention
attention_scores = self.attention(q, k, v)
# 結合
attention = attention_scores.transpose(1, 2).contiguous().view(bs, -1, self.num_heads * self.head_size)
output = self.fc(attention)
return output
# ==============================
# Positional Encoding
# ==============================
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
# ==============================
# Transformer Block
# ==============================
class TransformerBlock(nn.Module):
def __init__(self, input_size, num_heads, head_size, dropout_rate=0.1):
super().__init__()
self.multihead_attention = MultiheadAttention(num_heads, input_size, head_size)
self.layer_norm1 = nn.LayerNorm(input_size)
self.dropout1 = nn.Dropout(dropout_rate)
self.feedforward = nn.Sequential(
nn.Linear(input_size, 4 * input_size),
nn.ReLU(),
nn.Linear(4 * input_size, input_size)
)
self.layer_norm2 = nn.LayerNorm(input_size)
self.dropout2 = nn.Dropout(dropout_rate)
def forward(self, x):
attn_output = self.multihead_attention(x, x, x)
x = self.layer_norm1(x + self.dropout1(attn_output))
ff_output = self.feedforward(x)
x = self.layer_norm2(x + self.dropout2(ff_output))
return x
batch_size = 2
seq_len = 5
input_size = 16
num_heads = 4
head_size = 4
model = TransformerBlock(input_size, num_heads, head_size)
pos_enc = PositionalEncoding(d_model=input_size)
x = torch.randn(batch_size, seq_len, input_size)
x = pos_enc(x)
out = model(x)
print("出力 shape:", out.shape)