基于Keras的人工神经网络搭建入门实例

人工智能(AI)已经成为一个热门的方向,而神经网络的搭建也是最重要的一环,本篇介绍一个基于Keras的人工神经网络搭建入门级实例,(后端使用Tensorflow)
如果对Python还不是很了解的话,可以花俩小时看看这个python教程
https://pan.baidu.com/s/139i6QEMFdBuG7EmK8SwhFg
提取码:176t
上代码:

import numpy as np
from keras.models import Sequential
from keras.layers import Dense
import matplotlib.pyplot as plt
np.random.seed(1377)

~这一段主要是导入一些库,如果没安装的话百度安装方法
~numpy是一个运行速度非常快的数学库,主要用于数组计算
~keras就是主角啦,一个简便强大的神经网络搭建框架,后端有Tensorflow和Theano以及CNTK
~matplotlib主要是用于绘图的库
~random.seed用于产生随机数

#creat some data
X1=np.linspace(-1,1,200)
np.random.shuffle(X1)
Y=0.5*X1+2+np.random.normal(0,0.05,(200,))
#plot
plt.scatter(X1,Y)
plt.show()

~这一段就是用于产生本次实验的数据集了,X1在-1到1之间取200个数并且打乱次序
~再建立一个Y与X1的关系
~random.normal是正态分布函数,均值为0,标准差为0.05,输出的shape为(200,),结果就是Y≈0.5*X1+2
~最后绘图如下
200个数据集

X1_train,Y_train=X1[:160],Y[:160]
X1_test,Y_test=X1[160:],Y[160:]

~这两句将产生的数据集前80%纳为训练集,后20%纳为测试集(一般都是8,2分)
~下面开始来真正的搭建神经网络:

model=Sequential()
model.add(Dense(output_dim=1,input_dim=1))
model.compile(loss="mean_squared_error",optimizer='sgd')

~该模型采用Sequential顺序结构,一层连接一层
~通过model.add来添加层,此处添加一个Dense(全连接层),output_dim(单次输出个数)=1,input_dim(单次输入个数)=1
~model.compile来定义其他参数,loss(损失函数)=mse(用均方误差测损失值),optimizer(优化器)=‘sgd’(随机梯度下降优化器)
~至此,一个最简单的人工神经网络就搭建完成了,接下来就是训练网络

print("Traing..........")
for i in range(301):
    loss=model.train_on_batch(X1_train,Y_train)
    if i%50==0:
        print('train loss:',loss)

~训练的时候我们把input,output都告诉网络(Supervised learing)
~通过model.train_on_batch训练网络,他会返回该次训练的损失值
~训练301次,每50次打印一次’train loss值

print('\nTesting........')
loss=model.evaluate(X1_test,Y_test,batch_size=160)
print('test loss:',loss)
W,b=model.layers[0].get_weights()
print('W:',W,'\nb:',b)

~这一段是测试部分,评估训练后的网络在测试集的表现情况(用evaluate评估)
~因为我们只建立了一层(input和output不计入其中),实际建立的就是一个Y=W*x+b的function,训练后W越接近0.5,b越接近2说明效果较好

Y_pred=model.predict(X1_test)
plt.scatter(X1_test,Y_test)
plt.plot(X1_test,Y_pred)
plt.show()

~这一段是通过X1的测试集预测输出的Y值(这才是我们常常关心的地方),看是否和Y_test很接近(上面的test loss也有相同作用)
~一起来看看训练效果吧

~W和b都分别接近于0.5和2
在这里插入图片描述
~直线纵坐标为网络预测的输出,离散的点的纵坐标表示真实的输出
~入门实例一就到这了,如有错误请留言
~源于莫烦

Logo

K8S/Kubernetes社区为您提供最前沿的新闻资讯和知识内容

更多推荐