每个MNIST 手写字符图片包含28*28灰度像素。使用下面的例子,可以加载训练数据或测试数据集,可以选择查看某个图片对应的数据或图象

代码如下:

​
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Created on Thu Jan  4 16:24:07 2018
@author: wangd
"""
from tkinter import Tk, BOTH, RIGHT, RAISED,Text,X,SW,LEFT,INSERT,BOTTOM
from tkinter.ttk import Frame, Button, Style, Label, Entry
from tkinter import Listbox, StringVar, END
from tkinter import messagebox as msgbox

from MyMnistData import MyMnistData
import matplotlib.pyplot as plt

#Tk class is used to create root windows.
#frame is a container for other widgets.

class MnistView(Frame):  #inherit from frame
  
    def __init__(self):
        super().__init__()  #call the constructor of our inherited class

        self.summary = StringVar()  #for value of listbox and author       
        self.imageNum=0 #the total number of images
        self.imgRows=0 #the number of rows in each image
        self.imgCols=0 #the number of columns in each image
        self.imgPoint=0 #the number of points in each image
        self.imgs=None #matrix used to store images
        self.labels=None #vector used to store lables  
               
        self.imageNumInput=StringVar('')
        self.labelNumInput=StringVar('')
        self.dataSetType=StringVar('')
        self.currentImgNum=0
        self.currentLb=0
        
        self.imgArea=None
        self.ctrFrame=None
        self.infoFrame=None
        self.infoTxt=None
        
        self.initUI() 

        
    def initUI(self):
        self.style = Style()
        self.style.theme_use("default") #.theme_use("default")
        # Tkinter support theming of widgets, imported from ttk module

        self.master.title("Zhouwa MNIST Data View")
        #self.master.config(bg='GREEN')
        # set the title of the window, the master attribute gives access to root window
        
        self.pack(fill=BOTH, expand=1)
        
        #add another frame
        self.addFrame()
        self.addAbutton()
        self.addControlInfo()
      


    def addFrame(self):
        #frame0 for copyright information
        self.mainFrame = Frame(self, relief=RAISED, borderwidth=1)
        self.mainFrame.pack(fill=X, expand=False)  
        lbl0=Label(self.mainFrame, text="This is a tool used to explore MNIST \
data, by Daoyi Wang, 2018-10-18",width=600)
        lbl0.pack(side=LEFT, padx=5, pady=5)   

       #button area, for load data            
        ctrFrame = Frame(self,height=10)
        ctrFrame.pack(fill=X)       
        lbl1 = Label(ctrFrame, text="Control Area", width=10)
        lbl1.pack(side=LEFT, padx=5, pady=5)    
        self.ctrFrame=ctrFrame
        
        #frame2 for general information display
        infoFrame = Frame(self,height=10)
        infoFrame.pack(fill=BOTH, expand=True)      
        lbl2 = Label(infoFrame, text="Summary", width=10)
        lbl2.pack(side=LEFT, padx=5, pady=5)   
        self.infoFrame=infoFrame
        
        infoTxt = Text(infoFrame,height=5)
        infoTxt.pack(side=LEFT,fill=X, pady=5, padx=5, expand=True) 
        self.infoTxt=infoTxt
        
        #data area
        self.dataFrame = Frame(self, relief=RAISED, borderwidth=1,height=80)
        self.dataFrame.pack(fill=BOTH, expand=True)  
        
        lbl3 = Label(self.dataFrame, text="Data Area",width=10)
        lbl3.pack(side=LEFT, padx=5, pady=5) 
        
        viewImgBT = Button(self.dataFrame, text="ViewImg",command=self.onViewSingleImg)
        viewImgBT.pack(side=BOTTOM, padx=5, pady=5,anchor=SW)

        area = Text(self.dataFrame,width=50,height=70)
        area.pack(side=LEFT, padx=5, pady=5)
        self.imgArea=area
        self.imgArea.insert(END,"The image will be displayed in this area!")

        
        area2 = Text(self.dataFrame,width=50,height=70)
        area2.pack(side=RIGHT, padx=5, pady=5)
        self.imgArea2=area2
        self.imgArea2.insert(END,"The simple image will be displayed in this area!")
        
    def addListbox(self):
        dataSetType = ['Test', 'Train']

        lb = Listbox(self.ctrFrame,width=8,height=2)
        
        for i in dataSetType:
            lb.insert(END, i)         
        lb.bind("<<ListboxSelect>>", self.onSelect)               
        lb.pack(side=LEFT, padx=5, pady=5)


    def onSelect(self, val):
      
        sender = val.widget
        #val is a VirtualEvent, and sender is a listbox
        idx = sender.curselection()
        
        #idx is the a tuple for the sequence number of the selected one
        value = sender.get(idx)  

        self.dataSetType.set(value) 
        #used to store the name of author, which is assoc


    def addAbutton(self):
        self.addListbox()
        loadButton = Button(self.ctrFrame, text="Load",command=self.onLoad)
        loadButton.pack(side=LEFT, padx=10, pady=5)
        
        viewImgBT = Button(self.ctrFrame, text="View",command=self.onViewImgBT)
        viewImgBT.pack(side=LEFT, padx=5, pady=5)

        nextImgBT = Button(self.ctrFrame, text="Next",command=self.onNextImgBT)
        nextImgBT.pack(side=LEFT, padx=5, pady=5)

        byLabelImgBT = Button(self.ctrFrame, text="By Label",command=self.onByLabelImgBT)
        byLabelImgBT.pack(side=LEFT, padx=5, pady=5)
                  
    
    def addControlInfo(self):
        
        lbl1 = Label(self.ctrFrame, text="ImgNum",width=8)
        lbl1.pack(side=LEFT, padx=10, pady=2) 
        entry1 = Entry(self.ctrFrame, textvariable=self.imageNumInput,width=8) 
        entry1.pack(side=LEFT)   

        lbl2 = Label(self.ctrFrame, text="LabelNum",width=8)
        lbl2.pack(side=LEFT, padx=2, pady=2) 
        entry2 = Entry(self.ctrFrame, textvariable=self.labelNumInput,width=4)  
        entry2.pack(side=LEFT)  
        
    def setimgPoints(self):
        self.imgPoint=self.imgRows*self.imgCols

    def updateSummary(self):
        self.infoTxt.delete(1.0,END)
        self.infoTxt.insert(INSERT,self.summary)
        tempStr="\nthe number of points in each image is: "
        tempStr=tempStr+str(self.imgPoint)
        self.infoTxt.insert(INSERT,tempStr)
           
    def initDataSet(self):
        self.imageNum=0
        self.imgRows=0
        self.imgCols=0
        self.imgs=None
        self.labels=None 
        self.setimgPoints()
        self.summary=''
        self.updateSummary()
        
    def onLoad(self):
        myMnistData=MyMnistData()
        myMnistData.setPath("d:\data\MNIST")
        print(self.dataSetType.get())
        self.initDataSet()

        if self.dataSetType.get()=='Train':
            myMnistData.loadTrainData()
            tempstr=myMnistData.displaySumInfoOfTrainData()
            self.imageNum=myMnistData.trainImgNum
            self.imgRows=myMnistData.trainImgRows
            self.imgCols=myMnistData.trainImgCols
            self.imgs=myMnistData.trainImages
            self.labels=myMnistData.trainLabels  
            self.setimgPoints()
            self.summary=tempstr
            self.updateSummary()
        else:
            myMnistData.loadTestData()
            tempstr=myMnistData.displaySumInfoOfTestData()
            self.imageNum=myMnistData.testImgNum
            self.imgRows=myMnistData.testImgRows
            self.imgCols=myMnistData.testImgCols
            self.imgs=myMnistData.testImages
            self.labels=myMnistData.testLabels  
            self.setimgPoints()
            self.summary=tempstr
            self.updateSummary()
            
    def onViewImgBT(self):
        try:
            num=int(self.imageNumInput.get().lstrip())
        except ValueError:
            msgbox.showinfo('',"image number is not a integer!")    
        else:
            self.setCurrentImgNum(num)
            #string=str("the current image number is %s"%(self.getCurrentImgNum()))
            #msgbox.showinfo('',string)
            self.displayImage()
 
        
    def onNextImgBT(self): 
        num=self.getCurrentImgNum()
        num=num+1
        self.setCurrentImgNum(num)
        self.imageNumInput.set(str(num))
        #string=str("the current image number is %s"%(self.getCurrentImgNum()))
        #msgbox.showinfo('',string)   
        self.displayImage()

        
    def onByLabelImgBT(self):
        try:
            num=int(self.labelNumInput.get().lstrip())
        except ValueError:
            msgbox.showinfo('',"label number is not a integer!")    
        else:
            self.setCurrentLb(num)
            #string=str("the current image number is %s"%(self.getCurrentLb()))
            #msgbox.showinfo('',string)  
            self.displayImgByLabel()
        
    def displayImage(self):
        self.imgArea.delete(1.0,END)        
        imgIndex=self.getCurrentImgNum()       
        imgData=self.imgs[imgIndex].reshape(28, 28) 
        
        self.imgArea.tag_config('link', foreground='red',font=('Courier', 7, 'bold'))

        for i in range (0,28):
            self.imgArea.insert(INSERT,'\n  ')
            for j in range (0,28):
                colorStr=self.getColorStr(imgData[i,j])
                self.imgArea.insert(INSERT,colorStr,'link')     
        self.displaySimpleImage()
 
    def displaySimpleImage(self):
        self.imgArea2.delete(1.0,END)       
        imgIndex=self.getCurrentImgNum()       
        imgData=self.imgs[imgIndex].reshape(28, 28) 
        
        self.imgArea2.tag_config('link', foreground='red',font=('Courier', 7, 'bold'))

        for i in range (0,28):
            self.imgArea2.insert(INSERT,'\n  ')
            for j in range (0,28):
                colorStr=self.getColorStr(imgData[i,j])
                if colorStr=='00':
                    colorStr='  '
                self.imgArea2.insert(INSERT,colorStr,'link')
                
    def onViewSingleImg(self):
        imgIndex=self.getCurrentImgNum()       
        imgData=self.imgs[imgIndex].reshape(28, 28) 
        plt.figure("the image for lable %s"%(self.getCurrentLb()))
        ax=plt.subplot(111)
        ax.imshow(imgData, cmap='Greys', interpolation='nearest')
    
    def displayImgByLabel(self):
        labelNum=self.getCurrentLb()
        fig=plt.figure("All the image with label %d"%(labelNum))
        ax = fig.subplots(
                nrows=5,
                ncols=5,
                sharex=True,
                sharey=True)
        ax = ax.flatten()
        
        for i in range (0,25):
            img = self.imgs[self.labels == labelNum][i].reshape(28, 28)
            ax[i].imshow(img, cmap='Greys', interpolation='nearest')
        ax[0].set_xticks([])
        ax[0].set_yticks([])
        plt.tight_layout()
        #fig.set(title="All the image with label %d"%(labelNum))
        plt.show()

        
    def getColorStr(self,color):
        tempColor=str(hex(color))[2:]
        if len(tempColor)==1:
            tempColor='0'+tempColor
        return tempColor
    
    def setCurrentImgNum(self,num):
        self.currentImgNum=num
        label=self.labels[num]
        self.setCurrentLb(label)
        
    def getCurrentImgNum(self):
        return self.currentImgNum
        
    def setCurrentLb(self,label):
         self.currentLb=label
         
    def getCurrentLb(self):
        return self.currentLb        

def centerWindow(root):
    w = 1140
    h = 850
    sw = root.winfo_screenwidth()
    sh = root.winfo_screenheight()
    x = (sw - w)/2
    y = (sh - h)/2
    
    s='%dx%d+%d+%d' % (w, h, x, y);
    root.title(s)
    root.geometry(s)

def main():
  
    root = Tk()
    centerWindow(root)
    app = MnistView()
    #create the application frame
   
    root.mainloop()  
    #the main loop begins to receive events and dispatches them to 
    #the application widgets.


if __name__ == '__main__':
    main()
    

​

运行效果如下:(显示第7个测试样本)

点击 ViewImg之后的效果如下:

 

 

 

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐