.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def _pad_bert_inputs(examples, max_len, vocab):
max_num_mlm_preds = round(max_len * 0.15)
all_token_ids, all_segments, valid_lens, = [], [], []
all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
nsp_labels = []
for (token_ids, pred_positions, mlm_pred_label_ids, segments,
is_next) in examples:
all_token_ids.append(torch.tensor(token_ids + [vocab['
']] * (
max_len - len(token_ids)), dtype=torch.long))
all_segments.append(torch.tensor(segments + [0] * (
max_len - len(segments)), dtype=torch.long))
# `valid_lens` excludes count of '' tokens
valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))
all_pred_positions.append(torch.tensor(pred_positions + [0] * (
max_num_mlm_preds - len(pred_positions)), dtype=torch.long))
# Predictions of padded tokens will be filtered out in the loss via
# multiplication of 0 weights
all_mlm_weights.append(
torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (
max_num_mlm_preds - len(pred_positions)),
dtype=torch.float32))
all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (
max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))
nsp_labels.append(torch.tensor(is_next, dtype=torch.long))
return (all_token_ids, all_segments, valid_lens, all_pred_positions,
all_mlm_weights, all_mlm_labels, nsp_labels)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def _pad_bert_inputs(examples, max_len, vocab):
max_num_mlm_preds = round(max_len * 0.15)
all_token_ids, all_segments, valid_lens, = [], [], []
all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []
nsp_labels = []
for (token_ids, pred_positions, mlm_pred_label_ids, segments,
is_next) in examples:
all_token_ids.append(np.array(token_ids + [vocab['
']] * (
max_len - len(token_ids)), dtype='int32'))
all_segments.append(np.array(segments + [0] * (
max_len - len(segments)), dtype='int32'))
# `valid_lens` excludes count of '' tokens
valid_lens.append(np.array(len(token_ids), dtype='float32'))
all_pred_positions.append(np.array(pred_positions + [0] * (
max_num_mlm_preds - len(pred_positions)), dtype='int32'))
# Predictions of padded tokens will be filtered out in the loss via
# multiplication of 0 weights
all_mlm_weights.append(
np.array([1.0] * len(mlm_pred_label_ids) + [0.0] * (
max_num_mlm_preds - len(pred_positions)), dtype='float32'))
all_mlm_labels.append(np.array(mlm_pred_label_ids + [0] * (
max_num_mlm_preds - len(mlm_pred_label_ids)), dtype='int32'))
nsp_labels.append(np.array(is_next))
return (all_token_ids, all_segments, valid_lens, all_pred_positions,
all_mlm_weights, all_mlm_labels, nsp_labels)
.. raw:: html
.. raw:: html
Putting the helper functions for generating training examples of the two
pretraining tasks, and the helper function for padding inputs together,
we customize the following ``_WikiTextDataset`` class as the WikiText-2
dataset for pretraining BERT. By implementing the
``__getitem__``\ function, we can arbitrarily access the pretraining
(masked language modeling and next sentence prediction) examples
generated from a pair of sentences from the WikiText-2 corpus.
The original BERT model uses WordPiece embeddings whose vocabulary size
is 30000 :cite:`Wu.Schuster.Chen.ea.2016`. The tokenization method of
WordPiece is a slight modification of the original byte pair encoding
algorithm in :numref:`subsec_Byte_Pair_Encoding`. For simplicity, we
use the ``d2l.tokenize`` function for tokenization. Infrequent tokens
that appear less than five times are filtered out.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class _WikiTextDataset(torch.utils.data.Dataset):
def __init__(self, paragraphs, max_len):
# Input `paragraphs[i]` is a list of sentence strings representing a
# paragraph; while output `paragraphs[i]` is a list of sentences
# representing a paragraph, where each sentence is a list of tokens
paragraphs = [d2l.tokenize(
paragraph, token='word') for paragraph in paragraphs]
sentences = [sentence for paragraph in paragraphs
for sentence in paragraph]
self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
'
', '', '', ''])
# Get data for the next sentence prediction task
examples = []
for paragraph in paragraphs:
examples.extend(_get_nsp_data_from_paragraph(
paragraph, paragraphs, self.vocab, max_len))
# Get data for the masked language model task
examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)
+ (segments, is_next))
for tokens, segments, is_next in examples]
# Pad inputs
(self.all_token_ids, self.all_segments, self.valid_lens,
self.all_pred_positions, self.all_mlm_weights,
self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(
examples, max_len, self.vocab)
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx], self.all_pred_positions[idx],
self.all_mlm_weights[idx], self.all_mlm_labels[idx],
self.nsp_labels[idx])
def __len__(self):
return len(self.all_token_ids)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
class _WikiTextDataset(gluon.data.Dataset):
def __init__(self, paragraphs, max_len):
# Input `paragraphs[i]` is a list of sentence strings representing a
# paragraph; while output `paragraphs[i]` is a list of sentences
# representing a paragraph, where each sentence is a list of tokens
paragraphs = [d2l.tokenize(
paragraph, token='word') for paragraph in paragraphs]
sentences = [sentence for paragraph in paragraphs
for sentence in paragraph]
self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=[
'
', '', '', ''])
# Get data for the next sentence prediction task
examples = []
for paragraph in paragraphs:
examples.extend(_get_nsp_data_from_paragraph(
paragraph, paragraphs, self.vocab, max_len))
# Get data for the masked language model task
examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)
+ (segments, is_next))
for tokens, segments, is_next in examples]
# Pad inputs
(self.all_token_ids, self.all_segments, self.valid_lens,
self.all_pred_positions, self.all_mlm_weights,
self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(
examples, max_len, self.vocab)
def __getitem__(self, idx):
return (self.all_token_ids[idx], self.all_segments[idx],
self.valid_lens[idx], self.all_pred_positions[idx],
self.all_mlm_weights[idx], self.all_mlm_labels[idx],
self.nsp_labels[idx])
def __len__(self):
return len(self.all_token_ids)
.. raw:: html
.. raw:: html
By using the ``_read_wiki`` function and the ``_WikiTextDataset`` class,
we define the following ``load_data_wiki`` to download and WikiText-2
dataset and generate pretraining examples from it.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def load_data_wiki(batch_size, max_len):
"""Load the WikiText-2 dataset."""
num_workers = d2l.get_dataloader_workers()
data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')
paragraphs = _read_wiki(data_dir)
train_set = _WikiTextDataset(paragraphs, max_len)
train_iter = torch.utils.data.DataLoader(train_set, batch_size,
shuffle=True, num_workers=num_workers)
return train_iter, train_set.vocab
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
#@save
def load_data_wiki(batch_size, max_len):
"""Load the WikiText-2 dataset."""
num_workers = d2l.get_dataloader_workers()
data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')
paragraphs = _read_wiki(data_dir)
train_set = _WikiTextDataset(paragraphs, max_len)
train_iter = gluon.data.DataLoader(train_set, batch_size, shuffle=True,
num_workers=num_workers)
return train_iter, train_set.vocab
.. raw:: html
.. raw:: html
Setting the batch size to 512 and the maximum length of a BERT input
sequence to be 64, we print out the shapes of a minibatch of BERT
pretraining examples. Note that in each BERT input sequence, :math:`10`
(:math:`64 \times 0.15`) positions are predicted for the masked language
modeling task.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, max_len = 512, 64
train_iter, vocab = load_data_wiki(batch_size, max_len)
for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,
mlm_Y, nsp_y) in train_iter:
print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
nsp_y.shape)
break
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip...
torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, max_len = 512, 64
train_iter, vocab = load_data_wiki(batch_size, max_len)
for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,
mlm_Y, nsp_y) in train_iter:
print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
nsp_y.shape)
break
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Downloading ../data/wikitext-2-v1.zip from https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip...
[21:52:02] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
(512, 64) (512, 64) (512,) (512, 10) (512, 10) (512, 10) (512,)
.. raw:: html
.. raw:: html
In the end, let’s take a look at the vocabulary size. Even after
filtering out infrequent tokens, it is still over twice larger than that
of the PTB dataset.
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
len(vocab)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
20256
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
len(vocab)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
20256
.. raw:: html
.. raw:: html
Summary
-------
- Comparing with the PTB dataset, the WikiText-2 dateset retains the
original punctuation, case and numbers, and is over twice larger.
- We can arbitrarily access the pretraining (masked language modeling
and next sentence prediction) examples generated from a pair of
sentences from the WikiText-2 corpus.
Exercises
---------
1. For simplicity, the period is used as the only delimiter for
splitting sentences. Try other sentence splitting techniques, such as
the spaCy and NLTK. Take NLTK as an example. You need to install NLTK
first: ``pip install nltk``. In the code, first ``import nltk``.
Then, download the Punkt sentence tokenizer:
``nltk.download('punkt')``. To split sentences such as
``sentences = 'This is great ! Why not ?'``, invoking
``nltk.tokenize.sent_tokenize(sentences)`` will return a list of two
sentence strings: ``['This is great !', 'Why not ?']``.
2. What is the vocabulary size if we do not filter out any infrequent
token?
.. raw:: html