这个时候测试集就要用到了,为了便于观察,我们这里先给测试集重命名,这样子,哪张图片分类错了,我们也比较好找。
重命名代码:

#!/usr/bin/python
# -*- coding:utf-8 -*-
import os

outer_path = 'D:/flowers/test'
folderlist = os.listdir(outer_path)  # 列举文件夹

for folder in folderlist:
    inner_path = os.path.join(outer_path, folder)
    total_num_folder = len(folderlist)  # 文件夹的总数
    print('total have %d folders' % (total_num_folder))

    filelist = os.listdir(inner_path)  # 列举图片
    i = 0
    for item in filelist:
        total_num_file = len(filelist)  # 单个文件夹内图片的总数
        if item.endswith('.jpg'):
            src = os.path.join(os.path.abspath(inner_path), item)  # 原图的地址
            dst = os.path.join(os.path.abspath(inner_path), str(folder) + '_' + str(
                i) + '.jpg')  # 新图的地址(这里可以把str(folder) + '_' + str(i) + '.jpg'改成你想改的名称)
            try:
                os.rename(src, dst)
                print
                'converting %s to %s ...' % (src, dst)
                i += 1
            except:
                continue
    print
    'total %d to rename & converted %d jpgs' % (total_num_file, i)

随便打开测试集的,可以看到结果如下:
在这里插入图片描述

下面说下测试代码:
测试一整个文件夹的,这里的整个文件夹并不是指test文件夹,而是test文件夹下的五个子文件夹的任意一个,代码restall.py如下:

from PIL import Image
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import model
#from input_data import get_files
import input_data
import os
import warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.filterwarnings('ignore')

#自己改最后面的sunflowers改成别的
img_dir = 'D:/test/sunflowers/'


# 测试图片
def evaluate_one_image(image_array):
    with tf.Graph().as_default():
        BATCH_SIZE = 1
        N_CLASSES = 5
        image = tf.cast(image_array, tf.float32)
        image = tf.image.per_image_standardization(image)
        image = tf.reshape(image, [1, 64, 64, 3])
        logit = model.inference(image, BATCH_SIZE, N_CLASSES,1,0)
        logit = tf.nn.softmax(logit)
        x = tf.placeholder(tf.float32, shape=[64, 64, 3])
        logs_train_dir = 'D:/biye/save/train'
        saver = tf.train.Saver()
        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found')
            prediction = sess.run(logit, feed_dict={x: image_array})
            max_index = np.argmax(prediction)
            return max_index


# ------------------------------------------------------------------------

if __name__ == '__main__':
    roses=0
    tulips=0
    dandelion=0
    sunflowers=0
    daisy=0
    for name in os.listdir(img_dir):
     i=img_dir+name
     img = Image.open(i)
     imag = img.resize([64, 64])
     image = np.array(imag)
     if evaluate_one_image(image) == 0:
         print(i+'可能是玫瑰花')
         roses+=1
     elif evaluate_one_image(image) == 1:
         print (i+'可能是郁金香')
         tulips+=1
     elif evaluate_one_image(image) == 2:
         print(i+'可能是蒲公英')
         dandelion+=1
     elif evaluate_one_image(image) == 4:
         print(i+'有可能是雏菊')
         daisy+=1
     elif evaluate_one_image(image) == 3:
         print(i+'可能是向日葵')
         sunflowers+=1
    print('玫瑰花:%d 郁金香:%d' %(roses,tulips))
    print('蒲公英:%d 雏菊:%d 向日葵:%d' %(dandelion,daisy,sunflowers))

还有一个测试单张图片的testone.py

#!/bin/python

import wx
# from test import evaluate_one_image
from PIL import Image
import numpy as np
import tensorflow as tf
import model
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


def evaluate_one_image(image_array):
    with tf.Graph().as_default():
        BATCH_SIZE = 1
        N_CLASSES = 5
        image = tf.cast(image_array, tf.float32)
        image = tf.image.per_image_standardization(image)
        image = tf.reshape(image, [1, 64, 64, 3])
        logit = model.inference(image, BATCH_SIZE, N_CLASSES,1,0)
        logit = tf.nn.softmax(logit)

        x = tf.placeholder(tf.float32, shape=[64, 64, 3])
        logs_train_dir = 'D:/biye/save/train'
        saver = tf.train.Saver()
        with tf.Session() as sess:
            ckpt = tf.train.get_checkpoint_state(logs_train_dir)
            if ckpt and ckpt.model_checkpoint_path:
                global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('No checkpoint file found')
            prediction = sess.run(logit, feed_dict={x: image_array})
            max_index = np.argmax(prediction)
            return max_index

class HelloFrame(wx.Frame):

    def __init__(self,*args,**kw):
        super(HelloFrame,self).__init__(*args,**kw)
        pnl = wx.Panel(self)
        self.pnl = pnl
        st = wx.StaticText(pnl, pos=(200, 0))
        font = st.GetFont()
        font.PointSize += 10
        font = font.Bold()
        st.SetFont(font)
        # 选择图像文件按钮
        btn = wx.Button(pnl, -1, "select")
        btn.Bind(wx.EVT_BUTTON, self.OnSelect)
        self.makeMenuBar()
        self.CreateStatusBar()


    def makeMenuBar(self):
        menuBar = wx.MenuBar()
        self.SetMenuBar(menuBar)


    def OnSelect(self, event):
        wildcard = "image source(*.jpg)|*.jpg|" \
                   "Compile Python(*.pyc)|*.pyc|" \
                   "All file(*.*)|*.*"
        dialog = wx.FileDialog(None, "Choose a file", os.getcwd(),
                               "", wildcard, wx.ID_OPEN)
        if dialog.ShowModal() == wx.ID_OK:
            print(dialog.GetPath())
            img = Image.open(dialog.GetPath())
            imag = img.resize([64, 64])
            image = np.array(imag)
            print(evaluate_one_image(image))
            if evaluate_one_image(image) == 0:
                result = ('这可能是玫瑰花')
            elif evaluate_one_image(image) == 1:
                result = ('这可能是郁金香')
            elif evaluate_one_image(image) == 2:
                result = ('这可能是蒲公英')
            elif evaluate_one_image(image) == 4:
                result = ('这有可能是雏菊')
            elif evaluate_one_image(image) == 3:
                result = ('这可能是向日葵')
            # result =evaluate_one_image(image)
            result_text = wx.StaticText(self.pnl, label=result, pos=(320, 0))
            font = result_text.GetFont()
            font.PointSize += 8
            result_text.SetFont(font)
            self.initimage(name= dialog.GetPath())

    # 生成图片控件
    def initimage(self, name):
        imageShow = wx.Image(name, wx.BITMAP_TYPE_ANY)
        sb = wx.StaticBitmap(self.pnl, -1, imageShow.ConvertToBitmap(), pos=(0,30), size=(600,600))
        return sb


if __name__ == '__main__':
    app = wx.App()
    frm = HelloFrame(None, title='基于卷积神经网络的花卉图像识别', size=(600,600))
    frm.Show()
    app.MainLoop()

所有的代码就都列出来了。原本github上的源代码准确率只有70左右,经过数据增强后(也就是训练集由2500张变成7500张后),准确率提高了很多,我的是到了80了。
源码的dropout的代码注释了,也就是没用到,我是用到了,还调了最合适的也就是0.45.这里要注意训练的时候是0.45.但是验证集和测试集上都是1.用了这个准确率到了85.左右了。
其实还可以再高的,如果你还想提升,可以考虑调参,加几层网络,毕竟我的网络结构太简单了。数据增强有很多方式,你也可以试试别的。

运行代码的顺序是先运行数据增强的代码----裁剪大小的代码(两次)------训练的代码----重命名代码(也可以不要)----测试代码

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐