引入

  该数据集收集了杰伦第一张专辑《Jay》到第十张专辑《跨时代》中的歌词。
  下载地址:https://codechina.csdn.net/mirrors/shusentang/dive-into-dl-pytorch/-/blob/master/data/jaychou_lyrics.txt.zip


参考文献:
【1】李沐、Aston Zhang等老师,动手学深度学习


1 原始数据处理

  函数参数包括原始数据集的选取范围及数据集的路径。返回值包括:
  1)idx2char_list:不重复字符列表;
  2)char2idx_dict:字符索引字典;
  3)dict_size:字典大小;
  4)char2idx_list:字符索引列表。
  函数返回值包括以上三个步骤的处理结果:

def load_jaychou_lyrics(tr_range=None, path="../Data/jaychou_lyrics.txt.zip"):
    """
    :param tr_range: 数据集选取范围
    :param path: 数据集存储路径
    """
    with zipfile.ZipFile(path) as zin:
        with zin.open('jaychou_lyrics.txt') as f:
            ori_data = f.read().decode("utf-8")

    ori_data = ori_data.replace("\n", " ").replace("\r", " ")

    """设置原始数据集的选取范围并选取"""
    if tr_range is None:
        tr_range = (0, len(ori_data))
    ori_data = ori_data[tr_range[0]: tr_range[1]]

    # 不重复字符列表
    idx2char_list = list(set(ori_data))

    # 字符索引字典
    char2idx_dict = dict([(char, i) for i, char in enumerate(idx2char_list)])

    # 字典大小,即不重复字符的数量
    dict_size = len(char2idx_dict)

    # 字符索引列表
    char2idx_list = [char2idx_dict[char] for char in ori_data]

    return idx2char_list, char2idx_dict, dict_size, char2idx_list

2 时序数据的采样

  时序数据的一个样本通常包含连续的字符。例如时间步数为 5 5 5时,样本序列相应为 5 5 5个字符。假设样本序列为“想”、“要”、“有”、“直”、“升”,则该样本的标签序列为这些字符分别在训练集中的下一个字符,例如“要”、“有”、“直”、“升”、“机”。
  接下来使用两种方式对时序数据采样。

2.1 随机采样

  在随机采样中,每个样本是原始序列上任意截取的一段序列。相邻的两个随机小批量在原始序列上的位置不一定相邻。因此无法用一个小批量最终时间步的隐藏状态来初始化下一个小批量的隐藏状态。
  在模型训练时,每次随机采样都需要重新初始化隐藏状态:

def load_jaychou_lyrics_iter_random(data_idx, batch_size=2, num_step=5,
                                    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    """
    :param data_idx: 数据选取索引
    :param batch_size: 批次大小
    :param num_step: 每个样本的时间步数
    :param device: 设备
    """
    # 减1是因为输出的索引x是相应输入的索引y+1
    num_data = (len(data_idx) - 1) // num_step
    num_epoch = num_data // batch_size
    idx = np.random.permutation(num_data)

    def _data(pos):
        return data_idx[pos: pos + num_step]

    for i in range(num_epoch):
        j = i * batch_size
        batch_idx = idx[j: j + batch_size]
        X = [_data(k * num_step) for k in batch_idx]
        Y = [_data(k * num_step + 1) for k in batch_idx]

        yield (torch.tensor(X, dtype=torch.float32, device=device),
               torch.tensor(Y, dtype=torch.float32, device=device))


if __name__ == '__main__':
    for (a, b) in load_jaychou_lyrics_iter_random(list(range(30))):
        print(a, "\n", b)

  输出如下:

tensor([[10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.]]) 
 tensor([[11., 12., 13., 14., 15.],
        [16., 17., 18., 19., 20.]])
tensor([[20., 21., 22., 23., 24.],
        [ 0.,  1.,  2.,  3.,  4.]]) 
 tensor([[21., 22., 23., 24., 25.],
        [ 1.,  2.,  3.,  4.,  5.]])

2.2 相邻采样

  这里的相邻采样是指:两个随机小批量在原始序列上的位置相毗邻。这时,就可以用一个小批量最终时间步的隐藏状态来初始化下一个小批量的隐藏状态,从而使下一个小批量的输出也取决于当前小批量的输入:

def load_jaychou_lyrics_iter_consecutive(data_idx, batch_size=2, num_step=5,
                                         device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    """
    :param data_idx: 数据选取索引
    :param batch_size: 批次大小
    :param num_step: 每个样本的时间步数
    :param device: 设备
    """
    data_idx = torch.tensor(data_idx, dtype=torch.float32, device=device)
    num_data = len(data_idx)
    num_batch = num_data // batch_size
    idx = data_idx[0: batch_size * num_batch].view(batch_size, num_batch)
    num_epoch = (num_batch - 1) // num_step
    for i in range(num_epoch):
        j = i * num_step
        X = idx[:, j: j + num_step]
        Y = idx[:, j + 1: j + num_step + 1]
        yield X, Y


if __name__ == '__main__':
    for (a, b) in load_jaychou_lyrics_iter_consecutive(list(range(30))):
        print(a, "\n", b)

  输出如下:

tensor([[ 0.,  1.,  2.,  3.,  4.],
        [15., 16., 17., 18., 19.]]) 
 tensor([[ 1.,  2.,  3.,  4.,  5.],
        [16., 17., 18., 19., 20.]])
tensor([[ 5.,  6.,  7.,  8.,  9.],
        [20., 21., 22., 23., 24.]]) 
 tensor([[ 6.,  7.,  8.,  9., 10.],
        [21., 22., 23., 24., 25.]])
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐