# 使用golang做最小二乘法的线性拟合

FredricZhu · · 216 次点击 · · 开始浏览

const.go

``````package main

var (
ColNames = []string{"feature", "document", "machine", "load_time",
"search_time", "reduce_and_save"}

ResColNames = []string{"feature", "document", "machine", "total"}
)
``````

fit_classification.go

``````package main

import (
"fmt"
"log"
"math"
"os"

"github.com/go-gota/gota/dataframe"
"github.com/go-gota/gota/series"
"gonum.org/v1/gonum/optimize"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/plotutil"
"gonum.org/v1/plot/vg"
)

// 根据条件修改原先值
func getTotal(s series.Series) series.Series {

searchTime, _ := s.Val(4).(int)
rAsTime, _ := s.Val(5).(int)

res := loadTime + searchTime + rAsTime
resF := float64(res) / float64(60)
return series.Floats(resF)
}

func getDoc(s series.Series) series.Series {
document, _ := s.Val(1).(float64)
resF := float64(2*document) / float64(1000)
return series.Floats(resF)
}

// dataPrepare 数据预处理函数
func dataPrepare(clsDF *dataframe.DataFrame) {
// 获取total列
*clsDF = clsDF.Select(ColNames)
totalSeries := clsDF.Rapply(getTotal)
totalSeries.SetNames("total")
*clsDF = clsDF.CBind(totalSeries)

// document列 *2/1000
*clsDF = clsDF.Select(ResColNames)
newDocSeries := clsDF.Rapply(getDoc)
newDocSeries.SetNames("new_doc")
*clsDF = clsDF.CBind(newDocSeries)
*clsDF = clsDF.Drop([]string{"document"})
*clsDF = clsDF.Rename("document", "new_doc")
*clsDF = clsDF.Select(ResColNames)
}

// dataOptimize 数据优化和拟合函数
func dataOptimize(clsDF *dataframe.DataFrame) (actPoints, expPoints plotter.XYs, fa, fb float64) {
// 开始数据拟合

// 实际观测点
actPoints = plotter.XYs{}
// N行数据产生N个点
for i := 0; i < clsDF.Nrow(); i++ {
document := clsDF.Elem(i, 1).Val().(float64)
machine := clsDF.Elem(i, 2).Val().(int)
val := clsDF.Elem(i, 3).Val().(float64)

actPoints = append(actPoints, plotter.XY{
X: float64(document) / float64(machine),
Y: val,
})
}

result, err := optimize.Minimize(optimize.Problem{
Func: func(x []float64) float64 {
if len(x) != 2 {
panic("illegal x")
}
a := x[0]
b := x[1]
var sum float64
for _, point := range actPoints {
y := a*point.X + b
sum += math.Abs(y - point.Y)
}
return sum
},
if err != nil {
panic(err)
}

// 最小二乘法拟合出来的k和b值
fa, fb = result.X[0], result.X[1]
expPoints = plotter.XYs{}
for i := 0; i < clsDF.Nrow(); i++ {
document := clsDF.Elem(i, 1).Val().(float64)
machine := clsDF.Elem(i, 2).Val().(int)
x := float64(document) / float64(machine)
expPoints = append(expPoints, plotter.XY{
X: x,
Y: fa*float64(x) + fb,
})
}

return
}

func dataPlot(actPoints, expPoints plotter.XYs) {
plt, err := plot.New()
if err != nil {
panic(err)
}
plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 10, 10

"expPoints", expPoints,
"actPoints", actPoints,
); err != nil {
panic(err)
}

if err := plt.Save(5*vg.Inch, 5*vg.Inch, "classification-fit.png"); err != nil {
panic(err)
}
}

// FitClassification 分类曲线拟合函数
func FitClassification() {
clsData, err := os.Open("classification_data.csv")
if err != nil {
log.Fatal(err)
}

defer clsData.Close()
// 数据预处理
dataPrepare(&clsDF)
// 数据预处理完成
fmt.Println("数据预处理完成...")
fmt.Println(clsDF)

// 数据拟合
actPoints, expPoints, fa, fb := dataOptimize(&clsDF)
// 拟合完成，输出fa,fb
fmt.Println("Fa", fa, "Fb", fb)

// 数据绘图
dataPlot(actPoints, expPoints)
fmt.Println("绘制完成，图形地址: classification-fit.png")
}
``````

main.go

``````package main

func main() {

}
``````

main_test.go

``````package main

import "testing"

// TestFitClassification 测试分类曲线拟合
func TestFitClassification(t *testing.T) {
FitClassification()
}
``````

``````feature,document,machine,load_time,search_time,reduce_and_save

100,5000,4,19,130,67

100,5000,4,12,130,61

100,5000,4,13,127,61

100,5000,4,13,124,63

100,5000,4,13,129,59

100,5000,4,13,125,60

100,5000,4,13,123,63

100,5000,4,13,129,61

100,5000,4,12,127,61

100,5000,4,12,125,62

100,5000,4,13,128,59

100,5000,4,13,128,61

100,5000,4,12,130,60

100,5000,4,12,125,61

100,5000,4,13,127,60

100,5000,4,13,126,63

100,5000,4,13,127,64

100,5000,3,18,160,67

100,5000,3,13,166,59

100,5000,3,12,167,61

100,5000,3,12,168,60

100,5000,3,12,170,61

100,5000,3,12,154,63

100,5000,3,13,168,60

100,5000,3,12,167,60

100,5000,3,12,148,64

100,5000,3,12,167,65

100,5000,3,12,164,60

100,5000,3,12,150,59

100,5000,2,20,217,65

100,5000,2,13,205,63

100,5000,2,14,204,60

100,5000,2,13,205,55

100,5000,2,14,210,59

100,5000,2,13,201,59

100,5000,2,13,211,59

100,5000,2,13,217,59

100,5000,2,14,207,60

100,5000,2,14,209,59

100,5000,2,14,214,61

100,5000,2,13,210,61

100,5000,1,24,376,60

100,5000,1,21,393,58

100,5000,1,20,386,58

100,5000,1,22,384,59

100,5000,1,21,387,59

100,4000,4,18,112,70

100,4000,4,13,118,62

100,4000,4,12,114,63

100,4000,4,14,112,65

100,4000,4,12,113,62

100,4000,4,14,109,61

100,4000,4,13,118,63

100,4000,4,12,112,61

100,4000,4,12,110,61

100,4000,4,11,111,63

100,4000,4,13,112,67

100,4000,4,12,110,60

100,4000,4,12,113,60

100,3000,4,19,100,66

100,3000,4,13,99,64

100,3000,4,12,100,65

100,3000,4,13,103,61

100,3000,4,14,104,63

100,3000,4,14,99,63

100,2000,4,18,90,67

100,2000,4,13,86,65

100,2000,4,13,85,63

100,2000,4,14,87,62

100,2000,4,13,85,61

100,2000,4,13,86,64

100,2000,4,13,85,61

100,2000,4,13,89,58

100,2000,4,13,85,61

100,2000,4,12,85,60

100,2000,4,13,85,66

100,2000,4,12,86,59

100,2000,4,12,86,61

100,2000,4,12,82,60

100,2000,4,13,87,62

100,2000,4,12,83,65

100,2000,4,12,85,60

100,2000,4,13,87,60

100,2000,4,12,86,59

100,1000,4,19,75,63

100,1000,4,14,71,61

100,1000,4,13,75,59

100,1000,4,14,72,61

100,1000,4,13,72,59

100,1000,4,13,71,59

100,1000,4,13,72,59

100,1000,4,14,70,62

100,1000,4,12,72,58

100,1000,4,13,71,59

100,1000,4,13,70,62

100,1000,4,12,72,59

100,1000,4,20,71,58

100,1000,4,13,69,60

100,1000,4,12,73,60

100,1000,4,13,69,59

100,1000,4,13,71,60

100,1000,4,13,73,62

100,1000,4,12,71,59

100,1000,4,12,70,56

100,1000,4,13,70,58

100,1000,4,12,69,57

``````

go test -v -run=FitClass

0 回复

• 请尽量让自己的回复能够对别人有帮助
• 支持 Markdown 格式, **粗体**、~~删除线~~、``单行代码``
• 支持 @ 本站用户；支持表情（输入 : 提示），见 Emoji cheat sheet
• 图片支持拖拽、截图粘贴等方式上传

const.go

``````package main

var (
ColNames = []string{"feature", "document", "machine", "load_time",
"search_time", "reduce_and_save"}

ResColNames = []string{"feature", "document", "machine", "total"}
)
``````

fit_classification.go

``````package main

import (
"fmt"
"log"
"math"
"os"

"github.com/go-gota/gota/dataframe"
"github.com/go-gota/gota/series"
"gonum.org/v1/gonum/optimize"
"gonum.org/v1/plot"
"gonum.org/v1/plot/plotter"
"gonum.org/v1/plot/plotutil"
"gonum.org/v1/plot/vg"
)

// 根据条件修改原先值
func getTotal(s series.Series) series.Series {

searchTime, _ := s.Val(4).(int)
rAsTime, _ := s.Val(5).(int)

res := loadTime + searchTime + rAsTime
resF := float64(res) / float64(60)
return series.Floats(resF)
}

func getDoc(s series.Series) series.Series {
document, _ := s.Val(1).(float64)
resF := float64(2*document) / float64(1000)
return series.Floats(resF)
}

// dataPrepare 数据预处理函数
func dataPrepare(clsDF *dataframe.DataFrame) {
// 获取total列
*clsDF = clsDF.Select(ColNames)
totalSeries := clsDF.Rapply(getTotal)
totalSeries.SetNames("total")
*clsDF = clsDF.CBind(totalSeries)

// document列 *2/1000
*clsDF = clsDF.Select(ResColNames)
newDocSeries := clsDF.Rapply(getDoc)
newDocSeries.SetNames("new_doc")
*clsDF = clsDF.CBind(newDocSeries)
*clsDF = clsDF.Drop([]string{"document"})
*clsDF = clsDF.Rename("document", "new_doc")
*clsDF = clsDF.Select(ResColNames)
}

// dataOptimize 数据优化和拟合函数
func dataOptimize(clsDF *dataframe.DataFrame) (actPoints, expPoints plotter.XYs, fa, fb float64) {
// 开始数据拟合

// 实际观测点
actPoints = plotter.XYs{}
// N行数据产生N个点
for i := 0; i < clsDF.Nrow(); i++ {
document := clsDF.Elem(i, 1).Val().(float64)
machine := clsDF.Elem(i, 2).Val().(int)
val := clsDF.Elem(i, 3).Val().(float64)

actPoints = append(actPoints, plotter.XY{
X: float64(document) / float64(machine),
Y: val,
})
}

result, err := optimize.Minimize(optimize.Problem{
Func: func(x []float64) float64 {
if len(x) != 2 {
panic("illegal x")
}
a := x[0]
b := x[1]
var sum float64
for _, point := range actPoints {
y := a*point.X + b
sum += math.Abs(y - point.Y)
}
return sum
},
if err != nil {
panic(err)
}

// 最小二乘法拟合出来的k和b值
fa, fb = result.X[0], result.X[1]
expPoints = plotter.XYs{}
for i := 0; i < clsDF.Nrow(); i++ {
document := clsDF.Elem(i, 1).Val().(float64)
machine := clsDF.Elem(i, 2).Val().(int)
x := float64(document) / float64(machine)
expPoints = append(expPoints, plotter.XY{
X: x,
Y: fa*float64(x) + fb,
})
}

return
}

func dataPlot(actPoints, expPoints plotter.XYs) {
plt, err := plot.New()
if err != nil {
panic(err)
}
plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 10, 10

"expPoints", expPoints,
"actPoints", actPoints,
); err != nil {
panic(err)
}

if err := plt.Save(5*vg.Inch, 5*vg.Inch, "classification-fit.png"); err != nil {
panic(err)
}
}

// FitClassification 分类曲线拟合函数
func FitClassification() {
clsData, err := os.Open("classification_data.csv")
if err != nil {
log.Fatal(err)
}

defer clsData.Close()
// 数据预处理
dataPrepare(&clsDF)
// 数据预处理完成
fmt.Println("数据预处理完成...")
fmt.Println(clsDF)

// 数据拟合
actPoints, expPoints, fa, fb := dataOptimize(&clsDF)
// 拟合完成，输出fa,fb
fmt.Println("Fa", fa, "Fb", fb)

// 数据绘图
dataPlot(actPoints, expPoints)
fmt.Println("绘制完成，图形地址: classification-fit.png")
}
``````

main.go

``````package main

func main() {

}
``````

main_test.go

``````package main

import "testing"

// TestFitClassification 测试分类曲线拟合
func TestFitClassification(t *testing.T) {
FitClassification()
}
``````

``````feature,document,machine,load_time,search_time,reduce_and_save

100,5000,4,19,130,67

100,5000,4,12,130,61

100,5000,4,13,127,61

100,5000,4,13,124,63

100,5000,4,13,129,59

100,5000,4,13,125,60

100,5000,4,13,123,63

100,5000,4,13,129,61

100,5000,4,12,127,61

100,5000,4,12,125,62

100,5000,4,13,128,59

100,5000,4,13,128,61

100,5000,4,12,130,60

100,5000,4,12,125,61

100,5000,4,13,127,60

100,5000,4,13,126,63

100,5000,4,13,127,64

100,5000,3,18,160,67

100,5000,3,13,166,59

100,5000,3,12,167,61

100,5000,3,12,168,60

100,5000,3,12,170,61

100,5000,3,12,154,63

100,5000,3,13,168,60

100,5000,3,12,167,60

100,5000,3,12,148,64

100,5000,3,12,167,65

100,5000,3,12,164,60

100,5000,3,12,150,59

100,5000,2,20,217,65

100,5000,2,13,205,63

100,5000,2,14,204,60

100,5000,2,13,205,55

100,5000,2,14,210,59

100,5000,2,13,201,59

100,5000,2,13,211,59

100,5000,2,13,217,59

100,5000,2,14,207,60

100,5000,2,14,209,59

100,5000,2,14,214,61

100,5000,2,13,210,61

100,5000,1,24,376,60

100,5000,1,21,393,58

100,5000,1,20,386,58

100,5000,1,22,384,59

100,5000,1,21,387,59

100,4000,4,18,112,70

100,4000,4,13,118,62

100,4000,4,12,114,63

100,4000,4,14,112,65

100,4000,4,12,113,62

100,4000,4,14,109,61

100,4000,4,13,118,63

100,4000,4,12,112,61

100,4000,4,12,110,61

100,4000,4,11,111,63

100,4000,4,13,112,67

100,4000,4,12,110,60

100,4000,4,12,113,60

100,3000,4,19,100,66

100,3000,4,13,99,64

100,3000,4,12,100,65

100,3000,4,13,103,61

100,3000,4,14,104,63

100,3000,4,14,99,63

100,2000,4,18,90,67

100,2000,4,13,86,65

100,2000,4,13,85,63

100,2000,4,14,87,62

100,2000,4,13,85,61

100,2000,4,13,86,64

100,2000,4,13,85,61

100,2000,4,13,89,58

100,2000,4,13,85,61

100,2000,4,12,85,60

100,2000,4,13,85,66

100,2000,4,12,86,59

100,2000,4,12,86,61

100,2000,4,12,82,60

100,2000,4,13,87,62

100,2000,4,12,83,65

100,2000,4,12,85,60

100,2000,4,13,87,60

100,2000,4,12,86,59

100,1000,4,19,75,63

100,1000,4,14,71,61

100,1000,4,13,75,59

100,1000,4,14,72,61

100,1000,4,13,72,59

100,1000,4,13,71,59

100,1000,4,13,72,59

100,1000,4,14,70,62

100,1000,4,12,72,58

100,1000,4,13,71,59

100,1000,4,13,70,62

100,1000,4,12,72,59

100,1000,4,20,71,58

100,1000,4,13,69,60

100,1000,4,12,73,60

100,1000,4,13,69,59

100,1000,4,13,71,60

100,1000,4,13,73,62

100,1000,4,12,71,59

100,1000,4,12,70,56

100,1000,4,13,70,58

100,1000,4,12,69,57

``````

go test -v -run=FitClass