改动点:

(1)把传统的卷积改造成深度可分离卷积;

(2)使用pytorch实现的ctc,不再使用百度开源的warpctc,主要原因是本人使用Windows来开发调试,编译warpctc貌似很麻烦;

 

crnn网络实现代码:

class BidirectionalLSTM(nn.Module):

    def __init__(self, nInput_size, nHidden,nOut):
        super(BidirectionalLSTM, self).__init__()

        self.lstm = nn.LSTM(nInput_size, nHidden, bidirectional=True)
        self.linear = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, (hidden,cell)= self.lstm(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T * b, h)

        output = self.linear(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1) #输出变换为[seq,batch,类别总数]

        return output



class CNN(nn.Module):

    def __init__(self,imageHeight,nChannel):
        super(CNN,self).__init__()
        assert imageHeight % 32 == 0,'image Height has to be a multiple of 32'

        self.depth_conv0 = nn.Conv2d(in_channels=nChannel,out_channels=nChannel,kernel_size=3,stride=1,padding=1,groups=nChannel)
        self.point_conv0 = nn.Conv2d(in_channels=nChannel,out_channels=64,kernel_size=1,stride=1,padding=0,groups=1)
        self.relu0 = nn.ReLU(inplace=True)
        self.pool0 = nn.MaxPool2d(kernel_size=2,stride=2)

        self.depth_conv1 = nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1,groups=64)
        self.point_conv1 = nn.Conv2d(in_channels=64,out_channels=128,kernel_size=1,stride=1,padding=0,groups=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(kernel_size=2,stride=2)

        self.depth_conv2 = nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,stride=1,padding=1,groups=128)
        self.point_conv2 = nn.Conv2d(in_channels=128,out_channels=256,kernel_size=1,stride=1,padding=0,groups=1)
        self.batchNorm2 = nn.BatchNorm2d(256)
        self.relu2 = nn.ReLU(inplace=True)

        self.depth_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=256)
        self.point_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=1, stride=1, padding=0, groups=1)
        self.relu3 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,1),padding=(0,1))

        self.depth_conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, groups=256)
        self.point_conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)
        self.batchNorm4 = nn.BatchNorm2d(512)
        self.relu4 = nn.ReLU(inplace=True)

        self.depth_conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, groups=512)
        self.point_conv5 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)
        self.relu5 = nn.ReLU(inplace=True)
        self.pool5 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,1),padding=(0,1))

        #self.conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, stride=1, padding=0)
        self.depth_conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=2, stride=1, padding=0, groups=512)
        self.point_conv6 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, groups=1)
        self.batchNorm6 = nn.BatchNorm2d(512)
        self.relu6= nn.ReLU(inplace=True)

    def forward(self,input):
        depth0 = self.depth_conv0(input)
        point0 = self.point_conv0(depth0)
        relu0 = self.relu0(point0)
        pool0 = self.pool0(relu0)
       # print(pool0.size())

        depth1 = self.depth_conv1(pool0)
        point1 = self.point_conv1(depth1)
        relu1 = self.relu1(point1)
        pool1 = self.pool1(relu1)
        #print(pool1.size())

        depth2 = self.depth_conv2(pool1)
        point2 = self.point_conv2(depth2)
        batchNormal2 = self.batchNorm2(point2)
        relu2 = self.relu2(batchNormal2)
        #print(relu2.size())

        depth3 = self.depth_conv3(relu2)
        point3 = self.point_conv3(depth3)
        relu3 = self.relu3(point3)
        pool3 = self.pool3(relu3)
        #print(pool3.size())

        depth4 = self.depth_conv4(pool3)
        point4 = self.point_conv4(depth4)
        batchNormal4 = self.batchNorm4(point4)
        relu4 = self.relu4(batchNormal4)
        #print(relu4.size())

        depth5 = self.depth_conv5(relu4)
        point5 = self.point_conv5(depth5)
        relu5 = self.relu5(point5)
        pool5 = self.pool5(relu5)
        #print(pool5.size())

        depth6 = self.depth_conv6(pool5)
        point6 = self.point_conv6(depth6)
        batchNormal6 = self.batchNorm6(point6)
        relu6 = self.relu6(batchNormal6)
        #print(relu6.size())

        return relu6

class CRNN(nn.Module):
    def __init__(self,imgHeight, nChannel, nClass, nHidden):
        super(CRNN,self).__init__()

        self.cnn = nn.Sequential(CNN(imgHeight, nChannel))
        self.lstm = nn.Sequential(
            BidirectionalLSTM(512, nHidden, nHidden),
            BidirectionalLSTM(nHidden, nHidden, nClass),
        )
    def forward(self,input):
        conv = self.cnn(input)
        # pytorch框架输出结构为BCHW
        batch,channel,height,width = conv.size()
        assert  height==1,"the output height must be 1."
        # 将height==1的维度去掉-->BCW
        conv = conv.squeeze(dim=2)
        # 调整各个维度的位置(B,C,W)->(W,B,C),对应lstm的输入(seq,batch,input_size)
        conv = conv.permute(2,0,1)

        output = self.lstm(conv)

        return  output

训练网络代码:

import os
import torch
import cv2
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from crnn_new import crnn
import time


# 调整图像大小和归一化操作
class resizeAndNormalize():
    def __init__(self,size,interpolation=cv2.INTER_LINEAR):
        # 注意对于opencv,size的格式是(w,h)
        self.size = size
        self.interpolation = interpolation
        # ToTensor属于类  """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
        self.toTensor = transforms.ToTensor()

    def __call__(self, image):
        # (x,y) 对于opencv来说,图像宽对应x轴,高对应y轴
        image = cv2.resize(image,self.size,interpolation=self.interpolation)
        #转为tensor的数据结构
        image = self.toTensor(image)
        #对图像进行归一化操作
        image = image.sub_(0.5).div_(0.5)
        return image

class CRNNDataSet(Dataset):
    def __init__(self,imageRoot,labelRoot):
        self.image_root = imageRoot
        self.image_dict = self.readfile(labelRoot)
        self.image_name = [fileName for fileName,_ in self.image_dict.items()]

    def __getitem__(self, index):
        image_path = os.path.join(self.image_root,self.image_name[index])
        keys = self.image_dict.get(self.image_name[index])
        label = [int(x) for x in keys]

        image = cv2.imread(image_path,cv2.IMREAD_GRAYSCALE)
        # if image is None:
        #     return None,None
        (height,width) = image.shape

        #由于crnn网络输入图像的高为32,故需要resize原始图像的height
        size_height = 32
        ratio = 32/float(height)
        size_width = int(ratio * width)
        transform = resizeAndNormalize((size_width,size_height))
        #图像预处理
        image = transform(image)
        #标签格式转换为IntTensor
        label = torch.IntTensor(label)

        return image,label

    def __len__(self):
        return len(self.image_name)

    def readfile(self,fileName):
        res = []
        with open(fileName, 'r') as f:
            lines = f.readlines()
            for line in lines:
                res.append(line.strip())
        dic = {}
        total = 0
        for line in res:
            part = line.split(' ')
            #由于会存在训练过程中取图像的时候图像不存在导致异常,所以在初始化的时候就判断图像是否存在
            if  not os.path.exists(os.path.join(self.image_root, part[0])):
                print(os.path.join(self.image_root, part[0]))
                total += 1
            else:
                dic[part[0]] = part[1:]
        print(total)

        return dic

trainData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
                          labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data.txt")

trainLoader = DataLoader(dataset=trainData,batch_size=30,shuffle=True,num_workers=0)

valData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
                          labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data_t.txt")

valLoader = DataLoader(dataset=valData,batch_size=1,shuffle=True,num_workers=1)

def decode(preds):
    pred = []
    for i in range(len(preds)):
        if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i-1])):
            pred.append(int(preds[i]))
    return pred


def val(model, loss_function, max_iteration,use_gpu=True):
    # 将模式切换为验证评估模式
    model.eval()
    k = 0
    totalloss = 0
    correct_num = 0
    total_num = 0
    val_iter = iter(valLoader)
    max_iter = min(max_iteration,len(valLoader))

    for i in range(max_iter):
        k = k + 1
        data,label = val_iter.next()
        labels = torch.IntTensor([])
        for j in range(label.size(0)):
            labels = torch.cat((labels,label[j]),0)

        if torch.cuda.is_available() and use_gpu:
            data = data.cuda()
        output = model(data)
        input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))
        target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))
        loss = loss_function(output,labels,input_lengths,target_lengths) /  label.size(0)
        totalloss += float(loss)
        pred_label = output.max(2)[1]
        pred_label = pred_label.transpose(1,0).contiguous().view(-1)
        pred = decode(pred_label)
        total_num += len(pred)
        for x,y in zip(pred,labels):
            if int(x) == int(y):
                correct_num += 1
    accuracy = correct_num / float(total_num) * 100
    test_loss = totalloss / k
    print('Test loss : %.3f , accuary : %.3f%%' % (test_loss, accuracy))


def train():
    use_gpu = True
    learning_rate = 0.0005
    weight_decay = 1e-4
    max_epoch = 10
    modelpath = 'F:\crnn_model\pytorch-crnn.pth'

    char_set = open('../train/char_std_5990.txt','r',encoding='utf-8').readlines()
    char_set = ''.join([ch.strip('\n') for ch in char_set[1:]] +['卍'])
    n_class = len(char_set)

    model = crnn.CRNN(imgHeight=32,nChannel=1,nClass=n_class,nHidden=256)
    if torch.cuda.is_available() and use_gpu:
        model.cuda()

    loss_func = torch.nn.CTCLoss(blank=n_class-1)
    optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate,weight_decay=weight_decay)

    if os.path.exists(modelpath):
        print("load model from %s" % modelpath)
        model.load_state_dict(torch.load(modelpath))
        print("done!")

    lossTotal = 0.0
    k = 0
    printInterval = 100
    valinterval = 1000
    start_time = time.time()
    for epoch in range(max_epoch):

        for i,(data,label) in enumerate(trainLoader):

            k = k + 1
            #开启训练模式
            model.train()

            labels = torch.IntTensor([])
            for j in range(label.size(0)):
                labels = torch.cat((labels,label[j]),0)

            if torch.cuda.is_available and use_gpu:
                data = data.cuda()
                loss_func = loss_func.cuda()
                labels = labels.cuda()

            output = model(data)


            #log_probs = output
            #example 建议使用这样,貌似直接把output送进去loss fun也没发现什么问题
            log_probs = output.log_softmax(2).detach().requires_grad_()
            targets = labels
            input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))
            target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))

            #forward(self, log_probs, targets, input_lengths, target_lengths)
            loss = loss_func(log_probs,targets,input_lengths,target_lengths) / label.size(0)
            lossTotal += float(loss)

            if k % printInterval == 0:

                print("[%d/%d] [%d/%d] loss:%f" % (
                epoch, max_epoch, i + 1, len(trainLoader), lossTotal/printInterval))
                lossTotal = 0.0
                torch.save(model.state_dict(), 'F:\crnn_model\pytorch-crnn.pth')

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if k % valinterval == 0:
                val(model,loss_func)

    end_time = time.time()
    print("takes {}s".format((end_time - start_time)))


if __name__ == '__main__':
    train()

测试代码:

import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
import torch
from config import opt
from crnn import crnn
from PIL import Image
from torchvision import transforms

class resizeNormalize(object):
	def __init__(self, size, interpolation=Image.BILINEAR):
		self.size = size
		self.interpolation = interpolation
		self.toTensor = transforms.ToTensor()

	def __call__(self, img):
		img = img.resize(self.size, self.interpolation)
		img = self.toTensor(img)
		img.sub_(0.5).div_(0.5)
		return img

def decode(preds,char_set):
	pred_text = ''
	for i in range(len(preds)):
		if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i-1])):
			pred_text += char_set[int(preds[i])-1]

	return pred_text

# test if crnn work

if __name__ == '__main__':

	imagepath = './12.jpg'

	img_h = opt.img_h
	use_gpu = opt.use_gpu
	modelpath = 'F:\crnn_model\pytorch-crnn-Copy68.pth'
	#modelpath = '../train/models/pytorch-crnn.pth'
	# modelpath = opt.modelpath
	char_set = open('char_std_5990.txt', 'r', encoding='utf-8').readlines()
	char_set = ''.join([ch.strip('\n') for ch in char_set[1:]] + ['卍'])
	n_class = len(char_set)
	print(n_class)

	from crnn_new import crnn
	model = crnn.CRNN(img_h, 1, n_class, 256)

	if os.path.exists(modelpath):
		print('Load model from "%s" ...' % modelpath)
		model.load_state_dict(torch.load(modelpath))
		print('Done!')

	if torch.cuda.is_available and use_gpu:
		model.cuda()



	image = Image.open(imagepath).convert('L')
	(w,h) = image.size
	size_h = 32
	ratio = size_h / float(h)
	size_w = int(w * ratio)
	# keep the ratio
	transform = resizeNormalize((size_w, size_h))
	image = transform(image)
	image = image.unsqueeze(0)
	if torch.cuda.is_available and use_gpu:
		image = image.cuda()
	model.eval()
	preds = model(image)
	preds = preds.max(2)
	preds = preds[1]
	preds = preds.squeeze()
	pred_text = decode(preds,char_set)
	print('predict == >',pred_text)

 

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐