14.3. 用于预训练词嵌入的数据集
现在我们已经了解了word2vec模型的技术细节和大致的训练方法,让我们来看看它们的实现。具体地说,我们将以 14.1节的跳元模型和 14.2节的负采样为例。本节从用于预训练词嵌入模型的数据集开始:数据的原始格式将被转换为可以在训练期间迭代的小批量。

import math
import os
import random
from mxnet import gluon, np
from d2l import mxnet as d2l
import math
import os
import random
import torch
from d2l import torch as d2l
import warnings
from d2l import paddle as d2l

import math
import os
import random
import paddle

14.3.1. 读取数据集

我们在这里使用的数据集是Penn Tree Bank(PTB)。该语料库取自“华尔街日报”的文章,分为训练集、验证集和测试集。在原始格式中,文本文件的每一行表示由空格分隔的一句话。在这里,我们将每个单词视为一个词元。

d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',

def read_ptb():
    data_dir = d2l.download_extract('ptb')
    # Readthetrainingset.
    with open(os.path.join(data_dir, 'ptb.train.txt')) as f:
        raw_text = f.read()
    return [line.split() for line in raw_text.split('\n')]

sentences = read_ptb()
f'# sentences数: {len(sentences)}'
Downloading ../data/ptb.zip from http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip...
'# sentences数: 42069'
d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',

def read_ptb():
    data_dir = d2l.download_extract('ptb')
    # Readthetrainingset.
    with open(os.path.join(data_dir, 'ptb.train.txt')) as f:
        raw_text = f.read()
    return [line.split() for line in raw_text.split('\n')]

sentences = read_ptb()
f'# sentences数: {len(sentences)}'
Downloading ../data/ptb.zip from http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip...
'# sentences数: 42069'
d2l.DATA_HUB['ptb'] = (d2l.DATA_URL + 'ptb.zip',

def read_ptb():
    data_dir = d2l.download_extract('ptb')
    # Readthetrainingset.
    with open(os.path.join(data_dir, 'ptb.train.txt')) as f:
        raw_text = f.read()
    return [line.split() for line in raw_text.split('\n')]

sentences = read_ptb()
f'# sentences数: {len(sentences)}'
'# sentences数: 42069'


vocab = d2l.Vocab(sentences, min_freq=10)
f'vocab size: {len(vocab)}'
'vocab size: 6719'
vocab = d2l.Vocab(sentences, min_freq=10)
f'vocab size: {len(vocab)}'
'vocab size: 6719'
vocab = d2l.Vocab(sentences, min_freq=10)
f'vocab size: {len(vocab)}'
'vocab size: 6719'

14.3.2. 下采样

文本数据通常有“the”“a”和“in”等高频词:它们在非常大的语料库中甚至可能出现数十亿次。然而,这些词经常在上下文窗口中与许多不同的词共同出现,提供的有用信息很少。例如,考虑上下文窗口中的词“chip”:直观地说,它与低频单词“intel”的共现比与高频单词“a”的共现在训练中更有用。此外,大量(高频)单词的训练速度很慢。因此,当训练词嵌入模型时,可以对高频单词进行下采样 (Mikolov et al., 2013)。具体地说,数据集中的每个词\(w_i\)将有概率地被丢弃

(14.3.1)\[P(w_i) = \max\left(1 - \sqrt{\frac{t}{f(w_i)}}, 0\right),\]

其中\(f(w_i)\)\(w_i\)的词数与数据集中的总词数的比率,常量\(t\)是超参数(在实验中为\(10^{-4}\))。我们可以看到,只有当相对比率\(f(w_i) > t\)时,(高频)词\(w_i\)才能被丢弃,且该词的相对比率越高,被丢弃的概率就越大。

def subsample(sentences, vocab):
    # 排除未知词元'<unk>'
    sentences = [[token for token in line if vocab[token] != vocab.unk]
                 for line in sentences]
    counter = d2l.count_corpus(sentences)
    num_tokens = sum(counter.values())

    # 如果在下采样期间保留词元,则返回True
    def keep(token):
        return(random.uniform(0, 1) <
               math.sqrt(1e-4 / counter[token] * num_tokens))

    return ([[token for token in line if keep(token)] for line in sentences],

subsampled, counter = subsample(sentences, vocab)
def subsample(sentences, vocab):
    # 排除未知词元'<unk>'
    sentences = [[token for token in line if vocab[token] != vocab.unk]
                 for line in sentences]
    counter = d2l.count_corpus(sentences)
    num_tokens = sum(counter.values())

    # 如果在下采样期间保留词元,则返回True
    def keep(token):
        return(random.uniform(0, 1) <
               math.sqrt(1e-4 / counter[token] * num_tokens))

    return ([[token for token in line if keep(token)] for line in sentences],

subsampled, counter = subsample(sentences, vocab)
def subsample(sentences, vocab):
    # 排除未知词元'<unk>'
    sentences = [[token for token in line if vocab[token] != vocab.unk]
                 for line in sentences]
    counter = d2l.count_corpus(sentences)
    num_tokens = sum(counter.values())

    # 如果在下采样期间保留词元,则返回True
    def keep(token):
        return(random.uniform(0, 1) <
               math.sqrt(1e-4 / counter[token] * num_tokens))

    return ([[token for token in line if keep(token)] for line in sentences],

subsampled, counter = subsample(sentences, vocab)


    ['origin', 'subsampled'], '# tokens per sentence',
    'count', sentences, subsampled);
    ['origin', 'subsampled'], '# tokens per sentence',
    'count', sentences, subsampled);
    ['origin', 'subsampled'], '# tokens per sentence',
    'count', sentences, subsampled);


def compare_counts(token):
    return (f'"{token}"的数量:'
            f'之前={sum([l.count(token) for l in sentences])}, '
            f'之后={sum([l.count(token) for l in subsampled])}')

'"the"的数量:之前=50770, 之后=2063'
def compare_counts(token):
    return (f'"{token}"的数量:'
            f'之前={sum([l.count(token) for l in sentences])}, '
            f'之后={sum([l.count(token) for l in subsampled])}')

'"the"的数量:之前=50770, 之后=2056'
def compare_counts(token):
    return (f'"{token}"的数量:'
            f'之前={sum([l.count(token) for l in sentences])}, '
            f'之后={sum([l.count(token) for l in subsampled])}')

'"the"的数量:之前=50770, 之后=2017'


'"join"的数量:之前=45, 之后=45'
'"join"的数量:之前=45, 之后=45'
'"join"的数量:之前=45, 之后=45'


corpus = [vocab[line] for line in subsampled]
[[], [392, 2115, 145], [5277, 3054, 1580, 95]]
corpus = [vocab[line] for line in subsampled]
[[], [2115, 274, 406], [140, 3, 5277, 3054, 1580]]
corpus = [vocab[line] for line in subsampled]
[[], [2115], [140, 5277, 3054, 1580, 95]]

14.3.3. 中心词和上下文词的提取


def get_centers_and_contexts(corpus, max_window_size):
    centers, contexts = [], []
    for line in corpus:
        # 要形成“中心词-上下文词”对,每个句子至少需要有2个词
        if len(line) < 2:
        centers += line
        for i in range(len(line)):  # 上下文窗口中间i
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, i - window_size),
                                 min(len(line), i + 1 + window_size)))
            # 从上下文词中排除中心词
            contexts.append([line[idx] for idx in indices])
    return centers, contexts
def get_centers_and_contexts(corpus, max_window_size):
    centers, contexts = [], []
    for line in corpus:
        # 要形成“中心词-上下文词”对,每个句子至少需要有2个词
        if len(line) < 2:
        centers += line
        for i in range(len(line)):  # 上下文窗口中间i
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, i - window_size),
                                 min(len(line), i + 1 + window_size)))
            # 从上下文词中排除中心词
            contexts.append([line[idx] for idx in indices])
    return centers, contexts
def get_centers_and_contexts(corpus, max_window_size):
    centers, contexts = [], []
    for line in corpus:
        # 要形成“中心词-上下文词”对,每个句子至少需要有2个词
        if len(line) < 2:
        centers += line
        for i in range(len(line)):  # 上下文窗口中间i
            window_size = random.randint(1, max_window_size)
            indices = list(range(max(0, i - window_size),
                                 min(len(line), i + 1 + window_size)))
            # 从上下文词中排除中心词
            contexts.append([line[idx] for idx in indices])
    return centers, contexts


tiny_dataset = [list(range(7)), list(range(7, 10))]
print('数据集', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('中心词', center, '的上下文词是', context)
数据集 [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
中心词 0 的上下文词是 [1, 2]
中心词 1 的上下文词是 [0, 2, 3]
中心词 2 的上下文词是 [1, 3]
中心词 3 的上下文词是 [2, 4]
中心词 4 的上下文词是 [2, 3, 5, 6]
中心词 5 的上下文词是 [3, 4, 6]
中心词 6 的上下文词是 [5]
中心词 7 的上下文词是 [8]
中心词 8 的上下文词是 [7, 9]
中心词 9 的上下文词是 [8]
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('数据集', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('中心词', center, '的上下文词是', context)
数据集 [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
中心词 0 的上下文词是 [1]
中心词 1 的上下文词是 [0, 2]
中心词 2 的上下文词是 [0, 1, 3, 4]
中心词 3 的上下文词是 [2, 4]
中心词 4 的上下文词是 [3, 5]
中心词 5 的上下文词是 [4, 6]
中心词 6 的上下文词是 [5]
中心词 7 的上下文词是 [8, 9]
中心词 8 的上下文词是 [7, 9]
中心词 9 的上下文词是 [7, 8]
tiny_dataset = [list(range(7)), list(range(7, 10))]
print('数据集', tiny_dataset)
for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):
    print('中心词', center, '的上下文词是', context)
数据集 [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]
中心词 0 的上下文词是 [1]
中心词 1 的上下文词是 [0, 2, 3]
中心词 2 的上下文词是 [0, 1, 3, 4]
中心词 3 的上下文词是 [2, 4]
中心词 4 的上下文词是 [2, 3, 5, 6]
中心词 5 的上下文词是 [4, 6]
中心词 6 的上下文词是 [5]
中心词 7 的上下文词是 [8, 9]
中心词 8 的上下文词是 [7, 9]
中心词 9 的上下文词是 [7, 8]


all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
f'# “中心词-上下文词对”的数量: {sum([len(contexts) for contexts in all_contexts])}'
'# “中心词-上下文词对”的数量: 1502639'
all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
f'# “中心词-上下文词对”的数量: {sum([len(contexts) for contexts in all_contexts])}'
'# “中心词-上下文词对”的数量: 1499984'
all_centers, all_contexts = get_centers_and_contexts(corpus, 5)
f'# “中心词-上下文词对”的数量: {sum([len(contexts) for contexts in all_contexts])}'
'# “中心词-上下文词对”的数量: 1500181'

14.3.4. 负采样


class RandomGenerator:
    def __init__(self, sampling_weights):
        # Exclude
        self.population = list(range(1, len(sampling_weights) + 1))
        self.sampling_weights = sampling_weights
        self.candidates = []
        self.i = 0

    def draw(self):
        if self.i == len(self.candidates):
            # 缓存k个随机采样结果
            self.candidates = random.choices(
                self.population, self.sampling_weights, k=10000)
            self.i = 0
        self.i += 1
        return self.candidates[self.i - 1]
class RandomGenerator:
    def __init__(self, sampling_weights):
        # Exclude
        self.population = list(range(1, len(sampling_weights) + 1))
        self.sampling_weights = sampling_weights
        self.candidates = []
        self.i = 0

    def draw(self):
        if self.i == len(self.candidates):
            # 缓存k个随机采样结果
            self.candidates = random.choices(
                self.population, self.sampling_weights, k=10000)
            self.i = 0
        self.i += 1
        return self.candidates[self.i - 1]
class RandomGenerator:
    def __init__(self, sampling_weights):
        # Exclude
        self.population = list(range(1, len(sampling_weights) + 1))
        self.sampling_weights = sampling_weights
        self.candidates = []
        self.i = 0

    def draw(self):
        if self.i == len(self.candidates):
            # 缓存k个随机采样结果
            self.candidates = random.choices(
                self.population, self.sampling_weights, k=10000)
            self.i = 0
        self.i += 1
        return self.candidates[self.i - 1]

例如,我们可以在索引1、2和3中绘制10个随机变量\(X\),采样概率为\(P(X=1)=2/9, P(X=2)=3/9\)\(P(X=3)=4/9\),如下所示。

generator = RandomGenerator([2, 3, 4])
[generator.draw() for _ in range(10)]
[2, 2, 2, 3, 2, 1, 1, 2, 2, 1]
generator = RandomGenerator([2, 3, 4])
[generator.draw() for _ in range(10)]
[1, 2, 2, 3, 3, 3, 3, 2, 1, 2]
generator = RandomGenerator([2, 3, 4])
[generator.draw() for _ in range(10)]
[1, 2, 2, 1, 1, 1, 1, 1, 1, 3]

对于一对中心词和上下文词,我们随机抽取了K个(实验中为5个)噪声词。根据word2vec论文中的建议,将噪声词\(w\)的采样概率\(P(w)\)设置为其在字典中的相对频率,其幂为0.75 (Mikolov et al., 2013)

def get_negatives(all_contexts, vocab, counter, K):
    # 索引为1、2、...(索引0是词表中排除的未知标记)
    sampling_weights = [counter[vocab.to_tokens(i)]**0.75
                        for i in range(1, len(vocab))]
    all_negatives, generator = [], RandomGenerator(sampling_weights)
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            neg = generator.draw()
            # 噪声词不能是上下文词
            if neg not in contexts:
    return all_negatives

all_negatives = get_negatives(all_contexts, vocab, counter, 5)
def get_negatives(all_contexts, vocab, counter, K):
    # 索引为1、2、...(索引0是词表中排除的未知标记)
    sampling_weights = [counter[vocab.to_tokens(i)]**0.75
                        for i in range(1, len(vocab))]
    all_negatives, generator = [], RandomGenerator(sampling_weights)
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            neg = generator.draw()
            # 噪声词不能是上下文词
            if neg not in contexts:
    return all_negatives

all_negatives = get_negatives(all_contexts, vocab, counter, 5)
def get_negatives(all_contexts, vocab, counter, K):
    # 索引为1、2、...(索引0是词表中排除的未知标记)
    sampling_weights = [counter[vocab.to_tokens(i)]**0.75
                        for i in range(1, len(vocab))]
    all_negatives, generator = [], RandomGenerator(sampling_weights)
    for contexts in all_contexts:
        negatives = []
        while len(negatives) < len(contexts) * K:
            neg = generator.draw()
            # 噪声词不能是上下文词
            if neg not in contexts:
    return all_negatives

all_negatives = get_negatives(all_contexts, vocab, counter, 5)

14.3.5. 小批量加载训练实例


在小批量中,\(i^\mathrm{th}\)个样本包括中心词及其\(n_i\)个上下文词和\(m_i\)个噪声词。由于上下文窗口大小不同,\(n_i+m_i\)对于不同的\(i\)是不同的。因此,对于每个样本,我们在contexts_negatives个变量中将其上下文词和噪声词连结起来,并填充零,直到连结长度达到\(\max_i n_i+m_i\)(max_len)。为了在计算损失时排除填充,我们定义了掩码变量masks。在masks中的元素和contexts_negatives中的元素之间存在一一对应关系,其中masks中的0(否则为1)对应于contexts_negatives中的填充。



def batchify(data):
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += \
            [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (np.array(centers).reshape((-1, 1)), np.array(
        contexts_negatives), np.array(masks), np.array(labels))
def batchify(data):
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += \
            [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (torch.tensor(centers).reshape((-1, 1)), torch.tensor(
        contexts_negatives), torch.tensor(masks), torch.tensor(labels))
def batchify(data):
    max_len = max(len(c) + len(n) for _, c, n in data)
    centers, contexts_negatives, masks, labels = [], [], [], []
    for center, context, negative in data:
        cur_len = len(context) + len(negative)
        centers += [center]
        contexts_negatives += \
            [context + negative + [0] * (max_len - cur_len)]
        masks += [[1] * cur_len + [0] * (max_len - cur_len)]
        labels += [[1] * len(context) + [0] * (max_len - len(context))]
    return (paddle.to_tensor(centers).reshape((-1, 1)), paddle.to_tensor(
        contexts_negatives), paddle.to_tensor(masks), paddle.to_tensor(labels))


x_1 = (1, [2, 2], [3, 3, 3, 3])
x_2 = (1, [2, 2, 2], [3, 3])
batch = batchify((x_1, x_2))

names = ['centers', 'contexts_negatives', 'masks', 'labels']
for name, data in zip(names, batch):
    print(name, '=', data)
centers = [[1.]
contexts_negatives = [[2. 2. 3. 3. 3. 3.]
 [2. 2. 2. 3. 3. 0.]]
masks = [[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 0.]]
labels = [[1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0.]]
x_1 = (1, [2, 2], [3, 3, 3, 3])
x_2 = (1, [2, 2, 2], [3, 3])
batch = batchify((x_1, x_2))

names = ['centers', 'contexts_negatives', 'masks', 'labels']
for name, data in zip(names, batch):
    print(name, '=', data)
centers = tensor([[1],
contexts_negatives = tensor([[2, 2, 3, 3, 3, 3],
        [2, 2, 2, 3, 3, 0]])
masks = tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0]])
labels = tensor([[1, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0]])
x_1 = (1, [2, 2], [3, 3, 3, 3])
x_2 = (1, [2, 2, 2], [3, 3])
batch = batchify((x_1, x_2))

names = ['centers', 'contexts_negatives', 'masks', 'labels']
for name, data in zip(names, batch):
    print(name, '=', data)
centers = Tensor(shape=[2, 1], dtype=int64, place=Place(cpu), stop_gradient=True,
contexts_negatives = Tensor(shape=[2, 6], dtype=int64, place=Place(cpu), stop_gradient=True,
       [[2, 2, 3, 3, 3, 3],
        [2, 2, 2, 3, 3, 0]])
masks = Tensor(shape=[2, 6], dtype=int64, place=Place(cpu), stop_gradient=True,
       [[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0]])
labels = Tensor(shape=[2, 6], dtype=int64, place=Place(cpu), stop_gradient=True,
       [[1, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0]])

14.3.6. 整合代码


def load_data_ptb(batch_size, max_window_size, num_noise_words):
    sentences = read_ptb()
    vocab = d2l.Vocab(sentences, min_freq=10)
    subsampled, counter = subsample(sentences, vocab)
    corpus = [vocab[line] for line in subsampled]
    all_centers, all_contexts = get_centers_and_contexts(
        corpus, max_window_size)
    all_negatives = get_negatives(
        all_contexts, vocab, counter, num_noise_words)
    dataset = gluon.data.ArrayDataset(
        all_centers, all_contexts, all_negatives)
    data_iter = gluon.data.DataLoader(
        dataset, batch_size, shuffle=True,batchify_fn=batchify,
    return data_iter, vocab
def load_data_ptb(batch_size, max_window_size, num_noise_words):
    num_workers = d2l.get_dataloader_workers()
    sentences = read_ptb()
    vocab = d2l.Vocab(sentences, min_freq=10)
    subsampled, counter = subsample(sentences, vocab)
    corpus = [vocab[line] for line in subsampled]
    all_centers, all_contexts = get_centers_and_contexts(
        corpus, max_window_size)
    all_negatives = get_negatives(
        all_contexts, vocab, counter, num_noise_words)

    class PTBDataset(torch.utils.data.Dataset):
        def __init__(self, centers, contexts, negatives):
            assert len(centers) == len(contexts) == len(negatives)
            self.centers = centers
            self.contexts = contexts
            self.negatives = negatives

        def __getitem__(self, index):
            return (self.centers[index], self.contexts[index],

        def __len__(self):
            return len(self.centers)

    dataset = PTBDataset(all_centers, all_contexts, all_negatives)

    data_iter = torch.utils.data.DataLoader(
        dataset, batch_size, shuffle=True,
        collate_fn=batchify, num_workers=num_workers)
    return data_iter, vocab
def load_data_ptb(batch_size, max_window_size, num_noise_words):
    num_workers = d2l.get_dataloader_workers()
    sentences = read_ptb()
    vocab = d2l.Vocab(sentences, min_freq=10)
    subsampled, counter = subsample(sentences, vocab)
    corpus = [vocab[line] for line in subsampled]
    all_centers, all_contexts = get_centers_and_contexts(
        corpus, max_window_size)
    all_negatives = get_negatives(
        all_contexts, vocab, counter, num_noise_words)

    class PTBDataset(paddle.io.Dataset):
        def __init__(self, centers, contexts, negatives):
            assert len(centers) == len(contexts) == len(negatives)
            self.centers = centers
            self.contexts = contexts
            self.negatives = negatives

        def __getitem__(self, index):
            return (self.centers[index], self.contexts[index],

        def __len__(self):
            return len(self.centers)

    dataset = PTBDataset(all_centers, all_contexts, all_negatives)

    data_iter = paddle.io.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, return_list=True,
        collate_fn=batchify, num_workers=num_workers)
    return data_iter, vocab


data_iter, vocab = load_data_ptb(512, 5, 5)
for batch in data_iter:
    for name, data in zip(names, batch):
        print(name, 'shape:', data.shape)
centers shape: (512, 1)
contexts_negatives shape: (512, 60)
masks shape: (512, 60)
labels shape: (512, 60)
data_iter, vocab = load_data_ptb(512, 5, 5)
for batch in data_iter:
    for name, data in zip(names, batch):
        print(name, 'shape:', data.shape)
centers shape: torch.Size([512, 1])
contexts_negatives shape: torch.Size([512, 60])
masks shape: torch.Size([512, 60])
labels shape: torch.Size([512, 60])
data_iter, vocab = load_data_ptb(512, 5, 5)
for batch in data_iter:
    for name, data in zip(names, batch):
        print(name, 'shape:', data.shape)
centers shape: [512, 1]
contexts_negatives shape: [512, 60]
masks shape: [512, 60]
labels shape: [512, 60]

14.3.7. 小结

  • 高频词在训练中可能不是那么有用。我们可以对他们进行下采样,以便在训练中加快速度。

  • 为了提高计算效率,我们以小批量方式加载样本。我们可以定义其他变量来区分填充标记和非填充标记,以及正例和负例。

14.3.8. 练习

  1. 如果不使用下采样,本节中代码的运行时间会发生什么变化?

  2. RandomGenerator类缓存k个随机采样结果。将k设置为其他值,看看它如何影响数据加载速度。

  3. 本节代码中的哪些其他超参数可能会影响数据加载速度?