写网络结构与训练、保存(test1.py)

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2021/7/16 14:33
# @Author : wutiande

import tensorflow as tf #tf_version: 2.5.0


class MyModel(tf.keras.Model):
    """构建网络模型"""
    def __init__(self,num_classes):
        super(MyModel, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(filters=32,kernel_size=3,strides=(1,1),
                                            padding="same")
        self.b1 = tf.keras.layers.BatchNormalization()
        self.a1 = tf.keras.activations.relu
        self.pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2,2),strides=2)

        self.conv2 = tf.keras.layers.Conv2D(filters=32, kernel_size=3, strides=(1, 1),
                                            padding="same")
        self.b2 = tf.keras.layers.BatchNormalization()
        self.a2 = tf.keras.activations.relu
        self.pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)

        self.flatten = tf.keras.layers.Flatten()
        self.dense1 = tf.keras.layers.Dense(units=num_classes)

    def call(self,inputs):
        x = self.conv1(inputs)
        x = self.b1(x)
        x = self.a1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.b2(x)
        x = self.a2(x)
        x = self.pool2(x)

        x = self.flatten(x)
        x = self.dense1(x)
        output = tf.nn.softmax(x,name="output")
        return output


(x_train,y_train),(x_test,y_test) = tf.keras.datasets.mnist.load_data() # 导入数据

x_train = x_train/255.0 # 归一化
x_test = x_test/255.0

x_train = tf.expand_dims(x_train,axis=-1) # 增加维度
x_test = tf.expand_dims(x_test,axis=-1)
print(x_train.shape,y_train.shape)

model = MyModel(num_classes=10) # 初始化模型对象
# 编译
model.compile(optimizer=tf.keras.optimizers.Adam(),metrics=['accuracy'],loss=tf.keras.losses.sparse_categorical_crossentropy)
# 训练
model.fit(x_train,y_train,batch_size=32,epochs=50,validation_data=(x_test,y_test))
# 保存
tf.saved_model.save(model,"model1")

   它会自动生成model1的文件夹,将模型保存在里面

模型导入与预测(test2.py)

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2021/7/16 15:56
# @Author : wutiande

import tensorflow as tf
import cv2
import numpy as np

"""导入待预测图片并进行数据预处理"""
inputs = cv2.imread("five.jpg",0)
inputs = ~inputs
cv2.imshow('w',inputs)
cv2.waitKey(0)
inputs = (cv2.resize(inputs,(28,28))/255.0).astype(np.float32) # 使符合保存的网络的输入
inputs = tf.expand_dims(inputs,axis=0)
inputs = tf.expand_dims(inputs,axis=-1)

model = tf.saved_model.load("model1") # 导入模型

# (_,_),(inputs,label)= tf.keras.datasets.mnist.load_data()
# inputs = (inputs[0]/255.0).astype(np.float32)
# label = label[0]
# inputs = tf.expand_dims(inputs,axis=0)
# inputs = tf.expand_dims(inputs,axis=-1)
# print(inputs.shape,label)

# 上面注释掉的是官方的数据,用来验证没有写错代码

# 预测
outputs = model(inputs)
outputs = list(outputs)[0]
# 输出预测结果
print(np.argmax(outputs))

   这是用来测试的five.jpg

 这是最后输出的结果

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐