机器学习实验:使用sklearn的决策树算法对葡萄酒数据集进行分类
1、使用sklearn的决策树算法对葡萄酒数据集进行分类,要求:划分训练集和测试集(测试集占20%)对测试集的预测类别标签和真实标签进行对比输出分类的准确率调整参数比较不同算法(ID3,C4.5,CART)的分类效果。2、把ID3算法修改为CART,并实现以下例子的分类。
·
机器学习实验:使用sklearn的决策树算法对葡萄酒数据集进行分类
问题如下:
使用sklearn的决策树算法对葡萄酒数据集进行分类,要求:
①划分训练集和测试集(测试集占20%)
②对测试集的预测类别标签和真实标签进行对比
③输出分类的准确率
④调整参数比较不同算法(ID3,C4.5,CART)的分类效果。
代码实现:
-
导入依赖包
#导入相关库 import sklearn from sklearn.model_selection import train_test_split from sklearn import tree #导入tree模块 from sklearn.datasets import load_wine from math import log2 import pandas as pd import graphviz import treePlotter
-
导入数据集
#导入数据集 wine = load_wine() X = wine.data #X Y = wine.target #Y features_name = wine.feature_names print(features_name) pd.concat([pd.DataFrame(X),pd.DataFrame(Y)],axis=1) #打印数据
-
划分数据集,数据集划分为测试集占20%;
#划分数据集,数据集划分为测试集占20%; x_train, x_test, y_train, y_test = train_test_split( X, Y,test_size=0.2) # print(x_train.shape) #(142, 13) # print(x_test.shape) #(36, 13)
-
导入模型,进行训练
#采用C4.5算法进行计算 #获取模型 model = tree.DecisionTreeClassifier(criterion="entropy",splitter="best",max_depth=None,min_samples_split=2, min_samples_leaf=1,min_weight_fraction_leaf=0.0,max_features=None, random_state=None,max_leaf_nodes=None,class_weight=None); model.fit(x_train,y_train) score = model.score(x_test,y_test) y_predict = model.predict(x_test) print('准确率为:',score) #准确率为: 0.9444444444444444
-
对测试集的预测类别标签和真实标签进行对比
pd.concat([pd.DataFrame(x_test),pd.DataFrame(y_test),pd.DataFrame(y_predict)],axis=1) #打印数据,对测试集的预测类别标签和真实标签进行对比
最后两列为真实标签和预测类别标签
-
调整参数比较不同算法(ID3,C4.5,CART)的分类效果
#采用CART算法进行计算 #获取模型 model = tree.DecisionTreeClassifier(criterion="gini",splitter="best",max_depth=None,min_samples_split=2, min_samples_leaf=1,min_weight_fraction_leaf=0.0,max_features=None, random_state=None,max_leaf_nodes=None,class_weight=None); model.fit(x_train,y_train) score = model.score(x_test,y_test) y_predict = model.predict(x_test) print('准确率为:',score) #准确率为: 1.0
-
画出最后预测的树
feature_name = ['酒精','苹果酸','灰','灰的碱性','镁','总酚','类黄酮','非黄烷类酚类','花青素','颜色强度','色调','od280/od315稀释葡萄酒','脯氨酸'] dot_data = tree.export_graphviz(model ,out_file=None ,feature_names=feature_name ,class_names=['二锅头','苦荞','江小白'] ,filled=True ,rounded=True) graph = graphviz.Source(dot_data) graph #graph.render('tree')
更多推荐
已为社区贡献2条内容
所有评论(0)