https://qiita.com/age884/items/7b8d5c583e59e755aaf0
https://tfull.hatenablog.jp/entry/2020/10/07/122132
from transformers import BertTokenizer, BertForMaskedLM, BertForNextSentencePrediction
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
mlm_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
nsp_model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
text = "Hello, BERT! How are you?"
tokens = tokenizer.tokenize(text)
print(f"Tokens: {tokens}")
token_ids = tokenizer.encode(tokens)
print(f"Token_IDs: {token_ids}")
decoded_tokens = tokenizer.convert_ids_to_tokens(token_ids)
print(f"Decoded Tokens: {decoded_tokens}")
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(f"Token_IDs: {token_ids}")
decoded_tokens = tokenizer.convert_ids_to_tokens(token_ids)
print(f"Decoded Tokens: {decoded_tokens}")
# Masked Language Modeling(MLM)
text_a = "I enjoy walking with my cute dog."
tokens = tokenizer.tokenize(text_a)
print("Tokens:", tokens)
masked_index = 2
tokens[masked_index] = '[MASK]'
print(f"Masked Tokens: {tokens}")
encoded_input_mlm = tokenizer(tokens, is_split_into_words=True, return_tensors='pt')
with torch.no_grad():
output = mlm_model(**encoded_input_mlm)
predictions = output.logits
predicted_token_id = torch.argmax(predictions[0, masked_index+1]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_token_id])[0]
print(f"Predicted token for [MASK]: {predicted_token}")
# Next Sentence Prediction(NSP)
text_a = "I enjoy walking with my cute dog."
text_b = "He's very playful."
#text_b = "This is a pen."
encoded_input_nsp = tokenizer(
text_a,
text_b,
return_tensors='pt',
truncation=True,
padding=True,
max_length=512
)
with torch.no_grad():
output = nsp_model(**encoded_input_nsp)
logits = output.logits
# NSP判定:0 → IsNextSentence, 1 → NotNextSentence
is_next = torch.argmax(logits, dim=1).item() == 0
print(f"output: {output}")
print(f"is_next: {is_next}")