本文是周志华老师的《机器学习》一书中第4章 决策树 的课后题第4.4题的实现。原题是:
试编程实现基于基尼指数进行划分选择的决策树算法,为表4.2中的数据生成预剪枝、后剪枝决策树,并与未剪枝决策树进行比较。
本文主要是不进行剪枝的CART决策树的实现,预剪枝与后剪枝的CART决策树实现分别可见Python编程实现预剪枝的CART决策树和Python编程实现后剪枝的CART决策树。如果发现文章中的任何问题,欢迎通过QQ进行交流。
与ID3算法选择信息增益作为选择最优属性的标准不同,CART决策树选择使划分后基尼指数(Gini index)最小的属性作为最优划分属性。假设当前样本集合中第
类样本所占的比例为
,则
的纯度可以用基尼值来度量:
反映了从数据集
中随机抽取两个样本,这两个样本不属于同一类的概率,因此
越小,则数据集
的纯度越高。
假定离散的属性有
个可能的取值
,若使用
来对样本集
来进行划分,则会产生
个分支结点,其中第
个分支结点包含了
中所有在属性
上取值为
的样本,记为
,则属性
的基尼指数定义为
如果数据集中有取值范围是连续数值的属性,我们仍然需要使用二分法来寻找最佳的分隔点。
西瓜数据集2.0的可用版本如下所示
def watermelon2():
train_data = [
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'],
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'],
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '是'],
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '否'],
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '否'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '否'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '否'],
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '否']
]
test_data = [
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'],
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '是'],
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '否'],
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '否'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '否'],
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '否'],
]
labels = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']
return train_data, test_data, labels
我在实现CART决策树的时候,使用了和之前Python编程实现基于信息熵进行划分选择的决策树算法中相同的决策树结点结构TreeNode:
class TreeNode:
"""
决策树结点类
"""
current_index = 0
def __init__(self, parent=None, attr_name=None, children=None, judge=None, split=None, data_index=None,
attr_value=None, rest_attribute=None):
"""
决策树结点类初始化方法
:param parent: 父节点
"""
self.parent = parent # 父节点,根节点的父节点为 None
self.attribute_name = attr_name # 本节点上进行划分的属性名
self.attribute_value = attr_value # 本节点上划分属性的值,是与父节点的划分属性名相对应的
self.children = children # 孩子结点列表
self.judge = judge # 如果是叶子结点,需要给出判断
self.split = split # 如果是使用连续属性进行划分,需要给出分割点
self.data_index = data_index # 对应训练数据集的训练索引号
self.index = TreeNode.current_index # 当前结点的索引号,方便输出时查看
self.rest_attribute = rest_attribute # 尚未使用的属性列表
TreeNode.current_index += 1
def to_string(self):
"""用一个字符串来描述当前结点信息"""
this_string = 'current index : ' + str(self.index) + ";\n"
if not (self.parent is None):
parent_node = self.parent
this_string = this_string + 'parent index : ' + str(parent_node.index) + ";\n"
this_string = this_string + str(parent_node.attribute_name) + " : " + str(self.attribute_value) + ";\n"
this_string = this_string + "data : " + str(self.data_index) + ";\n"
if not(self.children is None):
this_string = this_string + 'select attribute is : ' + str(self.attribute_name) + ";\n"
child_list = []
for child in self.children:
child_list.append(child.index)
this_string = this_string + 'children : ' + str(child_list)
if not (self.judge is None):
this_string = this_string + 'label : ' + self.judge
return this_string
以下是不进行剪枝的CART决策树的主要实现代码cart.py:
# CART决策树,使用基尼指数(Gini index)来选择划分属性
# 分别实现预剪枝、后剪枝和不进行剪枝的实现
import math
from Ch04DecisionTree import TreeNode
from Ch04DecisionTree import Dataset
def is_number(s):
"""判断一个字符串是否为数字,如果是数字,我们认为这个属性的值是连续的"""
try:
float(s)
return True
except ValueError:
pass
return False
def gini(labels=[]):
"""
计算数据集的基尼值
:param labels: 数据集的类别标签
:return:
"""
data_count = {}
for label in labels:
if data_count.__contains__(label):
data_count[label] += 1
else:
data_count[label] = 1
n = len(labels)
if n == 0:
return 0
gini_value = 1
for key, value in data_count.items():
gini_value = gini_value - (value/n)*(value/n)
return gini_value
def gini_index_basic(n, attr_labels={}):
gini_value = 0
for attribute, labels in attr_labels.items():
gini_value = gini_value + len(labels) / n * gini(labels)
return gini_value
def gini_index(attributes=[], labels=[], is_value=False):
"""
计算一个属性的基尼指数
:param attributes: 当前数据在该属性上的属性值列表
:param labels:
:param is_value:
:return:
"""
n = len(labels)
attr_labels = {}
gini_value = 0 # 最终要返回的结果
split = None #
if is_value: # 属性值是连续的数值
sorted_attributes = attributes.copy()
sorted_attributes.sort()
split_points = []
for i in range(0, n-1):
split_points.append((sorted_attributes[i+1]+sorted_attributes[i])/2)
split_evaluation = []
for current_split in split_points:
low_labels = []
up_labels = []
for i in range(0, n):
if attributes[i] <= current_split:
low_labels.append(labels[i])
else:
up_labels.append(labels[i])
attr_labels = {'small': low_labels, 'large': up_labels}
split_evaluation.append(gini_index_basic(n, attr_labels=attr_labels))
gini_value = min(split_evaluation)
split = split_points[split_evaluation.index(gini_value)]
else: # 属性值是离散的词
for i in range(0, n):
if attr_labels.__contains__(attributes[i]):
temp_list = attr_labels[attributes[i]]
temp_list.append(labels[i])
else:
temp_list = []
temp_list.append(labels[i])
attr_labels[attributes[i]] = temp_list
gini_value = gini_index_basic(n, attr_labels=attr_labels)
return gini_value, split
def finish_node(current_node=TreeNode.TreeNode(), data=[], label=[]):
"""
完成一个结点上的计算
:param current_node: 当前计算的结点
:param data: 数据集
:param label: 数据集的 label
:return:
"""
n = len(label)
# 判断当前结点中的数据是否属于同一类
one_class = True
this_data_index = current_node.data_index
for i in this_data_index:
for j in this_data_index:
if label[i] != label[j]:
one_class = False
break
if not one_class:
break
if one_class:
current_node.judge = label[this_data_index[0]]
return
rest_title = current_node.rest_attribute # 候选属性
if len(rest_title) == 0: # 如果候选属性为空,则是个叶子结点。需要选择最多的那个类作为该结点的类
label_count = {}
temp_data = current_node.data_index
for index in temp_data:
if label_count.__contains__(label[index]):
label_count[label[index]] += 1
else:
label_count[label[index]] = 1
final_label = max(label_count)
current_node.judge = final_label
return
title_gini = {} # 记录每个属性的基尼指数
title_split_value = {} # 记录每个属性的分隔值,如果是连续属性则为分隔值,如果是离散属性则为None
for title in rest_title:
attr_values = []
current_label = []
for index in current_node.data_index:
this_data = data[index]
attr_values.append(this_data[title])
current_label.append(label[index])
temp_data = data[0]
this_gain, this_split_value = gini_index(attr_values, current_label, is_number(temp_data[title])) # 如果属性值为数字,则认为是连续的
title_gini[title] = this_gain
title_split_value[title] = this_split_value
best_attr = min(title_gini, key=title_gini.get) # 基尼指数最小的属性名
current_node.attribute_name = best_attr
current_node.split = title_split_value[best_attr]
rest_title.remove(best_attr)
a_data = data[0]
if is_number(a_data[best_attr]): # 如果是该属性的值为连续数值
split_value = title_split_value[best_attr]
small_data = []
large_data = []
for index in current_node.data_index:
this_data = data[index]
if this_data[best_attr] <= split_value:
small_data.append(index)
else:
large_data.append(index)
small_str = '<=' + str(split_value)
large_str = '>' + str(split_value)
small_child = TreeNode.TreeNode(parent=current_node, data_index=small_data, attr_value=small_str,
rest_attribute=rest_title.copy())
large_child = TreeNode.TreeNode(parent=current_node, data_index=large_data, attr_value=large_str,
rest_attribute=rest_title.copy())
current_node.children = [small_child, large_child]
else: # 如果该属性的值是离散值
best_titlevalue_dict = {} # key是属性值的取值,value是个list记录所包含的样本序号
for index in current_node.data_index:
this_data = data[index]
if best_titlevalue_dict.__contains__(this_data[best_attr]):
temp_list = best_titlevalue_dict[this_data[best_attr]]
temp_list.append(index)
else:
temp_list = [index]
best_titlevalue_dict[this_data[best_attr]] = temp_list
children_list = []
for key, index_list in best_titlevalue_dict.items():
a_child = TreeNode.TreeNode(parent=current_node, data_index=index_list, attr_value=key,
rest_attribute=rest_title.copy())
children_list.append(a_child)
current_node.children = children_list
# print(current_node.to_string())
for child in current_node.children: # 递归
finish_node(child, data, label)
def cart_tree(Data, title, label):
"""
生成一颗 CART 决策树
:param Data: 数据集,每个样本是一个 dict(属性名:属性值),整个 Data 是个大的 list
:param title: 每个属性的名字,如 色泽、含糖率等
:param label: 存储的是每个样本的类别
:return:
"""
n = len(Data)
rest_title = title.copy()
root_data = []
for i in range(0, n):
root_data.append(i)
root_node = TreeNode.TreeNode(data_index=root_data, rest_attribute=title.copy())
finish_node(root_node, Data, label)
return root_node
def print_tree(root=TreeNode.TreeNode()):
"""
打印输出一颗树
:param root: 根节点
:return:
"""
node_list = [root]
while(len(node_list)>0):
current_node = node_list[0]
print('--------------------------------------------')
print(current_node.to_string())
print('--------------------------------------------')
children_list = current_node.children
if not (children_list is None):
for child in children_list:
node_list.append(child)
node_list.remove(current_node)
def classify_data(decision_tree=TreeNode.TreeNode(), x={}):
"""
使用决策树判断一个数据样本的类别标签
:param decision_tree: 决策树的根节点
:param x: 要进行判断的样本
:return:
"""
current_node = decision_tree
while current_node.judge is None:
if current_node.split is None: # 离散属性
can_judge = False # 如果训练数据集不够大,测试数据集中可能会有在训练数据集中没有出现过的属性值
for child in current_node.children:
if child.attribute_value == x[current_node.attribute_name]:
current_node = child
can_judge = True
break
if not can_judge:
return None
else:
child_list = current_node.children
if x[current_node.attribute_name] <= current_node.split:
current_node = child_list[0]
else:
current_node = child_list[1]
return current_node.judge
def run_test():
train_watermelon, test_watermelon, title = Dataset.watermelon2()
# 先处理数据
train_data = []
test_data = []
train_label = []
test_label = []
for melon in train_watermelon:
a_dict = {}
dim = len(melon) - 1
for i in range(0, dim):
a_dict[title[i]] = melon[i]
train_data.append(a_dict)
train_label.append(melon[dim])
for melon in test_watermelon:
a_dict = {}
dim = len(melon) - 1
for i in range(0, dim):
a_dict[title[i]] = melon[i]
test_data.append(a_dict)
test_label.append(melon[dim])
decision_tree = cart_tree(train_data, title, train_label)
print('训练的决策树是:')
print_tree(decision_tree)
print('\n')
test_judge = []
for melon in test_data:
test_judge.append(classify_data(decision_tree, melon))
print('决策树在测试数据集上的分类结果是:', test_judge)
print('测试数据集的正确类别信息应该是: ', test_label)
accuracy = 0
for i in range(0, len(test_label)):
if test_label[i] == test_judge[i]:
accuracy += 1
accuracy /= len(test_label)
print('决策树在测试数据集上的分类正确率为:'+str(accuracy*100)+"%")
if __name__ == '__main__':
run_test()
在西瓜数据集2.0上的运行结果如下所示:
训练的决策树是:
--------------------------------------------
current index : 3;
data : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
select attribute is : 色泽;
children : [4, 5, 6]
--------------------------------------------
--------------------------------------------
current index : 4;
parent index : 3;
色泽 : 青绿;
data : [0, 3, 5, 9];
select attribute is : 敲声;
children : [7, 8, 9]
--------------------------------------------
--------------------------------------------
current index : 5;
parent index : 3;
色泽 : 乌黑;
data : [1, 2, 4, 7];
select attribute is : 根蒂;
children : [10, 11]
--------------------------------------------
--------------------------------------------
current index : 6;
parent index : 3;
色泽 : 浅白;
data : [6, 8];
label : 否
--------------------------------------------
--------------------------------------------
current index : 7;
parent index : 4;
敲声 : 浊响;
data : [0, 3];
label : 是
--------------------------------------------
--------------------------------------------
current index : 8;
parent index : 4;
敲声 : 清脆;
data : [5];
label : 否
--------------------------------------------
--------------------------------------------
current index : 9;
parent index : 4;
敲声 : 沉闷;
data : [9];
label : 否
--------------------------------------------
--------------------------------------------
current index : 10;
parent index : 5;
根蒂 : 蜷缩;
data : [1, 2];
label : 是
--------------------------------------------
--------------------------------------------
current index : 11;
parent index : 5;
根蒂 : 稍蜷;
data : [4, 7];
select attribute is : 纹理;
children : [12, 13]
--------------------------------------------
--------------------------------------------
current index : 12;
parent index : 11;
纹理 : 稍糊;
data : [4];
label : 是
--------------------------------------------
--------------------------------------------
current index : 13;
parent index : 11;
纹理 : 清晰;
data : [7];
label : 否
--------------------------------------------
决策树在测试数据集上的分类结果是: ['否', '否', '否', '是', '否', '否', '是']
测试数据集的正确类别信息应该是: ['是', '是', '是', '否', '否', '否', '否']
决策树在测试数据集上的分类正确率为:28.57142857142857%
通过与预剪枝的决策树和后剪枝的决策树比较可以看出,剪枝之后决策树在测试数据集上的分类正确率有了明显的提升,但是由于数据的原因,这里并没有体现出后剪枝相比于预剪枝的正确率优势。
版权声明:本文为john_bian原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。