## python
import h5py
import random
import torch
import numpy as np
import pickle
## pytorch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torchvision.transforms.functional as TF
from torch import nn
from tensorflow.keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
def load_image(image_file):
    f = h5py.File(image_file, 'r')
    img_np = f['img'][()]
    img_np = (img_np / 255.0).astype('float32')
    return img_np

def load_mask(image_path, img_id, attribute='pigment_network'):
    if attribute == 'all':
        mask_file = image_path + '%s_attribute_all.h5' % (img_id)
        f = h5py.File(mask_file, 'r')
        mask_np = f['img'][()]
    else:
        mask_file = image_path + '%s_attribute_%s.h5' % (img_id, mask_attr)
        f = h5py.File(mask_file, 'r')
        mask_np = f['img'][()]

    mask_np = mask_np.astype('uint8')
    return mask_np
class SkinDataset(Dataset):
#train_test_id:所有数据的id数组
#image_path:所有文件都存在这个文件夹中
#train_test_split_file:将训练数据和测试数据分开
#attribute:masks的种类数
    def __init__(self,train_test_id,image_path,train_test_split_file,train=True,attribute=None,transform=None,num_classes = None):
        self.train_test_id = train_test_id
        self.image_path  = image_path
        self.attribute = attribute
        self.attr_types = ['pigment_network', 'negative_network', 'streaks', 'milia_like_cyst', 'globules']
        self.train = train
        self.transform = transform
        self.num_classes = num_classes
        
        with open(train_test_split_file,'rb') as f:
            self.mask_ind = pickle.load(f)
        #将数据讥分为 训练和测试
        if self.train:
            self.train_test_id = self.train_test_id[self.train_test_id['Split'] == 'train'].ID.values
            print('Train =', self.train, 'train_test_id.shape: ', self.train_test_id.shape)
        else:
            self.train_test_id = self.train_test_id[self.train_test_id['Split'] != 'train'].ID.values
            print('Train =', self.train, 'train_test_id.shape: ', self.train_test_id.shape)
        self.n = self.train_test_id.shape[0]
    def __len__():
        return self.n
    def transform_fn(self,image,mask):
        if  self.num_classes ==1:
            ### Converts a torch.*Tensor of shape C x H x W 
            # or a numpy ndarray of shape H x W x C to 
            # a PIL Image while preserving the value range.
            image = array_to_img(image,data_format='channels_last')
            mask = array_to_img(mask,data_format='channels_last')
            #随机水平翻转
            if random.random()>0.5:
                image = TF.hflip(image)
                mask = TF.hflip(mask)
            #随机垂直翻,剪切
            if random.random()>0.5:
                image = TF.vflip(image)
                mask = TF.vflip(mask)
            #随机角度旋转,缩放,剪切
            angle = random.randint(0, 90)
            ##.randint(0,90)
            translate = (random.uniform(0,100),random.uniform(0,100))
            scale = random.uniform(0.5,2)
            shear = random.uniform(-10,10)
            image = TF.affine(image,angle,translate,scale,shear)
            mask =  TF.affine(mask,angle,translate,scale,shear)

            ##随机调整亮度
            image = TF.adjust_brightness(image,saturation_factor=random.uniform(0.8,1.2))

            ## 随机调整饱和度
            image = TF.adjust_saturation(image,saturation_factor=random.uniform(0.8,1.2))

            ## 旋转
            angle = random.randint(0,90)
            image = TF.rotate(image,angle)
            mask = TF.rotate(mask,angle)
        
            ## 转为tensor
            image = img_to_array(image,data_format='channels_last')
            mask = img_to_array(mask,data_format='channels_last')
        else:
            image = array_to_img(image,data_format='channels_last')
            #将mask的每层都分割出来
            mask_pil_array=[None]*mask.shape[-1]
            for i in range(mask.shape[-1]):
                mask_pil_array[i] = array_to_img(mask[:,:,i,np.newaxis],data_format='channels_last')
             #随机水平翻转
            if random.random()>0.5:
                image = TF.hflip(image)
                for i in range(mask.shape[-1]):
                    mask_pil_array[i] = TF.hflip(mask_pil_array[i])
            
            #随机垂直翻,剪切
            if random.random()>0.5:
                image = TF.vflip(image)
                for i in range(mask.shape[-1]):
                    mask_pil_array[i] = TF.vflip(mask_pil_array[i])
            #随机角度旋转,缩放,剪切
            angle = random.randint(0, 90)
            ##.randint(0,90)
            translate = (random.uniform(0,100),random.uniform(0,100))
            scale = random.uniform(0.5,2)
            shear = random.uniform(0,0)
            image = TF.affine(image,angle,translate,scale,shear)
            for i in range(mask.shape[-1]):
                    mask_pil_array[i] = TF.affine(mask_pil_array[i],angle,translate,scale,shear)
            ##随机调整亮度
            image = TF.adjust_brightness(image,saturation_factor=random.uniform(0.8,1.2))

            ## 随机调整饱和度
            image = TF.adjust_saturation(image,saturation_factor=random.uniform(0.8,1.2))

            ## 转为tensor
            image = img_to_array(image,data_format='channels_last')
            for i in range(mask.shape[-1]):
                mask[:,:,i] = img_to_array(mask_pil_array[i],data_format='channels_last')[:,:,0].astype('uint8')
            
        image = (image/255).astype('float32')
        mask = (mask/255).astype('uint8')

        return image,mask

    def __getitem__(self,index):
        img_id = self.train_test_id[index]
        image_file = self.image_path="%s.h5" % img_id
        img_np = load_image(image_file)
        ## load masks
        mask_np = load_mask(self.image_path,img_id,self.attribute)
        if self.train:
            img_np,mask_np = self.transform_fn(img_np,mask_np)
            
        img_np = img_np.astype('float32')
        #mask_ind 就是那种train_test表
        #从这里面,找到那一行的内容
        ind = self.mask_ind.loc[index,self.attr_types].values.astype('uint8')
        print("imgnp:",img_np)
        print("mask_np:",mask_np)
        print("ind:",ind)
        return img_np,mask_np,ind
 
Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐