使用Go 机器学习库来进行数据分析 2 (决策树)

smallnest · · 4904 次点击 · · 开始浏览    
这是一个创建于 的文章,其中的信息可能已经有所发展或是发生改变。

目录 [−]

  1. 决策树和随机森林
  2. 代码
  3. 评估结果

这篇文章, 继续使用golearn库分析鸢尾花的数据集。 这一次,我们会使用决策树和随机森林来分析。

决策树和随机森林

决策树是机器学习中最接近人类思考问题的过程的一种算法,通过若干个节点,对特征进行提问并分类(可以是二分类也可以使多分类),直至最后生成叶节点(也就是只剩下一种属性)。

每个决策树都表述了一种树型结构,它由它的分支来对该类型的对象依靠属性进行分类。每个决策树可以依靠对源数据库的分割进行数据测试。这个过程可以递归式的对树进行修剪。 当不能再进行分割或一个单独的类可以被应用于某一分支时,递归过程就完成了。另外,随机森林分类器将许多决策树结合起来以提升分类的正确率。

golearn支持两种决策树算法。ID3和RandomTree。

  • ID3: 以信息增益为准则选择信息增益最大的属性。

    ID3 is a decision tree induction algorithm which splits on the Attribute which gives the greatest Information Gain (entropy gradient). It performs well on categorical data. Numeric datasets will need to be discretised before using ID3

  • RandomTree: 与ID3类似,但是选择的属性的时候随机选择。

    Random Trees are structurally identical to those generated by ID3, but the split Attribute is chosen randomly. Golearn's implementation allows you to choose up to k nodes for consideration at each split.

可以参考 ChongmingLiu的介绍: 决策树(ID3 & C4.5 & CART)

维基百科中对随机森林的介绍:

在机器学习中,随机森林是一个包含多个决策树的分类器,并且其输出的类别是由个别树输出的类别的众数而定。 Leo Breiman和Adele Cutler发展出推论出随机森林的算法。而"Random Forests"是他们的商标。这个术语是1995年由贝尔实验室的Tin Kam Ho所提出的随机决策森林(random decision forests)而来的。这个方法则是结合Breimans的"Bootstrap aggregating"想法和Ho的"random subspace method" 以建造决策树的集合。

在机器学习中,随机森林由许多的决策树组成,因为这些决策树的形成采用了随机的方法,因此也叫做随机决策树。随机森林中的树之间是没有关联的。当测试数据进入随机森林时,其实就是让每一颗决策树进行分类,最后取所有决策树中分类结果最多的那类为最终的结果。因此随机森林是一个包含多个决策树的分类器,并且其输出的类别是由个别树输出的类别的众数而定。

代码

下面是使用决策树和随机森林预测鸢尾花分类的代码,来自golearn:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Demonstrates decision tree classification
package main
import (
"fmt"
"github.com/sjwhitworth/golearn/base"
"github.com/sjwhitworth/golearn/ensemble"
"github.com/sjwhitworth/golearn/evaluation"
"github.com/sjwhitworth/golearn/filters"
"github.com/sjwhitworth/golearn/trees"
"math/rand"
)
func main() {
var tree base.Classifier
rand.Seed(44111342)
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.999)
for _, a := range base.NonClassFloatAttributes(iris) {
filt.AddAttribute(a)
}
filt.Train()
irisf := base.NewLazilyFilteredInstances(iris, filt)
// Create a 60-40 training-test split
trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)
//
// First up, use ID3
//
tree = trees.NewID3DecisionTree(0.6)
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err := tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (information gain)")
cf, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator))
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (information gain ratio)")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator))
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (gini index generator)")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
//
// Next up, Random Trees
//
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
fmt.Println("RandomTree Performance")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
//
// Finally, Random Forests
//
tree = ensemble.NewRandomForest(70, 3)
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
fmt.Println("RandomForest Performance")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
}

首选使用ChiMerge方法进行数据离散,ChiMerge 是监督的、自底向上的(即基于合并的)数据离散化方法。它依赖于卡方分析:具有最小卡方值的相邻区间合并在一起,直到满足确定的停止准则。
基本思想:对于精确的离散化,相对类频率在一个区间内应当完全一致。因此,如果两个相邻的区间具有非常类似的类分布,则这两个区间可以合并;否则,它们应当保持分开。而低卡方值表明它们具有相似的类分布。可以参考"Principles of Data Mining"第二版中的第105页--第115页。

接着调用base.NewLazilyFilteredInstances应用filter得到FixedDataGrid。

之后将数据集分成训练数据和测试数据两部分。

接下来就是训练数据、预测与评估了。

分别使用ID3ID3 with InformationGainRatioRuleGeneratorID3 with GiniCoefficientRuleGeneratorRandomTreeRandomForest算法进行处理。

评估结果

以下是各种算法的评估结果,可以和 kNN进行比较,看起来比不过kNN的预测。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
ID3 Performance (information gain)
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
--------------- -------------- --------------- -------------- --------- ------ --------
Iris-virginica 32 5 46 0.8649 0.9697 0.9143
Iris-versicolor 4 1 61 0.8000 0.1818 0.2963
Iris-setosa 29 13 42 0.6905 1.0000 0.8169
Overall accuracy: 0.7738
ID3 Performance (information gain ratio)
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
--------------- -------------- --------------- -------------- --------- ------ --------
Iris-virginica 29 3 48 0.9062 0.8788 0.8923
Iris-versicolor 5 3 59 0.6250 0.2273 0.3333
Iris-setosa 29 15 40 0.6591 1.0000 0.7945
Overall accuracy: 0.7500
ID3 Performance (gini index generator)
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
--------------- -------------- --------------- -------------- --------- ------ --------
Iris-virginica 26 5 46 0.8387 0.7879 0.8125
Iris-versicolor 17 36 26 0.3208 0.7727 0.4533
Iris-setosa 0 0 55 NaN 0.0000 NaN
Overall accuracy: 0.5119
RandomTree Performance
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
--------------- -------------- --------------- -------------- --------- ------ --------
Iris-virginica 30 3 48 0.9091 0.9091 0.9091
Iris-versicolor 9 3 59 0.7500 0.4091 0.5294
Iris-setosa 29 10 45 0.7436 1.0000 0.8529
Overall accuracy: 0.8095
RandomForest Performance
Reference Class True Positives False Positives True Negatives Precision Recall F1 Score
--------------- -------------- --------------- -------------- --------- ------ --------
Iris-virginica 31 8 43 0.7949 0.9394 0.8611
Iris-versicolor 0 0 62 NaN 0.0000 NaN
Iris-setosa 29 16 39 0.6444 1.0000 0.7838
Overall accuracy: 0.7143

有疑问加站长微信联系(非本文作者)

本文来自:鸟窝

感谢作者:smallnest

查看原文:使用Go 机器学习库来进行数据分析 2 (决策树)

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

4904 次点击  
加入收藏 微博
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传