决策树分类模型对葡萄酒数据的训练及预测
一、决策树模型训练
先从训练集说起。做决策树模型,第一步当然是要有一份靠谱的数据。这里直接用 scikit-learn 自带的葡萄酒数据集,这算是分类任务里的经典案例了,拿来练手再合适不过。
先把必要的 Python 包导入进来:
import pandas as pd
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
import graphviz
接着加载数据集,看看里面到底有什么。
wine = load_wine()
加载之后可以确认一下数据概况:一共178条样本,13个特征维度,目标值有3种分类。

如果想把数据看得更清楚,可以用 Pandas 拼一下:
pd.concat([pd.DataFrame(wine.data, columns=wine.feature_names), pd.DataFrame(wine.target)], axis=1)

接下来,把数据集分成训练集和测试集。这里让测试集占30%,训练集占70%,算是比较常见的比例。
Xtrain, Xtest, Ytrain, Ytest = train_test_split(wine.data, wine.target, test_size=0.3)
然后就可以建立并训练决策树模型了。这里选择基于信息熵的 ID3 算法。
# 选择信息熵模式即ID3算法建立决策分类树模型
clf = DecisionTreeClassifier(criterion="entropy")
# 用训练数据建立决策树
clf = clf.fit(Xtrain, Ytrain)
# 用以上训练的决策树,给测试数据返回打分
score = clf.score(Xtest, Ytest)

模型训练完了,光看分数还不够直观。用 graphviz 把决策树的可视化图导出来看看,这棵树的决策逻辑就一目了然了。
feature_name = ['酒精','苹果酸','灰','灰的碱性','镁','总酚','类黄酮','非黄烷类酚类','花青素','颜色强度','色调','od280/od315稀释葡萄酒','脯氨酸']
dot_data = tree.export_graphviz(clf,
out_file = None,
feature_names= feature_name,
class_names=["琴酒","雪莉","贝尔摩德"],
filled=True,
rounded=True)
graph = graphviz.Source(dot_data)
graph.view()

二、决策树模型微调
模型跑起来了,但问题在于——你是不是直接用了默认参数?对于决策树来说,默认参数几乎不可能是最优的,尤其是它天生就爱过拟合。所以,调参才是把模型从“能用”变成“好用”的关键一步。
1、决策树分类树重要参数介绍
手头的分类树参数不少,核心都在下面这个类的定义里。具体可以看官方文档,信息更详细。
class sklearn.tree.DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, class_weight=None, ccp_alpha=0.0, monotonic_cst=None)

2、随机相关参数调整
splitter、random_state 和 max_features 这三个参数控制着树在构建时的随机性。不同数据集的“黄金组合”可能完全不同,下面这张图是 DeepSeek 给出的推荐组合,但实际调优还是要根据自己数据的训练结果来微调,没有万能药方。

3、剪枝相关参数调整
必须警惕的是,决策树是一个天生就会过拟合的模型。它在训练集上往往表现得近乎完美,但一到测试集上就可能原形毕露。所以,剪枝是做决策树绕不开的核心环节。
常用的剪枝参数包括 max_depth、min_samples_leaf、min_samples_split、max_features 和 min_impurity_decrease 等。这些参数的最优值怎么找?最直观的方法就是画学习曲线——以某个超参数的取值作为横坐标,模型的准确率作为纵坐标,找到曲线的峰值。
L = []
L1 = []
# 调节决策树最大深度
for i in range(2, 11):
dct = DecisionTreeClassifier(criterion='entropy'
, random_state=10
, splitter='random'
, max_depth=i
, min_samples_leaf=10
, min_samples_split=10)
dct.fit(Xtrain, Ytrain)
L.append([i, dct.score(Xtest, Ytest)])
# 调节一个节点在分枝后的每个子节点都必须包含的训练样本个数
for j in range(5,15):
dct1 = DecisionTreeClassifier(criterion='entropy'
, random_state=10
, splitter='random'
, max_depth=3
, min_samples_leaf=j
, min_samples_split=10)
dct1.fit(Xtrain, Ytrain)
L1.append([j, dct1.score(Xtest, Ytest)])
a = pd.DataFrame(L, columns = ['max_depth', 'zhunquelv'])
b = pd.DataFrame(L1, columns = ['min_samples_leaf', 'zhunquelv'])
A = [a, b]
plt.figure(figsize=(15, 5), dpi=70)
for k,v in enumerate(A):
plt.subplot(1,2,k+1)
plt.plot(v.iloc[:,0], v.zhunquelv, color='orange')
plt.xticks(v.iloc[:,0])
plt.xlabel(v.columns[0])
plt.ylabel('zhunquelv')
plt.title(f'{v.columns[0]}学习曲线');
plt.sa vefig('learning_curve.png', bbox_inches='tight')
从输出的学习曲线可以看到,max_depth 取 3 的时候准确率最高,min_samples_leaf 取 10 的时候结果最好。

再看看这两个参数组合下的决策树结构:

其余剪枝参数的调节套路是一样的。建议先从 max_depth 开始,确定最优值之后再依次调节其他参数,直到模型达到一个比较理想的状态。这种每次只调一个参数的策略,本质上是一种局部最优的贪心思路,但实践中效果通常都不错。
4、网格搜索选取最佳参数
如果你对模型的精度有更高要求,并且不太在意训练时长,那么网格搜索就是更彻底的办法。它会把预定义的所有参数组合都跑一遍,结合交叉验证来选最优参数。听起来很完美,代价就是计算量可能会大到让你怀疑人生。所以,它在参数空间不大、计算资源充裕的情况下用最合适。
从实践来看,局部最优往往已经能满足大多数场景了。网格搜索更像是“锦上添花”的那一步。
三、训练后的决策树模型文件保存
模型调好之后,总不能每次都重新训练。把它保存下来是最基本的操作。
# 通过pickle来保存训练后的模型文件
import pickle
with open('decision_tree_model.pkl', 'wb') as file:
pickle.dump(clf, file)
四、加载训练的决策树模型文件,以及对新数据的预测
保存好用,以后需要预测新数据的时候,直接加载就完了。
with open('decision_tree_model.pkl', 'rb') as file:
clf_loaded = pickle.load(file)
new_data = [[13.17, 2.59, 2.37, 20, 120, 1.65, 0.68, 0.53, 1.46, 9.3, 0.6, 1.62, 840]]
prediction = clf_loaded.predict(new_data)
print(f'预测类别:{prediction[0]}')

五、小结
以上就是用决策树模型对葡萄酒数据进行分类的完整流程,从训练、调参到保存和预测都过了一遍。但说到底,这只是个入门示例,具体到真实的业务数据上,参数调优仍然需要反复测试。
决策树的应用场景远不止分类,简单总结一下:
- :比如判断一封邮件是不是垃圾邮件,这是最常见的使用场景。
分类问题
- :通过特征预测连续值,比如用房屋属性预测房价。
回归问题
- :决策树天然就能告诉我们哪些特征更重要,这对理解数据非常有帮助。
特征选择
- :比如检测信用卡交易是否存在欺诈。
异常检测
- :在营销场景下,可以根据用户特征预测其购买概率,辅助制定策略。
决策分析