(一)HDF与h5文件,

基础知识,如何理解h5文件,以及如何读写链接
attetion:

  1. h5文件无法直接存储str型数据,需要先进行类型转换——用data.encode()进行类型转换
  2. 一个h5文件被像linux文件系统一样被组织起来:dataset是文件,group是文件夹,它下面可以包含多个文件夹(group)和多个文件(dataset)。
  3. dataset :简单来讲类似数组组织形式的数据集合,像 numpy 数组一样工作,一个dataset即一个numpy.ndarray。具体的dataset可以是图像、表格,甚至是pdf文件和excel。
  4. group:包含了其它 dataset(数组) 和 其它 group ,像字典一样工作。
    在这里插入图片描述

(二)创建自己的训练集

step1—建立h5格式数据集

1,导入相应库

import numpy as np
import h5py

2,将数据处理为需要的模式

def read_picture(path,n_C):
    import os
    from PIL import Image
    import numpy as np
    import matplotlib.pyplot as plt
    #function:读取path路径下的图片,并转为形状为[m,n_H,n_W,n_C]的数组
    #path:str,图片所在路径
    #n_C:int,图像维数,黑白图像输入1,rgb图像输入3
    #datas:返回维度为(m,n_H,n_W,n_C)的array(数组)矩阵
    datas=[]
    x_dirs=os.listdir(path)
    for x_file in x_dirs:
        fpath=os.path.join(path,x_file)
        if n_C == 1 :
            _x=Image.open(fpath).convert("L")
            #plt.imshow(_x,"gray")   #显示图像(只显示最后一张)
        elif n_C ==3:
            _x=Image.open(fpath)
            #plt.imshow(_x)         #显示图像(只显示最后一张)
        else:
            print("错误:图像维数错误")
        n_W=_x.size[0]
        n_H=_x.size[1]
        #若要对图像进行放大缩小,激活(去掉注释)以下函数
        '''
        rat=0.4          #放大/缩小倍数
        n_W=int(rat*n_W)
        n_H=int(rat*n_H)
        _x=_x.resize((n_W,n_H))  #直接给n_W,n_H赋值可将图像变为任意大小
        '''
        datas.append(np.array(_x))
        _x.close()  
    datas=np.array(datas)
    
    m=datas.shape[0]
    datas=datas.reshape((m,n_H,n_W,n_C))
    #print(datas.shape)
    
    return datas
def read_txt(path,pass_n,model=0):
    import os
    import numpy as np
    #function:读取文件夹内txt文件
    #path:str,txt文件所在文件夹
    #pass_n:int,从txt文件中第pass_n行开始读文件
    #model:int,两个模式,model=0,不按行列顺序存储数据,model=1,按行列顺序存储数据
    
    #datas:array,输出txt内数据到数组
    datas=[]
    #分隔符
    sym1="\n" 
    sym2=","      #此处为txt文件中每个数据的分隔符,可根据需要自己修改
    x_dirs=os.listdir(path)
    for x_file in x_dirs:
        f_path=os.path.join(path,x_file)
        x_text=open(f_path)
        for i in range(pass_n):
            next(x_text)
        
        if model == 0:
            #按txt文件中数据不分行列存储,形状为(m,n)
            #m为文件数,n为每个txt内的数据量
            x_data=x_text.read()
            x_text.close()
            #去除末尾空格与‘\n’
            x_data=x_data.strip()
            x_data=x_data.replace(sym1,sym2)
            #按分隔符划分数据
            x_data=x_data.split(sym2)
            #x_data=np.array(x_data)
            for i in range(len(x_data)):
                x_data[i]=x_data[i].encode()#之后要将其存为h5文件需要用byte编码
                
        elif model == 1:
        
            #txt文件中数据按行存储,形状为(m,n_r,n_c)
            #m为txt文件个数。每个txt文件中有n_r行,n_c列数据
            x_data=x_text.readlines()
            x_text.close()
            #对同一行数据进行处理
            for i in range(len(x_data)):
                #print(type(x_data[i]))
                x_data[i]=x_data[i].strip()
                x_data[i]=x_data[i].replace(sym1,sym2)
                #按分隔符划分数据
                x_data[i]=x_data[i].split(sym2)
        else:
            print("请输入正确的模式:")
            print("model=0:按文件排列")
            print("model=1:按行列排列")
            
        datas.append(x_data)
    #将datas转化为数组
    datas=np.array(datas)
    #print(datas.shape)
    
    return datas
#读取数据
datas=read_picture("0_Pic",1)
labels=read_txt("1_Label",1)
assert(datas.shape[0] == labels.shape[0])
print(datas.shape)
print(labels.shape)
#对数据进行随机处理
index=np.arange(datas.shape[0])
np.random.shuffle(index)
x_data=datas[index,:]
y_data=labels[index,:]

test_range=np.arange(int(datas.shape[0]*0.2))
train_range=np.arange(test_len,datas.shape[0])

x_train=x_data[train_range,:]
y_train=y_data[train_range,:]
x_test=x_data[test_range,:]
y_test=y_data[test_range,:]
print(x_train.shape)
print(type(x_train[1][1][1][0]))
print(type(y_train[1][0]))
print(type(x_train))
print(type(y_train))

#查看图像
import cv2
index=8
data_tmp=np.squeeze(x_test)
img=cv2.cvtColor(np.asarray(data_tmp[index]),cv2.COLOR_RGB2BGR)

point=y_test
point_color=(0,0,255)#BGR
cv2.circle(img,(int(float(point[index][0])),int(float(point[index][1]))),4,point_color,0)
cv2.circle(img,(int(float(point[index][2])),int(float(point[index][3]))),4,point_color,0)
cv2.circle(img,(int(float(point[index][4])),int(float(point[index][5]))),4,point_color,0)
cv2.circle(img,(int(float(point[index][6])),int(float(point[index][7]))),4,point_color,0)
cv2.circle(img,(int(float(point[index][8])),int(float(point[index][9]))),4,point_color,0)

cv2.imshow(str(index),img)
cv2.waitKey(0)
print("显示完成")

3,创建一个h5文件,并将排序好的数据写入

#创建一个新文件 create a new file
f=h5py.File('data.h5','w')
f.create_dataset('train_set_x',data=x_train)
f.create_dataset('train_set_y',data=y_train)
f.create_dataset('test_set_x',data=x_test)
f.create_dataset('test_set_y',data=y_test)
f.close()

这样就成功生成了一个数据集

step2—读取h5文件内数据

import numpy as np
import h5py
def load_dataset():
    train_dataset=h5py.File('data.h5','r')
    train_set_x_orig=np.array(train_dataset["train_set_x"][:])
    train_set_y_orig=np.array(train_dataset["train_set_y"][:])
    
    test_set_x_orig=np.array(train_dataset["test_set_x"][:])
    test_set_y_orig=np.array(train_dataset["test_set_y"][:])
    return train_set_x_orig, train_set_y_orig, test_set_x_orig, test_set_y_orig

x_train,y_train,x_test,y_test=load_dataset()

数据集建立完,可导入结构内进行训练了

Logo

更多推荐