Tensorflow2.10使用BERT從文本中抽取答案實現(xiàn)詳解
前言
本文詳細介紹了用 tensorflow-gpu 2.10 版本實現(xiàn)一個簡單的從文本中抽取答案的過程。
數(shù)據(jù)準(zhǔn)備
這里主要用于準(zhǔn)備訓(xùn)練和評估 SQuAD(Standford Question Answering Dataset)數(shù)據(jù)集的 Bert 模型所需的數(shù)據(jù)和工具。
首先,通過導(dǎo)入相關(guān)庫,包括 os、re、json、string、numpy、tensorflow、tokenizers 和 transformers,為后續(xù)處理數(shù)據(jù)和構(gòu)建模型做好準(zhǔn)備。 然后,設(shè)置了最大長度為384 ,并創(chuàng)建了一個 BertConfig 對象。接著從 Hugging Face 模型庫中下載預(yù)訓(xùn)練模型 bert-base-uncased 模型的 tokenizer ,并將其保存到同一目錄下的名叫 bert_base_uncased 文件夾中。 當(dāng)下載結(jié)束之后,使用 BertWordPieceTokenizer 從已下載的文件夾中夾在 tokenizer 的詞匯表從而創(chuàng)建分詞器 tokenizer 。
剩下的部分就是從指定的 URL 下載訓(xùn)練和驗證集,并使用 keras.utils.get_file() 將它們保存到本地,一般存放在 “用戶目錄.keras\datasets”下 ,以便后續(xù)的數(shù)據(jù)預(yù)處理和模型訓(xùn)練。
import os
import re
import json
import string
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer, TFBertModel, BertConfig
max_len = 384
configuration = BertConfig()
slow_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
save_path = "bert_base_uncased/"
if not os.path.exists(save_path):
os.makedirs(save_path)
slow_tokenizer.save_pretrained(save_path)
tokenizer = BertWordPieceTokenizer("bert_base_uncased/vocab.txt", lowercase=True)
train_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"
train_path = keras.utils.get_file("train.json", train_data_url)
eval_data_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"
eval_path = keras.utils.get_file("eval.json", eval_data_url)
打?。?/p>
Downloading data from https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json
30288272/30288272 [==============================] - 131s 4us/step
Downloading data from https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json
4854279/4854279 [==============================] - 20s 4us/step
模型輸入、輸出處理
這里定義了一個名為 SquadExample 的類,用于表示一個 SQuAD 數(shù)據(jù)集中的問題和對應(yīng)的上下文片段、答案位置等信息。
該類的構(gòu)造函數(shù) __init__() 接受五個參數(shù):問題(question)、上下文(context)、答案起始字符索引(start_char_idx)、答案文本(answer_text) 和所有答案列表 (all_answers) 。
類還包括一個名為 preprocess() 的方法,用于對每個 SQuAD 樣本進行預(yù)處理,首先對context 、question 和 answer 進行預(yù)處理,并計算出答案的結(jié)束位置 end_char_idx 。接下來,根據(jù) start_char_idx 和 end_char_idx 在 context 的位置,構(gòu)建了一個表示 context 中哪些字符屬于 answer 的列表 is_char_in_ans 。然后,使用 tokenizer 對 context 進行編碼,得到 tokenized_context。
接著,通過比較 answer 的字符位置和 context 中每個標(biāo)記的字符位置,得到了包含答案的標(biāo)記的索引列表 ans_token_idx 。如果 answer 未在 context 中找到,則將 skip 屬性設(shè)置為 True ,并直接返回空結(jié)果。
最后,將 context 和 question 的序列拼接成輸入序列 input_ids ,并根據(jù)兩個句子的不同生成了同樣長度的序列 token_type_ids 以及與 input_ids 同樣長度的 attention_mask 。然后對這三個序列進行了 padding 操作。
class SquadExample:
def __init__(self, question, context, start_char_idx, answer_text, all_answers):
self.question = question
self.context = context
self.start_char_idx = start_char_idx
self.answer_text = answer_text
self.all_answers = all_answers
self.skip = False
def preprocess(self):
context = self.context
question = self.question
answer_text = self.answer_text
start_char_idx = self.start_char_idx
context = " ".join(str(context).split())
question = " ".join(str(question).split())
answer = " ".join(str(answer_text).split())
end_char_idx = start_char_idx + len(answer)
if end_char_idx >= len(context):
self.skip = True
return
is_char_in_ans = [0] * len(context)
for idx in range(start_char_idx, end_char_idx):
is_char_in_ans[idx] = 1
tokenized_context = tokenizer.encode(context)
ans_token_idx = []
for idx, (start, end) in enumerate(tokenized_context.offsets):
if sum(is_char_in_ans[start:end]) > 0:
ans_token_idx.append(idx)
if len(ans_token_idx) == 0:
self.skip = True
return
start_token_idx = ans_token_idx[0]
end_token_idx = ans_token_idx[-1]
tokenized_question = tokenizer.encode(question)
input_ids = tokenized_context.ids + tokenized_question.ids[1:]
token_type_ids = [0] * len(tokenized_context.ids) + [1] * len(tokenized_question.ids[1:])
attention_mask = [1] * len(input_ids)
padding_length = max_len - len(input_ids)
if padding_length > 0:
input_ids = input_ids + ([0] * padding_length)
attention_mask = attention_mask + ([0] * padding_length)
token_type_ids = token_type_ids + ([0] * padding_length)
elif padding_length < 0:
self.skip = True
return
self.input_ids = input_ids
self.token_type_ids = token_type_ids
self.attention_mask = attention_mask
self.start_token_idx = start_token_idx
self.end_token_idx = end_token_idx
self.context_token_to_char = tokenized_context.offsets
這里的兩個函數(shù)用于準(zhǔn)備數(shù)據(jù)以訓(xùn)練一個使用 BERT 結(jié)構(gòu)的問答模型。
第一個函數(shù) create_squad_examples 接受一個 JSON 文件的原始數(shù)據(jù),將里面的每條數(shù)據(jù)都變成 SquadExample 類所定義的輸入格式。
第二個函數(shù) create_inputs_targets 將 SquadExample 對象列表轉(zhuǎn)換為模型的輸入和目標(biāo)。這個函數(shù)返回兩個列表,一個是模型的輸入,包含了 input_ids 、token_type_ids 、 attention_mask ,另一個是模型的目標(biāo),包含了 start_token_idx 、end_token_idx。
def create_squad_examples(raw_data):
squad_examples = []
for item in raw_data["data"]:
for para in item["paragraphs"]:
context = para["context"]
for qa in para["qas"]:
question = qa["question"]
answer_text = qa["answers"][0]["text"]
all_answers = [_["text"] for _ in qa["answers"]]
start_char_idx = qa["answers"][0]["answer_start"]
squad_eg = SquadExample(question, context, start_char_idx, answer_text, all_answers)
squad_eg.preprocess()
squad_examples.append(squad_eg)
return squad_examples
def create_inputs_targets(squad_examples):
dataset_dict = {
"input_ids": [],
"token_type_ids": [],
"attention_mask": [],
"start_token_idx": [],
"end_token_idx": [],
}
for item in squad_examples:
if item.skip == False:
for key in dataset_dict:
dataset_dict[key].append(getattr(item, key))
for key in dataset_dict:
dataset_dict[key] = np.array(dataset_dict[key])
x = [ dataset_dict["input_ids"], dataset_dict["token_type_ids"], dataset_dict["attention_mask"], ]
y = [dataset_dict["start_token_idx"], dataset_dict["end_token_idx"]]
return x, y
這里主要讀取了 SQuAD 訓(xùn)練集和驗證集的 JSON 文件,并使用create_squad_examples 函數(shù)將原始數(shù)據(jù)轉(zhuǎn)換為 SquadExample 對象列表。然后使用 create_inputs_targets 函數(shù)將這些 SquadExample 對象列表轉(zhuǎn)換為模型輸入和目標(biāo)輸出。最后輸出打印了已創(chuàng)建的訓(xùn)練數(shù)據(jù)樣本數(shù)和評估數(shù)據(jù)樣本數(shù)。
with open(train_path) as f:
raw_train_data = json.load(f)
with open(eval_path) as f:
raw_eval_data = json.load(f)
train_squad_examplesa = create_squad_examples(raw_train_data)
x_train, y_train = create_inputs_targets(train_squad_examples)
print(f"{len(train_squad_examples)} training points created.")
eval_squad_examples = create_squad_examples(raw_eval_data)
x_eval, y_eval = create_inputs_targets(eval_squad_examples)
print(f"{len(eval_squad_examples)} evaluation points created.")
打?。?/p>
87599 training points created.
10570 evaluation points created.
模型搭建
這里定義了一個基于 BERT 的問答模型。在 create_model() 函數(shù)中,首先使用 TFBertModel.from_pretrained() 方法加載預(yù)訓(xùn)練的 BERT 模型。然后創(chuàng)建了三個輸入層(input_ids、token_type_ids 和 attention_mask),每個輸入層的形狀都是(max_len,) 。這些輸入層用于接收模型的輸入數(shù)據(jù)。
接下來使用 encoder() 方法對輸入進行編碼得到 embedding ,然后分別對這些向量表示進行全連接層的操作,得到一個 start_logits 和一個 end_logits 。接著分別對這兩個向量進行扁平化操作,并將其傳遞到激活函數(shù) softmax 中,得到一個 start_probs 向量和一個 end_probs 向量。
最后,將這三個輸入層和這兩個輸出層傳遞給 keras.Model() 函數(shù),構(gòu)建出一個模型。此模型使用 SparseCategoricalCrossentropy 損失函數(shù)進行編譯,并使用 Adam 優(yōu)化器進行訓(xùn)練。
def create_model():
encoder = TFBertModel.from_pretrained("bert-base-uncased")
input_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
token_type_ids = layers.Input(shape=(max_len,), dtype=tf.int32)
attention_mask = layers.Input(shape=(max_len,), dtype=tf.int32)
embedding = encoder(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)[0]
start_logits = layers.Dense(1, name="start_logit", use_bias=False)(embedding)
start_logits = layers.Flatten()(start_logits)
end_logits = layers.Dense(1, name="end_logit", use_bias=False)(embedding)
end_logits = layers.Flatten()(end_logits)
start_probs = layers.Activation(keras.activations.softmax)(start_logits)
end_probs = layers.Activation(keras.activations.softmax)(end_logits)
model = keras.Model( inputs=[input_ids, token_type_ids, attention_mask], outputs=[start_probs, end_probs],)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False)
optimizer = keras.optimizers.Adam(lr=5e-5)
model.compile(optimizer=optimizer, loss=[loss, loss])
return model
這里主要是展示了一下模型的架構(gòu),可以看到所有的參數(shù)都可以訓(xùn)練,并且主要調(diào)整的部分都幾乎是 bert 中的參數(shù)。
model = create_model() model.summary()
打?。?/p>
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_4 (InputLayer) [(None, 384)] 0 []
input_6 (InputLayer) [(None, 384)] 0 []
input_5 (InputLayer) [(None, 384)] 0 []
tf_bert_model_1 (TFBertModel) TFBaseModelOutputWi 109482240 ['input_4[0][0]',
thPoolingAndCrossAt 'input_6[0][0]',
tentions(last_hidde 'input_5[0][0]']
n_state=(None, 384,
768),
pooler_output=(Non
e, 768),
past_key_values=No
ne, hidden_states=N
one, attentions=Non
e, cross_attentions
=None)
start_logit (Dense) (None, 384, 1) 768 ['tf_bert_model_1[0][0]']
end_logit (Dense) (None, 384, 1) 768 ['tf_bert_model_1[0][0]']
flatten_2 (Flatten) (None, 384) 0 ['start_logit[0][0]']
flatten_3 (Flatten) (None, 384) 0 ['end_logit[0][0]']
activation_2 (Activation) (None, 384) 0 ['flatten_2[0][0]']
activation_3 (Activation) (None, 384) 0 ['flatten_3[0][0]']
==================================================================================================
Total params: 109,483,776
Trainable params: 109,483,776
Non-trainable params: 0
自定義驗證回調(diào)函數(shù)
這里定義了一個回調(diào)函數(shù) ExactMatch , 有一個初始化方法 __init__ ,接收驗證集的輸入和目標(biāo) x_eval 和 y_eval 。該類還實現(xiàn)了 on_epoch_end 方法,在每個 epoch 結(jié)束時調(diào)用,計算模型的預(yù)測值,并計算精確匹配分?jǐn)?shù)。
具體地,on_epoch_end 方法首先使用模型對 x_eval 進行預(yù)測,得到預(yù)測的起始位置 pred_start 和結(jié)束位置 pred_end ,并進一步找到對應(yīng)的預(yù)測答案和正確答案標(biāo)準(zhǔn)化為 normalized_pred_ans 和 normalized_true_ans ,如果前者存在于后者,則說明該樣本被正確地回答,最終將精確匹配分?jǐn)?shù)打印出來。
def normalize_text(text):
text = text.lower()
exclude = set(string.punctuation)
text = "".join(ch for ch in text if ch not in exclude)
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
text = re.sub(regex, " ", text)
text = " ".join(text.split())
return text
class ExactMatch(keras.callbacks.Callback):
def __init__(self, x_eval, y_eval):
self.x_eval = x_eval
self.y_eval = y_eval
def on_epoch_end(self, epoch, logs=None):
pred_start, pred_end = self.model.predict(self.x_eval)
count = 0
eval_examples_no_skip = [_ for _ in eval_squad_examples if _.skip == False]
for idx, (start, end) in enumerate(zip(pred_start, pred_end)):
squad_eg = eval_examples_no_skip[idx]
offsets = squad_eg.context_token_to_char
start = np.argmax(start)
end = np.argmax(end)
if start >= len(offsets):
continue
pred_char_start = offsets[start][0]
if end < len(offsets):
pred_char_end = offsets[end][1]
pred_ans = squad_eg.context[pred_char_start:pred_char_end]
else:
pred_ans = squad_eg.context[pred_char_start:]
normalized_pred_ans = normalize_text(pred_ans)
normalized_true_ans = [normalize_text(_) for _ in squad_eg.all_answers]
if normalized_pred_ans in normalized_true_ans:
count += 1
acc = count / len(self.y_eval[0])
print(f"\nepoch={epoch+1}, exact match score={acc:.2f}")
模型訓(xùn)練和驗證
訓(xùn)練模型,并使用驗證集對模型的性能進行測試。這里的 epoch 只設(shè)置了 1 ,如果數(shù)值增大效果會更好。
exact_match_callback = ExactMatch(x_eval, y_eval) model.fit( x_train, y_train, epochs=1, verbose=2, batch_size=16, callbacks=[exact_match_callback],)
打?。?/p>
23/323 [==============================] - 47s 139ms/step
epoch=1, exact match score=0.77
5384/5384 - 1268s - loss: 2.4677 - activation_2_loss: 1.2876 - activation_3_loss: 1.1800 - 1268s/epoch - 236ms/step
以上就是Tensorflow2.10使用BERT從文本中抽取答案實現(xiàn)詳解的詳細內(nèi)容,更多關(guān)于Tensorflow BERT文本抽取的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python學(xué)習(xí)之循環(huán)方法詳解
循環(huán)是有著周而復(fù)始的運動或變化的規(guī)律;在 Python 中,循環(huán)的操作也叫做 遍歷。與現(xiàn)實中一樣,Python 中也同樣存在著無限循環(huán)的方法與有限循環(huán)的方法。本文將通過示例詳細講解Python中的循環(huán)方法,需要的可以參考一下2022-03-03
Python實現(xiàn)的插入排序算法原理與用法實例分析
這篇文章主要介紹了Python實現(xiàn)的插入排序算法原理與用法,簡單描述了插入排序的原理,并結(jié)合實例形式分析了Python實現(xiàn)插入排序的相關(guān)操作技巧,需要的朋友可以參考下2017-11-11
python中dict字典的查詢鍵值對 遍歷 排序 創(chuàng)建 訪問 更新 刪除基礎(chǔ)操作方法
字典的每個鍵值(key=>value)對用冒號(:)分割,每個對之間用逗號(,)分割,整個字典包括在花括號({})中,本文講述了python中dict字典的查詢鍵值對 遍歷 排序 創(chuàng)建 訪問 更新 刪除基礎(chǔ)操作方法2018-09-09
基于Python實現(xiàn)一個自動關(guān)機程序并打包成exe文件
這篇文章主要介紹了通過Python創(chuàng)建一個可以自動關(guān)機的小程序,并打包成exe文件。文中的示例代碼講解詳細,對我們學(xué)習(xí)Python有一定的幫助,感興趣的同學(xué)可以了解一下2021-12-12

