const.go
package main
var (
ColNames = []string{"feature", "document", "machine", "load_time",
"search_time", "reduce_and_save"}
ResColNames = []string{"feature", "document", "machine", "total"}
)
fit_classification.go
package main
import (
"fmt"
"log"
"math"
"os"
"github.com/go-gota/gota/dataframe"
"github.com/go-gota/gota/series"
"gonum.org/v1/gonum/optimize"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/plotutil"
"gonum.org/v1/plot/vg"
)
// 根据条件修改原先值
func getTotal(s series.Series) series.Series {
loadTime, _ := s.Val(3).(int)
searchTime, _ := s.Val(4).(int)
rAsTime, _ := s.Val(5).(int)
res := loadTime + searchTime + rAsTime
resF := float64(res) / float64(60)
return series.Floats(resF)
}
func getDoc(s series.Series) series.Series {
document, _ := s.Val(1).(float64)
resF := float64(2*document) / float64(1000)
return series.Floats(resF)
}
// dataPrepare 数据预处理函数
func dataPrepare(clsDF *dataframe.DataFrame) {
// 获取total列
*clsDF = clsDF.Select(ColNames)
totalSeries := clsDF.Rapply(getTotal)
totalSeries.SetNames("total")
*clsDF = clsDF.CBind(totalSeries)
// document列 *2/1000
*clsDF = clsDF.Select(ResColNames)
newDocSeries := clsDF.Rapply(getDoc)
newDocSeries.SetNames("new_doc")
*clsDF = clsDF.CBind(newDocSeries)
*clsDF = clsDF.Drop([]string{"document"})
*clsDF = clsDF.Rename("document", "new_doc")
*clsDF = clsDF.Select(ResColNames)
}
// dataOptimize 数据优化和拟合函数
func dataOptimize(clsDF *dataframe.DataFrame) (actPoints, expPoints plotter.XYs, fa, fb float64) {
// 开始数据拟合
// 实际观测点
actPoints = plotter.XYs{}
// N行数据产生N个点
for i := 0; i < clsDF.Nrow(); i++ {
document := clsDF.Elem(i, 1).Val().(float64)
machine := clsDF.Elem(i, 2).Val().(int)
val := clsDF.Elem(i, 3).Val().(float64)
actPoints = append(actPoints, plotter.XY{
X: float64(document) / float64(machine),
Y: val,
})
}
result, err := optimize.Minimize(optimize.Problem{
Func: func(x []float64) float64 {
if len(x) != 2 {
panic("illegal x")
}
a := x[0]
b := x[1]
var sum float64
for _, point := range actPoints {
y := a*point.X + b
sum += math.Abs(y - point.Y)
}
return sum
},
}, []float64{1, 1}, &optimize.Settings{}, &optimize.NelderMead{})
if err != nil {
panic(err)
}
// 最小二乘法拟合出来的k和b值
fa, fb = result.X[0], result.X[1]
expPoints = plotter.XYs{}
for i := 0; i < clsDF.Nrow(); i++ {
document := clsDF.Elem(i, 1).Val().(float64)
machine := clsDF.Elem(i, 2).Val().(int)
x := float64(document) / float64(machine)
expPoints = append(expPoints, plotter.XY{
X: x,
Y: fa*float64(x) + fb,
})
}
return
}
func dataPlot(actPoints, expPoints plotter.XYs) {
plt, err := plot.New()
if err != nil {
panic(err)
}
plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 10, 10
if err := plotutil.AddLinePoints(plt,
"expPoints", expPoints,
"actPoints", actPoints,
); err != nil {
panic(err)
}
if err := plt.Save(5*vg.Inch, 5*vg.Inch, "classification-fit.png"); err != nil {
panic(err)
}
}
// FitClassification 分类曲线拟合函数
func FitClassification() {
clsData, err := os.Open("classification_data.csv")
if err != nil {
log.Fatal(err)
}
defer clsData.Close()
clsDF := dataframe.ReadCSV(clsData)
// 数据预处理
dataPrepare(&clsDF)
// 数据预处理完成
fmt.Println("数据预处理完成...")
fmt.Println(clsDF)
// 数据拟合
actPoints, expPoints, fa, fb := dataOptimize(&clsDF)
// 拟合完成,输出fa,fb
fmt.Println("Fa", fa, "Fb", fb)
// 数据绘图
dataPlot(actPoints, expPoints)
fmt.Println("绘制完成,图形地址: classification-fit.png")
}
main.go
package main
func main() {
}
main_test.go
package main
import "testing"
// TestFitClassification 测试分类曲线拟合
func TestFitClassification(t *testing.T) {
FitClassification()
}
运行数据
feature,document,machine,load_time,search_time,reduce_and_save
100,5000,4,19,130,67
100,5000,4,12,130,61
100,5000,4,13,127,61
100,5000,4,13,124,63
100,5000,4,13,129,59
100,5000,4,13,125,60
100,5000,4,13,123,63
100,5000,4,13,129,61
100,5000,4,12,127,61
100,5000,4,12,125,62
100,5000,4,13,128,59
100,5000,4,13,128,61
100,5000,4,12,130,60
100,5000,4,12,125,61
100,5000,4,13,127,60
100,5000,4,13,126,63
100,5000,4,13,127,64
100,5000,3,18,160,67
100,5000,3,13,166,59
100,5000,3,12,167,61
100,5000,3,12,168,60
100,5000,3,12,170,61
100,5000,3,12,154,63
100,5000,3,13,168,60
100,5000,3,12,167,60
100,5000,3,12,148,64
100,5000,3,12,167,65
100,5000,3,12,164,60
100,5000,3,12,150,59
100,5000,2,20,217,65
100,5000,2,13,205,63
100,5000,2,14,204,60
100,5000,2,13,205,55
100,5000,2,14,210,59
100,5000,2,13,201,59
100,5000,2,13,211,59
100,5000,2,13,217,59
100,5000,2,14,207,60
100,5000,2,14,209,59
100,5000,2,14,214,61
100,5000,2,13,210,61
100,5000,1,24,376,60
100,5000,1,21,393,58
100,5000,1,20,386,58
100,5000,1,22,384,59
100,5000,1,21,387,59
100,4000,4,18,112,70
100,4000,4,13,118,62
100,4000,4,12,114,63
100,4000,4,14,112,65
100,4000,4,12,113,62
100,4000,4,14,109,61
100,4000,4,13,118,63
100,4000,4,12,112,61
100,4000,4,12,110,61
100,4000,4,11,111,63
100,4000,4,13,112,67
100,4000,4,12,110,60
100,4000,4,12,113,60
100,3000,4,19,100,66
100,3000,4,13,99,64
100,3000,4,12,100,65
100,3000,4,13,103,61
100,3000,4,14,104,63
100,3000,4,14,99,63
100,2000,4,18,90,67
100,2000,4,13,86,65
100,2000,4,13,85,63
100,2000,4,14,87,62
100,2000,4,13,85,61
100,2000,4,13,86,64
100,2000,4,13,85,61
100,2000,4,13,89,58
100,2000,4,13,85,61
100,2000,4,12,85,60
100,2000,4,13,85,66
100,2000,4,12,86,59
100,2000,4,12,86,61
100,2000,4,12,82,60
100,2000,4,13,87,62
100,2000,4,12,83,65
100,2000,4,12,85,60
100,2000,4,13,87,60
100,2000,4,12,86,59
100,1000,4,19,75,63
100,1000,4,14,71,61
100,1000,4,13,75,59
100,1000,4,14,72,61
100,1000,4,13,72,59
100,1000,4,13,71,59
100,1000,4,13,72,59
100,1000,4,14,70,62
100,1000,4,12,72,58
100,1000,4,13,71,59
100,1000,4,13,70,62
100,1000,4,12,72,59
100,1000,4,20,71,58
100,1000,4,13,69,60
100,1000,4,12,73,60
100,1000,4,13,69,59
100,1000,4,13,71,60
100,1000,4,13,73,62
100,1000,4,12,71,59
100,1000,4,12,70,56
100,1000,4,13,70,58
100,1000,4,12,69,57
运行方法
go test -v -run=FitClass
最终输出的数据和scipy的结果差不多
程序输出
输出的拟合图像如下
有疑问加站长微信联系(非本文作者)