pytorch+VGG16卷积自编码器(用来训练图形生成器)

支持断点重训

支持动态学习率

# -*- coding: utf-8 -*-
"""
Created on Wed Apr 21 15:21:29 2021

@author: HUANGYANGLAI
"""
from time import *
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torch.utils.data as data
import torchvision
from torch.autograd import Variable
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.metrics import accuracy_score
import pickle
import cv2
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import itertools


CNN_embed_dim = 512      # latent dim extracted by 2D CNN
img_x, img_y = 224, 224  # resize video 2d frame size(可能更改图片尺寸)
dropout_p = 0        # dropout probability(随机失活)

#训练参数
k = 2         # 这里没用
epochs = 5000   # (迭代次数)
batchsize = 8  #(批处理)

learning_rate = 0.00005#本设计汇总可使用动态的方式来调整学习率

log_interval = 10 
flag=True#vgg官方数据初始化数据

resume=True#是否断点重新连接
###############################################################################进行自定义数据加载###############################
root='D://sign_first//video//'
print(root)

#定义读取文件的格式
def default_loader(path):
    return Image.open(path).convert('RGB')#路径必须指名哪一张图,不能是指定文件夹

#创建自己的类: MyDataset,这个类是继承的torch.utils.data.Dataset
class MyDataset(Dataset):
     #使用__init__()初始化一些需要传入的参数及数据集的调用
     #初始化文件路径或文件名列表
     def __init__(self,txt,imgs=None,transform=None,target_transform=None, loader=default_loader):
         super(MyDataset,self).__init__() #对继承自父类的属性进行初始化
         fh=open(txt,'r') #按照传入的路径和txt文本参数,以只读的方式打开这个文本
         imgs=[]
         for line in fh: #迭代该列表按行循环txt文本中的内容
             line=line.rstrip('\n') # 删除 本行string 字符串末尾的指定字符
             words=line.split(',')#通过每行的逗号来分开每一行的数
             imgs.append([words[0],int(words[1])])#word0图片信息,word1标签信息
         self.imgs=imgs
         self.transform = transform
         self.target_transform=target_transform
         self.loader=loader
         
 #使用__getitem__()对数据进行预处理并返回想要的信息
     def __getitem__(self,index):#用于按照索引读取每个元素的具体内容
         fn,label=self.imgs[index]
                         #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
         img=self.loader(fn)
                         #按照路径读取图片
         if self.transform is not None:
             img=self.transform(img)
             #print("数据标签转换",img)
             
             #数据标签转换成tensor
             
         return img,label #return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
     
     def __len__(self):#这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
         print("长度",len(self.imgs[0]))
         return len(self.imgs)
     

transform = transforms.Compose([transforms.Resize([img_x, img_y]),#改变形状
                                torchvision.transforms.Grayscale(num_output_channels=1),
                                transforms.ToTensor(),
                                #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

                                ])
       
train_data=MyDataset(txt=root+'train.txt', transform=transform)
#数据集制作完成
#数据加载器使用
data_loader = torch.utils.data.DataLoader(train_data, batch_size=batchsize, shuffle=True, num_workers=3) 
############################################################################数据定义及加载器结束###############################

class VGG16(nn.Module):
    def __init__(self,init_weights=True):
        super(VGG16, self).__init__()
        
        # 3 * 224 * 224
        self.conv1_1 = nn.Conv2d(1, 64, 3) # 64 * 222 * 222
        self.conv1_2 = nn.Conv2d(64, 64, 3, padding=(1, 1)) # 64 * 222* 222
        self.maxpool1 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 64 * 112 * 112
        
        self.conv2_1 = nn.Conv2d(64, 128, 3) # 128 * 110 * 110
        self.conv2_2 = nn.Conv2d(128, 128, 3, padding=(1, 1)) # 128 * 110 * 110
        self.maxpool2 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 128 * 56 * 56
        
        self.conv3_1 = nn.Conv2d(128, 256, 3) # 256 * 54 * 54
        self.conv3_2 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # 256 * 54 * 54
        self.conv3_3 = nn.Conv2d(256, 256, 3, padding=(1, 1)) # 256 * 54 * 54
        self.maxpool3 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 256 * 28 * 28
        
        self.conv4_1 = nn.Conv2d(256, 512, 3) # 512 * 26 * 26
        self.conv4_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 26 * 26
        self.conv4_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 26 * 26
        self.maxpool4 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 512 * 14 * 14
        
        self.conv5_1 = nn.Conv2d(512, 512, 3) # 512 * 12 * 12
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 12 * 12
        self.conv5_3 = nn.Conv2d(512, 512, 3, padding=(1, 1)) # 512 * 12 * 12
        self.maxpool5 = nn.MaxPool2d((2, 2), padding=(1, 1),return_indices=True) # pooling 512 * 7 * 7
        
        # view
        self.fc1 =nn.Sequential(nn.Linear(512 * 7 * 7, 250),
                                nn.ReLU(inplace=True),
                                nn.Linear(250, 250),
                                nn.ReLU(inplace=True),
                                )
        
        # self.fc1 = nn.Linear(512 * 7 * 7, 2)
        # self.fc2 = nn.Linear(2, )
        # self.fc3 = nn.Linear(4096, 1000)
        
        if init_weights:
            self._initialize_weights()
            
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
                
                
    def forward(self, x):
        
        # x.size(0)即为batch_size
        in_size = x.size(0)
        
        out = self.conv1_1(x) # 222
        #print("编码器第一次卷积1",out.size())
        out = F.relu(out)
        out = self.conv1_2(out) # 222
        #print("编码器第一次卷积2",out.size())
        out = F.relu(out)
        out,indices1 = self.maxpool1(out) # 112
        #print("编码器第一次池华",out.size())
        
        out = self.conv2_1(out) # 110
        out = F.relu(out)
        out = self.conv2_2(out) # 110
        out = F.relu(out)
       # print("编码器第二次卷积2",out.size())
        out ,indices2= self.maxpool2(out) # 56
        #print("编码器第二次池华",out.size())
        
        out = self.conv3_1(out) # 54
        out = F.relu(out)
        out = self.conv3_2(out) # 54
        out = F.relu(out)
        out = self.conv3_3(out) # 54
        out = F.relu(out)
        out ,indices3= self.maxpool3(out) # 28
        #print("编码器第三次池华",out.size())
        
        out = self.conv4_1(out) # 26
        out = F.relu(out)
        out = self.conv4_2(out) # 26
        out = F.relu(out)
        out = self.conv4_3(out) # 26
        out = F.relu(out)
        out ,indices4= self.maxpool4(out) # 14
        #print("编码器第四次池华",out.size())
        
        out = self.conv5_1(out) # 12
        out = F.relu(out)
        out = self.conv5_2(out) # 12
        out = F.relu(out)
        out = self.conv5_3(out) # 12
        #print("编码器第五次卷积3",out.size())
        out = F.relu(out)
        out,indices5 = self.maxpool5(out) # 7
        print("编码器第五次池华",out.size())
       # print("indices5",indices5.size())
        # 展平
        #out = F.dropout(out, p=0.5, training=self.training)
        #out = F.relu(out)
        
        out = out.view(in_size, -1)
        
        out = self.fc1(out)

        
        return out,indices1,indices2,indices3,indices4,indices5
    
class DecoderCNN1(nn.Module):
    def __init__(self):
        super(DecoderCNN1,self).__init__()
        self.fc=nn.Sequential(
            nn.Linear(250,512 * 7 * 7),
            nn.ReLU(inplace=True),
            )
        #######第一次反池华
        self.unpool1= nn.MaxUnpool2d(2, stride=2, padding=1)
        #第一批反卷积1
        self.convtran1=nn.Sequential(
            nn.ConvTranspose2d(512,512,3,1,1,0),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            )
        #第一批反卷积2
        self.convtran12=nn.Sequential(
            nn.ConvTranspose2d(512,512,3,1,1,0),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            )
        #第一批反卷积3
        self.convtran13=nn.Sequential(
            nn.ConvTranspose2d(512,512,3,1,0,0),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            )
        #######第二次反池华
        self.unpool2= nn.MaxUnpool2d(2, stride=2, padding=1)
        #第二批反卷积1
        self.convtran2=nn.Sequential(
            nn.ConvTranspose2d(512,512,3,1,1,0),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            )
        #第二批反卷积2
        self.convtran22=nn.Sequential(
            nn.ConvTranspose2d(512,512,3,1,1,0),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            )
        #第二批反卷积3
        self.convtran23=nn.Sequential(
            nn.ConvTranspose2d(512,256,3,1,0,0),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            )
        #######第三次反池华
        self.unpool3= nn.MaxUnpool2d(2, stride=2, padding=1)
        #第三批反卷积1
        self.convtran3=nn.Sequential(
            nn.ConvTranspose2d(256,256,3,1,1,0),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            )
        #第三批反卷积2
        self.convtran32=nn.Sequential(
            nn.ConvTranspose2d(256,256,3,1,1,0),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            )
        #第三批反卷积3
        self.convtran33=nn.Sequential(
            nn.ConvTranspose2d(256,128,3,1,0,0),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            )
        #######第四次反池华
        self.unpool4= nn.MaxUnpool2d(2, stride=2, padding=1)
        #第四批反卷积1
        self.convtran4=nn.Sequential(
            nn.ConvTranspose2d(128,128,3,1,1,0),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            )
        #第四批反卷积2
        self.convtran42=nn.Sequential(
            nn.ConvTranspose2d(128,64,3,1,0,0),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            )
        #######第五次反池华
        self.unpool5= nn.MaxUnpool2d(2, stride=2, padding=1)
        #第五批反卷积1
        self.convtran5=nn.Sequential(
            nn.ConvTranspose2d(64,64,3,1,1,0),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            )
        #第五批反卷积2
        self.convtran52=nn.Sequential(
            nn.ConvTranspose2d(64,1,3,1,0,0),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
            )
    def forward(self, x,indices1,indices2,indices3,indices4,indices5):
        in_size = x.size(0)
        x=self.fc(x)
        x=x.view(in_size,512,7,7)
        #print("解码器输入",x.size())
        out=self.unpool1(x,indices5)
        #print("第一次最大池华",out.size())
        out=self.convtran1(out)
        out=self.convtran12(out)
        out=self.convtran13(out)
        
        out=self.unpool2(out,indices4)
        out=self.convtran2(out)
        out=self.convtran22(out)
        out=self.convtran23(out)
        
        out=self.unpool3(out,indices3)
        out=self.convtran3(out)
        out=self.convtran32(out)
        out=self.convtran33(out)
        
        out=self.unpool4(out,indices2)
        out=self.convtran4(out)
        out=self.convtran42(out)
        
        out=self.unpool5(out,indices1)
        out=self.convtran5(out)
        out=self.convtran52(out)
        
        return out
    
###############################################################开始定义学习率等等#############################################
use_cuda = torch.cuda.is_available()                   
device = torch.device("cuda" if use_cuda else "cpu")   

cnn_encoder = VGG16(init_weights=flag).to(device)
cnn_decoder=DecoderCNN1().to(device)

#c_params=list(cnn_encoder.parameters())+list(cnn_decoder.parameters())

optimizer = torch.optim.Adam(itertools.chain(cnn_encoder.parameters(),cnn_decoder.parameters()), lr=learning_rate)#优化cnn编码器和rnn解码器的参数
loss_func=torch.nn.MSELoss()    

   
if __name__=="__main__":
    
    if resume:
        #恢复上次的训练状态
        print("Resume from checkpoint...")
        checkpoint=torch.load('D:/sign_first/video_code/checkpoint/check_point.pth')
        cnn_encoder.load_state_dict(checkpoint['model1_state_dict'])
        cnn_decoder.load_state_dict(checkpoint['model2_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        initepoch=checkpoint['epoch']+1
        print("断点是第多少次",initepoch)
    
    
    
    for epoch in range(initepoch,epochs):
        print("epoch:",epoch)
        
        for step,(b_x,b_y) in enumerate(data_loader):
            print("step:",step)
            #print("b_x",b_x.size())
            optimizer.zero_grad()#老三步
            out,indices1,indices2,indices3,indices4,indices5=cnn_encoder(b_x)
            weneed=cnn_decoder(out,indices1,indices2,indices3,indices4,indices5)
            #print("weneed",weneed.size())

            loss=loss_func(weneed,b_x)
            print("loss",loss)
            loss.backward()
            optimizer.step()#老三步

        '''  
        for p in optimizer.param_groups:
            print("#######################################")
            #p['lr'] *= 0.9
            if(loss>0.1):
                p['lr'] =0.1
            elif((loss>0.01)and(loss<0.09)):
                p['lr'] =0.01
            elif((loss>0.001)and(loss<0.009)):
                p['lr'] =0.001
            elif((loss>0.0001)and(loss<0.0009)):
                p['lr'] =0.0001
            elif((loss>0.00001)and(loss<0.00009)):
                p['lr'] =0.00005
            # elif((loss<0.005)&(loss>0.00001)):
            #     p['lr'] =0.00009
            
            print("lr的学习率是",optimizer.state_dict()['param_groups'][0]['lr'])
        '''
        
        
        '''
        #用来断点续训部分
        '''
        checkpoint={
            'epoch':epoch,
            'model1_state_dict':cnn_encoder.state_dict(),
            'model2_state_dict':cnn_decoder.state_dict(),
            'optimizer_state_dict':optimizer.state_dict(),
            }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(checkpoint,'D:/sign_first/video_code/checkpoint/check_point.pth')
        
        
        
        torch.save(cnn_encoder,'encodernet1.pkl')
        torch.save(cnn_decoder,'decodernet2.pkl')
        
        

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐