1. pin_memory参数解析

由于从 CPU 数据转移至 GPU 时,位于pinned(或叫做page-locked) memory上的 Tensor 会更快,因此DataLoader里设置了这一选项, 如果pin_memory=True, 则在提供数据时, 调用Tensor的 .pinmemory() 方法提高数据转移速度. 但是, 该方法只对普通 Tensor 和包含 Tensor 的映射与容器等数据结构成立, 如果是自定义的数据 batch, 则需要特殊实现其 .pinmemory() 方法.

主机中的内存,有两种存在方式,一是锁页(page-locked),二是不锁页,锁页内存存放的内容在任何情况下都不会与主机的虚拟内存进行交换(注:虚拟内存就是硬盘),而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。显卡中的显存全部是锁页内存,当计算机的内存充足的时候,可以设置 pin_memory=True。当系统卡住,或者交换内存使用过多的时候,设置 pin_memory=False。因为 pin_memory 与电脑硬件性能有关,pytorch 开发者不能确保每一个炼丹玩家都有高端设备,因此 pin_memory 默认为False

2. num_workers参数解析

num_workers 是服务于多进程(multiprocessing)数据加载的, 用于设置有多少个子进程负责数据加载. num_workers并不是越大越好, 因为过多的子进程会占据 CPU 计算资源, 使得程序中其他在CPU上的计算部分变慢, 导致整体运行时间增加. 

一般来说是通过逐步增加尝试来进行设置, 比如当GPU计算利用率已经很饱和时, 说明数据读取足够满足计算需求, 则不必再增加worker数量.

3. collate_fn参数解析

默认情况下DataLoader将调用预置的 default_collate_fn, 将 Dataset 的返回的多个数据样本整理(collate)成为一个 batch. 在 collate 时, 会添加一个维度, 即批样本维度在数据的第一维, 可以看做这个操作即是 torch.stack 运算.

Logo

权威|前沿|技术|干货|国内首个API全生命周期开发者社区

更多推荐