这篇博文将首先在掘金社区发布!

前言

通过前面的介绍,我们大概知道了我们pytorch的一些张量的基本概念,以及我们的梯度和张量复制的一些细节。 Tensor 在很大程度上与 numpy 非常相似。在某些情况下,我们甚至可以直接使用张量进行计算。现在我们来谈谈pytorch的一些基本用途。

毕竟,我们使用 pytorch 来构建我们的深度学习神经网络。嗯,在机器学习的简要概述中,深度学习其实是我们机器学习的一个分支,也就是有一个特殊点的机器学习。 sklearn和aruze云平台之前的机器学习步骤大致分为五个部分。其实他们在pytorch中是类似的,只不过算法换成了更抽象的神经网络。

所以我们可以大致把pytorch分成这几块

在这里,我们将主要关注数据加载和转换。

类型转换

一开始我们说tensor可以转换numpy数据,但有时我们需要处理文本、图片和声音。所以我们需要一个转换器(当然你可以把它转成numpy再转成张量,但那是你着急做的)

在此处使用工具包

张量视觉

例如,我们转换图像。

我们发现工具包下的内容还是很多的。 totensor() 可以直接转换(见源码和说明)

到这里我们就可以轻松完成改造了。

编写“链式改造”

有时我们可能需要多次转换。比如我们需要先改变一张图片的大小,然后再进行转换。这个时候,为了避免代码重复,我们这个时候还是可以这样做的。

从 torchvision 导入转换

张量_to u003d transforms.ToTensor()

compose u003d transforms.Compose([张量_to,])

image u003d Image.open("train/1/0BGHNV6P.jpg")

img u003d 撰写(图像)

打印(图片)

好吧,还有其他方法。我不会谈论它。你pycharm根本就出来了,还有注释。

类型转换其实很简单,对应的情况也很多。这里真的很难解释。

数据处理

众所周知,机器学习离不开数据、数据集。对于一些知名的网络模型或数据集,pytorch 提供了自动下载工具。

自包含数据集

这意味着pytorch会通过爬虫自动下载数据集合,然后打包给我们。

这也是使用 tensorvision

例如,下载 CIFAR10 数据集

训练_set u003d torchvision.datasets.CIFAR10(rootu003d"./dataset",trainu003dTrue,downloadu003dTrue)

tese_set u003d torchvision.datasets.CIFAR10(rootu003d"./dataset",trainu003dFalse,downloadu003dTrue)

直接,但是注意这里得到的数据集不是张量类型的,我们需要进行类型转换

从 torchvision 导入转换

trans u003d transforms.Compose([transforms.ToTensor()])

数据集 u003d torchvision.datasets.CIFAR10(rootu003d"./dataset",trainu003dFalse,transformu003dtrans,downloadu003dTrue)

数据加载

然后我们加载数据

这里用到utils下的工具

从 torch.utils.data 导入 DataLoader

从 torchvision 导入转换

trans u003d transforms.Compose([transforms.ToTensor()])

数据集 u003d torchvision.datasets.CIFAR10(rootu003d"./dataset",trainu003dFalse,transformu003dtrans,downloadu003dTrue)

数据加载器 u003d 数据加载器(数据集,批次_sizeu003d64)

这里主要介绍DataLoader的一些参数。

自定义获取数据

这是更原始的,即有时我们需要自己加载数据集,例如。

这是从网上下载的数据集。现在我们需要将它导入到我们的 pytorch 中。

此文件夹是标签名称,位于此数据集中。 1是1元的图片,100是100元的图片。

我这里直接给代码

从 torch.utils.data 导入数据集、数据加载器

从 torchvision 导入转换

导入我们

从 PIL 导入图像

通过Dataset获取数据

类我的数据集(数据集):

def __init__(self,RootDir,LabelDir):

self.RootDir u003d RootDir

self.LabelDir u003d 标签目录

self.transform u003d transforms.ToTensor()

self.ImagePathDir u003d os.path.join(self.RootDir,self.LabelDir)

self.ImageNameItems u003d os.listdir(self.ImagePathDir)

def __getitem__(自我,项目):

#item 是获取一个数据元素,懒惰模式。如果你想用我给你

项目名称 u003d self.ImageNameItems[项目]

ImagePathItem u003d os.path.join(self.RootDir,self.LabelDir,ItemName)

ItemGet u003d self.transform(Image.open(ImagePathItem).resize((500,500)))

ItemLabel u003d self.LabelDir

返回项目获取、项目标签

定义__len__(自我):

返回 len(self.ImageNameItems)

如果 __name__ u003du003d"__main__":

RootDir u003d "火车"

OneYuanLabel u003d "1"

HandoneYuanLabel u003d "100"

OneYuanData u003d MyDataset(RootDir,OneYuanLabel)

HandoneData u003d MyDataset(RootDir,HandoneYuanLabel)

DataGet u003d OneYuanData+HandoneData

train_data u003d DataLoader(datasetu003dDataGet,batch_sizeu003d18,shuffleu003dTrue,num_workersu003d0,drop_lastu003dTrue)

对于 train_data 中的数据:

imgs,标签 u003d 数据

打印(imgs.shape)

重点是我们继承Dataset,然后实现__ getitem()__ 这个神奇的方法。代码其实很简单。当我们得到我们路径的图片名,然后调用魔术方法,我们读取图片,直接转换成张量。其实这和前面得到的数据差不多,只是我们直接转换了一下。同时,这也是为什么我们使用 DataLoader 取出数据,而不是等待训练模型,它非常慢。

总结

这些都是最基本的操作,所以明天我们就来聊聊如何玩转神经网络和使用pytorch。这里,我们以CNN为例,构建CIFAR10模型。稍后我们将做一个小演示。

其实python的使用很简单,但是有很多前提条件。否则,很难理解。不像python的django和java的ssm springcloud等业务框架,背几个API和注解就可以了。很容易上手。当然,源代码不同。

点击阅读全文
Logo

学AI,认准AI Studio!GPU算力,限时免费领,邀请好友解锁更多惊喜福利 >>>

更多推荐