pytorch的基本使用(数据加载、类型转换)
这篇博文将首先在掘金社区发布! 前言 通过前面的介绍,我们大概知道了我们pytorch的一些张量的基本概念,以及我们的梯度和张量复制的一些细节。 Tensor 在很大程度上与 numpy 非常相似。在某些情况下,我们甚至可以直接使用张量进行计算。现在我们来谈谈pytorch的一些基本用途。 毕竟,我们使用 pytorch 来构建我们的深度学习神经网络。嗯,在机器学习的简要概述中,深度学习其实是我们
这篇博文将首先在掘金社区发布!
前言
通过前面的介绍,我们大概知道了我们pytorch的一些张量的基本概念,以及我们的梯度和张量复制的一些细节。 Tensor 在很大程度上与 numpy 非常相似。在某些情况下,我们甚至可以直接使用张量进行计算。现在我们来谈谈pytorch的一些基本用途。
毕竟,我们使用 pytorch 来构建我们的深度学习神经网络。嗯,在机器学习的简要概述中,深度学习其实是我们机器学习的一个分支,也就是有一个特殊点的机器学习。 sklearn和aruze云平台之前的机器学习步骤大致分为五个部分。其实他们在pytorch中是类似的,只不过算法换成了更抽象的神经网络。
所以我们可以大致把pytorch分成这几块
在这里,我们将主要关注数据加载和转换。
类型转换
一开始我们说tensor可以转换numpy数据,但有时我们需要处理文本、图片和声音。所以我们需要一个转换器(当然你可以把它转成numpy再转成张量,但那是你着急做的)
在此处使用工具包
张量视觉
例如,我们转换图像。
我们发现工具包下的内容还是很多的。 totensor() 可以直接转换(见源码和说明)
到这里我们就可以轻松完成改造了。
编写“链式改造”
有时我们可能需要多次转换。比如我们需要先改变一张图片的大小,然后再进行转换。这个时候,为了避免代码重复,我们这个时候还是可以这样做的。
从 torchvision 导入转换
张量_to u003d transforms.ToTensor()
compose u003d transforms.Compose([张量_to,])
image u003d Image.open("train/1/0BGHNV6P.jpg")
img u003d 撰写(图像)
打印(图片)
好吧,还有其他方法。我不会谈论它。你pycharm根本就出来了,还有注释。
类型转换其实很简单,对应的情况也很多。这里真的很难解释。
数据处理
众所周知,机器学习离不开数据、数据集。对于一些知名的网络模型或数据集,pytorch 提供了自动下载工具。
自包含数据集
这意味着pytorch会通过爬虫自动下载数据集合,然后打包给我们。
这也是使用 tensorvision
例如,下载 CIFAR10 数据集
训练_set u003d torchvision.datasets.CIFAR10(rootu003d"./dataset",trainu003dTrue,downloadu003dTrue)
tese_set u003d torchvision.datasets.CIFAR10(rootu003d"./dataset",trainu003dFalse,downloadu003dTrue)
直接,但是注意这里得到的数据集不是张量类型的,我们需要进行类型转换
从 torchvision 导入转换
trans u003d transforms.Compose([transforms.ToTensor()])
数据集 u003d torchvision.datasets.CIFAR10(rootu003d"./dataset",trainu003dFalse,transformu003dtrans,downloadu003dTrue)
数据加载
然后我们加载数据
这里用到utils下的工具
从 torch.utils.data 导入 DataLoader
从 torchvision 导入转换
trans u003d transforms.Compose([transforms.ToTensor()])
数据集 u003d torchvision.datasets.CIFAR10(rootu003d"./dataset",trainu003dFalse,transformu003dtrans,downloadu003dTrue)
数据加载器 u003d 数据加载器(数据集,批次_sizeu003d64)
这里主要介绍DataLoader的一些参数。
自定义获取数据
这是更原始的,即有时我们需要自己加载数据集,例如。
这是从网上下载的数据集。现在我们需要将它导入到我们的 pytorch 中。
此文件夹是标签名称,位于此数据集中。 1是1元的图片,100是100元的图片。
我这里直接给代码
从 torch.utils.data 导入数据集、数据加载器
从 torchvision 导入转换
导入我们
从 PIL 导入图像
通过Dataset获取数据
类我的数据集(数据集):
def __init__(self,RootDir,LabelDir):
self.RootDir u003d RootDir
self.LabelDir u003d 标签目录
self.transform u003d transforms.ToTensor()
self.ImagePathDir u003d os.path.join(self.RootDir,self.LabelDir)
self.ImageNameItems u003d os.listdir(self.ImagePathDir)
def __getitem__(自我,项目):
#item 是获取一个数据元素,懒惰模式。如果你想用我给你
项目名称 u003d self.ImageNameItems[项目]
ImagePathItem u003d os.path.join(self.RootDir,self.LabelDir,ItemName)
ItemGet u003d self.transform(Image.open(ImagePathItem).resize((500,500)))
ItemLabel u003d self.LabelDir
返回项目获取、项目标签
定义__len__(自我):
返回 len(self.ImageNameItems)
如果 __name__ u003du003d"__main__":
RootDir u003d "火车"
OneYuanLabel u003d "1"
HandoneYuanLabel u003d "100"
OneYuanData u003d MyDataset(RootDir,OneYuanLabel)
HandoneData u003d MyDataset(RootDir,HandoneYuanLabel)
DataGet u003d OneYuanData+HandoneData
train_data u003d DataLoader(datasetu003dDataGet,batch_sizeu003d18,shuffleu003dTrue,num_workersu003d0,drop_lastu003dTrue)
对于 train_data 中的数据:
imgs,标签 u003d 数据
打印(imgs.shape)
重点是我们继承Dataset,然后实现__ getitem()__ 这个神奇的方法。代码其实很简单。当我们得到我们路径的图片名,然后调用魔术方法,我们读取图片,直接转换成张量。其实这和前面得到的数据差不多,只是我们直接转换了一下。同时,这也是为什么我们使用 DataLoader 取出数据,而不是等待训练模型,它非常慢。
总结
这些都是最基本的操作,所以明天我们就来聊聊如何玩转神经网络和使用pytorch。这里,我们以CNN为例,构建CIFAR10模型。稍后我们将做一个小演示。
其实python的使用很简单,但是有很多前提条件。否则,很难理解。不像python的django和java的ssm springcloud等业务框架,背几个API和注解就可以了。很容易上手。当然,源代码不同。
更多推荐
所有评论(0)