pytorch之dataloader,enumerate
·
pytorch之dataloader,enumerate
from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3],
[4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a, b)#封装数据a与标签b
# 切片输出
print(train_ids[0:2])
print('='* 80)
# 循环取数据
for x_train, y_label in train_ids:
print(x_train, y_label)
# DataLoader进行数据封装
print('=' * 80)
train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader):
# 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
x_data, label = data
print(' batch:{0}\n x_data:{1}\nlabel: {2}'.format(i, x_data, label))
for i, data in enumerate(train_loader,5):
# 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
x_data, label = data
print(' batch:{0}\n x_data:{1}\nlabel: {2}'.format(i, x_data, label))
Dataloader:传入数据(这个数据包括:训练数据和标签),
batchsize代表的是每次取出4个样本数据。本例题中一共12个样本,因此迭代3次即可全部取出,迭代结束。
enumerate:返回值有两个:一个是序号,一个是数据train_ids
输出结果如下图:
也可如下代码,进行迭代:
for i, data in enumerate(train_loader,5):
# 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
x_data, label = data
print(' batch:{0}\n x_data:{1}\nlabel: {2}'.format(i, x_data, label))
for i, data in enumerate(train_loader,1):此代码中5,是batch从5开始,batch仍然是3个。运行结果如下:
推荐内容
更多推荐
相关推荐
查看更多
ai-hedge-fund

AI 对冲基金原理验证项目,多智能体协作模拟交易决策,用于教育目的
fastapi_mcp

一种零配置工具,用于自动将 FastAPI 端点公开为模型上下文协议 (MCP) 工具。
fumadocs

用于在 Next.js 中构建文档网站的框架。
热门开源项目
活动日历
查看更多
直播时间 2025-04-09 14:34:18

樱花限定季|G-Star校园行&华中师范大学专场
直播时间 2025-04-07 14:51:20

樱花限定季|G-Star校园行&华中农业大学专场
直播时间 2025-03-26 14:30:09

开源工业物联实战!
直播时间 2025-03-25 14:30:17

Heygem.ai数字人超4000颗星火燎原!
直播时间 2025-03-13 18:32:35

全栈自研企业级AI平台:Java核心技术×私有化部署实战
所有评论(0)