Python 医疗心脏病预测分类实战:逻辑回归 / 决策树 / 随机森林 / XGBoost/SVM 全模型对比
一、项目简介
本项目基于心脏病医疗数据集,构建多分类机器学习模型实现心肌梗死二分类预测(positive 患病 /negative 健康)。完整覆盖数据清洗、探索性数据分析 EDA、特征标准化、多模型训练、混淆矩阵 + ROC-AUC 评估五大流程,对比逻辑回归、决策树、随机森林、XGBoost、SVM 五种经典分类算法,最终得出树集成模型在医疗诊断场景效果最优的结论。
数据集信息:共 1319 条患者样本,9 个特征:年龄、性别、心率、收缩压、舒张压、血糖、肌酸激酶同工酶、肌钙蛋白、诊断结果,无缺失值,存在极端异常值与逻辑错误血压数据。
二、完整环境依赖导入
python
运行
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import scipy.stats as stats
from scipy.stats import chi2_contingency
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
import warnings
# 配置全局设置
warnings.filterwarnings('ignore')
mpl.rcParams['font.family'] = 'SimHei'
plt.rcParams['axes.unicode_minus'] = False
三、数据读取与基础预处理
3.1 加载数据、修改中文列名
python
运行
# 读取医疗数据集
df = pd.read_csv("Medicaldataset.csv")
# 自定义中文列名
col_name = ['年龄','性别','心率','收缩压','舒张压','血糖','肌酸激酶同工酶','肌钙蛋白','结果']
df.columns = col_name
# 查看前5行
print(df.head())
# 查看数据类型、缺失值
print(df.info())
# 统计描述
print(df.describe())
输出关键信息:
- 1319 条数据,无缺失值;
- 5 个 int 整型特征、3 个 float 浮点特征、1 个字符串标签列;
- 心率最大值 1111,存在极端异常;部分样本舒张压>收缩压,存在逻辑错误。
3.2 异常值清洗
- 过滤极端生理异常:心率>1000、收缩压<50、舒张压>140
- 修正血压逻辑错误:舒张压大于收缩压的数据,两列数值互换
python
运行
# 1. 删除极端异常样本
df = df.loc[(df['心率']<1000) & (df['收缩压']>50) & (df['舒张压']<140)]
print("清洗后样本数量:", df.shape[0])
# 2. 找出舒张压大于收缩压的错误数据
wrong_data = df[df['舒张压']>df['收缩压']]
print("血压逻辑错误样本数量:", len(wrong_data))
# 交换两列数值修复
df.loc[wrong_data.index,['收缩压','舒张压']] = df.loc[wrong_data.index,['舒张压','收缩压']].values.copy()
# 检查重复值
print("重复样本数量:", df.duplicated().sum())
# 绘制箱线图查看整体分布
df.plot(kind='box', figsize=(12,6))
plt.title("各特征箱线图")
plt.show()
清洗后样本量降至 1314 条,血压逻辑错误共 6 条全部修复,无重复数据。
四、EDA 探索性数据分析
4.1 基础人口分布:年龄 + 性别
python
运行
plt.figure(figsize=(12,8))
# 子图1:年龄直方图+核密度曲线
plt.subplot(211)
plt.hist(df['年龄'],bins=20,alpha=0.7)
plt.title('患者年龄分布直方图')
df['年龄'].plot(kind='kde',secondary_y=True)
# 子图2:性别占比饼图
plt.subplot(212)
gender_data = df['性别'].value_counts()
plt.pie(gender_data,autopct='%.2f%%',labels=['男','女'],shadow=True,explode=(0.05,0))
plt.title('患者性别比例')
plt.tight_layout()
plt.show()
分析结论:
- 年龄近似正态分布,50-75 岁中老年患者占比最高;
- 男性样本 65.98%,女性仅 34.02%,数据集男性占主导。
4.2 各特征与心脏病诊断关系(阳性 / 阴性分组)
python
运行
plt.figure(figsize=(20,16))
# 1. 年龄与患病关系箱线图
plt.subplot(3,3,1)
pos_age = df[df['结果']=='positive']['年龄']
neg_age = df[df['结果']=='negative']['年龄']
plt.boxplot([pos_age,neg_age],labels=['阳性(患病)','阴性(健康)'])
plt.title('年龄与心脏病诊断关系')
# 2. 心率与患病关系
plt.subplot(332)
pos_heart = df[df['结果']=='positive']['心率']
neg_heart = df[df['结果']=='negative']['心率']
plt.boxplot([pos_heart,neg_heart],labels=['阳性','阴性'])
plt.title('心率与心脏病诊断关系')
# 3. 确诊患者性别饼图
plt.subplot(333)
pos_data = df[df['结果']=='positive']
gender_pos = pos_data['性别'].value_counts()
plt.pie(gender_pos,labels=['男','女'],autopct='%.2f%%')
plt.title('确诊心脏病患者性别比例')
# 4. 收缩压分组箱线图
ax4 = plt.subplot(334)
df.boxplot(column='收缩压',by='结果',ax=ax4)
ax4.set_title('收缩压分布(患病/健康分组)')
# 5. 舒张压分组箱线图
ax5 = plt.subplot(335)
df.boxplot(column='舒张压',by='结果',ax=ax5)
ax5.set_title('舒张压分布(患病/健康分组)')
# 6. 血糖分组箱线图
ax6 = plt.subplot(336)
df.boxplot(column='血糖',by='结果',ax=ax6)
ax6.set_title('血糖分布(患病/健康分组)')
plt.tight_layout()
plt.show()
# 多特征按性别+患病双分组对比
df.boxplot(column=['血糖','年龄','收缩压','舒张压'],by=['性别','结果'],figsize=(16,8))
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
通过分组箱线图可直观观察:患病群体肌钙蛋白、肌酸激酶同工酶指标显著高于健康人群,是区分心脏病的核心生物标志物。
五、特征工程与数据集划分
5.1 标签二值化
将文本标签positive/negative转为 0、1 数值,方便模型训练:
python
运行
# 1=患病阳性,0=健康阴性
df['Result_Binary'] = (df['结果'] == 'positive').astype(int)
5.2 筛选核心特征 + 标准化
根据医学先验知识,选取区分度最高 4 个特征:性别、年龄、肌酸激酶同工酶、肌钙蛋白;仅对连续数值特征标准化,性别为二分类特征无需标准化。
python
运行
# 定义输入特征与目标标签
features = ['性别','年龄', '肌酸激酶同工酶', '肌钙蛋白']
X = df[features].copy()
y = df['Result_Binary']
# 需要标准化的连续变量
continuous_vars = ['年龄', '肌酸激酶同工酶', '肌钙蛋白']
X[continuous_vars] = X[continuous_vars].astype(float)
# 标准化处理
scaler = StandardScaler()
X.loc[:, continuous_vars] = scaler.fit_transform(X[continuous_vars])
# 划分训练集80%、测试集20%,分层抽样保证正负样本比例一致
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=15, stratify=y
)
六、通用模型评估函数(统一输出分类报告、混淆矩阵、ROC 曲线)
封装评估函数,一次性完成训练、预测、可视化,减少重复代码:
python
运行
def evaluate_model(model, model_name, X_train, X_test, y_train, y_test):
# 训练模型
model.fit(X_train, y_train)
# 测试集预测
y_pred = model.predict(X_test)
# 打印分类报告(精确率、召回率、F1、准确率)
print(f"===== {model_name} 模型评估报告 =====")
print(classification_report(y_test, y_pred))
# 绘制混淆矩阵热力图
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8,6))
plt.imshow(cm, cmap='hot', interpolation='nearest')
plt.colorbar()
# 填充数字
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, f"{cm[i,j]}", ha="center", va="center", color="white", fontsize=14)
plt.title(f"{model_name} 混淆矩阵")
plt.xlabel("预测标签")
plt.ylabel("真实标签")
plt.show()
# 获取预测概率/置信度,绘制ROC曲线
if hasattr(model, "predict_proba"):
y_prob = model.predict_proba(X_test)[:, 1]
else:
y_prob = model.decision_function(X_test)
fpr, tpr, _ = roc_curve(y_test, y_prob)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(8,6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线(AUC={roc_auc:.2f})')
plt.plot([0,1], [0,1], color='navy', lw=2, linestyle='--')
plt.xlabel('假阳性率 FPR')
plt.ylabel('真阳性率 TPR')
plt.title(f"{model_name} ROC曲线")
plt.legend(loc="lower right")
plt.show()
return model, roc_auc
七、五大分类模型训练与对比
7.1 逻辑回归
python
运行
lr_model = LogisticRegression(random_state=15, class_weight='balanced')
lr_model, lr_auc = evaluate_model(lr_model, "逻辑回归", X_train, X_test, y_train, y_test)
print(f"逻辑回归AUC值:{lr_auc}\n")
结果:AUC=0.92,线性模型拟合能力有限,区分效果一般。
7.2 决策树
python
运行
dt_model = DecisionTreeClassifier(random_state=15)
dt_model, dt_auc = evaluate_model(dt_model, "决策树", X_train, X_test, y_train, y_test)
print(f"决策树AUC值:{dt_auc}\n")
结果:AUC=0.98,单棵树拟合效果大幅提升,但易过拟合。
7.3 随机森林(集成树)
python
运行
rf_model = RandomForestClassifier(random_state=15, class_weight='balanced')
rf_model, rf_auc = evaluate_model(rf_model, "随机森林", X_train, X_test, y_train, y_test)
print(f"随机森林AUC值:{rf_auc}\n")
结果:AUC=0.99,集成树抗过拟合,泛化能力优秀。
7.4 XGBoost 梯度提升树
python
运行
# 正负样本权重平衡
scale_weight = sum(y_train==0) / sum(y_train==1)
xgb_model = XGBClassifier(random_state=15, scale_pos_weight=scale_weight)
xgb_model, xgb_auc = evaluate_model(xgb_model, "XGBoost", X_train, X_test, y_train, y_test)
print(f"XGBoost AUC值:{xgb_auc}\n")
结果:AUC=0.99,本项目最优模型,医疗指标非线性关系捕捉最强。
7.5 SVM 支持向量机
python
运行
svm_model = SVC(random_state=15, class_weight='balanced', probability=True)
svm_model, svm_auc = evaluate_model(svm_model, "支持向量机SVM", X_train, X_test, y_train, y_test)
print(f"SVM AUC值:{svm_auc}\n")
结果:AUC=0.92,与逻辑回归持平,高维医疗特征下优势不明显。
八、模型综合对比总结
表格
| 模型 | AUC 得分 | 优缺点 |
|---|---|---|
| 逻辑回归 | 0.92 | 可解释性强,线性假设限制精度 |
| SVM | 0.92 | 小样本友好,大数据训练慢 |
| 决策树 | 0.98 | 单树易过拟合,泛化弱 |
| 随机森林 | 0.99 | 并行训练,稳定,不易过拟合 |
| XGBoost | 0.99 | 梯度提升,预测精度最高,适合医疗诊断 |
项目结论
- 医疗心脏病数据中,肌钙蛋白、肌酸激酶同工酶是判断患病的核心特征;
- 线性模型(逻辑回归、SVM)仅能捕捉简单线性关系,效果较差;
- 树集成模型(随机森林、XGBoost)能挖掘生理指标非线性关联,AUC 达到 0.99,适合落地心脏病辅助诊断系统;
- 数据预处理至关重要:极端生理异常值、血压逻辑错误会严重干扰模型预测精度。
九、项目拓展方向
- 特征工程:增加特征交叉、分箱,筛选特征重要性;
- 超参数调优:GridSearchCV/RandomSearchCV 优化树模型参数;
- 多模型融合:Stacking 堆叠集成进一步提升精度;
- 模型部署:导出 XGBoost 模型,搭建简易 Web 诊断接口;
- 数据增强:扩充样本,解决男女样本不均衡问题。
更多推荐
所有评论(0)