1.首先进行数据处理

import numpy as np  #用于数据处理
from matplotlib import pyplot as plt  #用于显示图像和画图
from sklearn import svm #导入支持向量机
from sklearn.model_selection  import train_test_split #用于数据集划分
from sklearn.metrics import accuracy_score  #用于计算正确率
import cv2  #用于读取图片
import os  #文件读取
import pickle  #用于模型的保存
from PIL import Image

SHAPE = (30, 30) #设置输入图片的大小
1.1文件的结构如下

图片不用太多,一类几张即可
替换成自己的图片以及目录即可
在这里插入图片描述

    def getImageData(self,directory):
        s = 1
        feature_list = list()
        label_list = list()
        num_classes = 0
        for root, dirs, files in os.walk(directory):
            for d in dirs:
                num_classes += 1
                images = os.listdir(root + d)
                for image in images:
                    s += 1
                    label_list.append(d)
                    feature_list.append(Svm_derection.extractFeaturesFromImage(root + d + "/" + image))

        return np.asarray(feature_list), np.asarray(label_list)
1.2接下来图片的预处理函数(上方有调用到)
    def extractFeaturesFromImage(self,image_file):
        img = cv2.imread(image_file)#读取图片
        img = cv2.resize(img, self.SHAPE, interpolation=cv2.INTER_CUBIC)
        #对图片进行risize操作统一大小
        img = img.flatten()#对图像进行降维操作,方便算法计算
        img = img / np.mean(img)#归一化,突出特征
        return img

2.svm模型训练

    def train(self,dir):
    	#数据获取,这里Svm_derection是自定义类的名称
        feature_array, label_array = Svm_derection.getImageData(self.directory)
        #数据的分割
        X_train, X_test, y_train, y_test = train_test_split(feature_array, label_array, test_size=0.2, random_state=42)

        print("shape of raw image data: {0}".format(feature_array.shape))
        print("shape of raw image data: {0}".format(X_train.shape))
        print("shape of raw image data: {0}".format(X_test.shape))
		#模型的选择
        clf = svm.SVC(gamma=0.001, C=100., probability=True)
        #模型的训练
        clf.fit(X_train, y_train);
        #模型测试
        Ypred = clf.predict(X_test);

        print("pre",Ypred)
        print("test",y_test)
		#模型保存
        pickle.dump(clf, open("svm.pkl", "wb"))

3.模型读取使用

    def test(self,path,img_file):
        pkl_file = open(path, 'rb')
        clf=pickle.load(pkl_file)
        Ypred = clf.predict(np.reshape(self.extractFeaturesFromImage(img_file),(1,2700)))
        return Ypred

4.运行代码

path='svm.pkl'#模型保存位置以及名字
img='derection/'#数据集位置
img_file='derection/f1/1.jpg'#测试图片位置
train(img)
t=test(path,img_file)
print(t)
img = Image.open(os.path.join('derection/f1/1.jpg'))
plt.figure("Image") # 图像窗口名称
plt.imshow(img)
plt.axis('off') # 关掉坐标轴为 off
plt.title(t) # 图像题目
plt.show()

5.结果展示

在这里插入图片描述

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐