使用golang做最小二乘法的线性拟合

FredricZhu · · 105 次点击 · · 开始浏览    

const.go

package main

var (
    ColNames = []string{"feature", "document", "machine", "load_time",
        "search_time", "reduce_and_save"}

    ResColNames = []string{"feature", "document", "machine", "total"}
)

fit_classification.go

package main

import (
    "fmt"
    "log"
    "math"
    "os"

    "github.com/go-gota/gota/dataframe"
    "github.com/go-gota/gota/series"
    "gonum.org/v1/gonum/optimize"
    "gonum.org/v1/plot"
    "gonum.org/v1/plot/plotter"
    "gonum.org/v1/plot/plotutil"
    "gonum.org/v1/plot/vg"
)

// 根据条件修改原先值
func getTotal(s series.Series) series.Series {

    loadTime, _ := s.Val(3).(int)
    searchTime, _ := s.Val(4).(int)
    rAsTime, _ := s.Val(5).(int)

    res := loadTime + searchTime + rAsTime
    resF := float64(res) / float64(60)
    return series.Floats(resF)
}

func getDoc(s series.Series) series.Series {
    document, _ := s.Val(1).(float64)
    resF := float64(2*document) / float64(1000)
    return series.Floats(resF)
}

// dataPrepare 数据预处理函数
func dataPrepare(clsDF *dataframe.DataFrame) {
    // 获取total列
    *clsDF = clsDF.Select(ColNames)
    totalSeries := clsDF.Rapply(getTotal)
    totalSeries.SetNames("total")
    *clsDF = clsDF.CBind(totalSeries)

    // document列 *2/1000
    *clsDF = clsDF.Select(ResColNames)
    newDocSeries := clsDF.Rapply(getDoc)
    newDocSeries.SetNames("new_doc")
    *clsDF = clsDF.CBind(newDocSeries)
    *clsDF = clsDF.Drop([]string{"document"})
    *clsDF = clsDF.Rename("document", "new_doc")
    *clsDF = clsDF.Select(ResColNames)
}

// dataOptimize 数据优化和拟合函数
func dataOptimize(clsDF *dataframe.DataFrame) (actPoints, expPoints plotter.XYs, fa, fb float64) {
    // 开始数据拟合

    // 实际观测点
    actPoints = plotter.XYs{}
    // N行数据产生N个点
    for i := 0; i < clsDF.Nrow(); i++ {
        document := clsDF.Elem(i, 1).Val().(float64)
        machine := clsDF.Elem(i, 2).Val().(int)
        val := clsDF.Elem(i, 3).Val().(float64)

        actPoints = append(actPoints, plotter.XY{
            X: float64(document) / float64(machine),
            Y: val,
        })
    }

    result, err := optimize.Minimize(optimize.Problem{
        Func: func(x []float64) float64 {
            if len(x) != 2 {
                panic("illegal x")
            }
            a := x[0]
            b := x[1]
            var sum float64
            for _, point := range actPoints {
                y := a*point.X + b
                sum += math.Abs(y - point.Y)
            }
            return sum
        },
    }, []float64{1, 1}, &optimize.Settings{}, &optimize.NelderMead{})
    if err != nil {
        panic(err)
    }

    // 最小二乘法拟合出来的k和b值
    fa, fb = result.X[0], result.X[1]
    expPoints = plotter.XYs{}
    for i := 0; i < clsDF.Nrow(); i++ {
        document := clsDF.Elem(i, 1).Val().(float64)
        machine := clsDF.Elem(i, 2).Val().(int)
        x := float64(document) / float64(machine)
        expPoints = append(expPoints, plotter.XY{
            X: x,
            Y: fa*float64(x) + fb,
        })
    }

    return
}

func dataPlot(actPoints, expPoints plotter.XYs) {
    plt, err := plot.New()
    if err != nil {
        panic(err)
    }
    plt.Y.Min, plt.X.Min, plt.Y.Max, plt.X.Max = 0, 0, 10, 10

    if err := plotutil.AddLinePoints(plt,
        "expPoints", expPoints,
        "actPoints", actPoints,
    ); err != nil {
        panic(err)
    }

    if err := plt.Save(5*vg.Inch, 5*vg.Inch, "classification-fit.png"); err != nil {
        panic(err)
    }
}

// FitClassification 分类曲线拟合函数
func FitClassification() {
    clsData, err := os.Open("classification_data.csv")
    if err != nil {
        log.Fatal(err)
    }

    defer clsData.Close()
    clsDF := dataframe.ReadCSV(clsData)
    // 数据预处理
    dataPrepare(&clsDF)
    // 数据预处理完成
    fmt.Println("数据预处理完成...")
    fmt.Println(clsDF)

    // 数据拟合
    actPoints, expPoints, fa, fb := dataOptimize(&clsDF)
    // 拟合完成,输出fa,fb
    fmt.Println("Fa", fa, "Fb", fb)

    // 数据绘图
    dataPlot(actPoints, expPoints)
    fmt.Println("绘制完成,图形地址: classification-fit.png")
}

main.go

package main

func main() {

}

main_test.go

package main

import "testing"

// TestFitClassification 测试分类曲线拟合
func TestFitClassification(t *testing.T) {
    FitClassification()
}

运行数据

feature,document,machine,load_time,search_time,reduce_and_save

100,5000,4,19,130,67

100,5000,4,12,130,61

100,5000,4,13,127,61

100,5000,4,13,124,63

100,5000,4,13,129,59

100,5000,4,13,125,60

100,5000,4,13,123,63

100,5000,4,13,129,61

100,5000,4,12,127,61

100,5000,4,12,125,62

100,5000,4,13,128,59

100,5000,4,13,128,61

100,5000,4,12,130,60

100,5000,4,12,125,61

100,5000,4,13,127,60

100,5000,4,13,126,63

100,5000,4,13,127,64

100,5000,3,18,160,67

100,5000,3,13,166,59

100,5000,3,12,167,61

100,5000,3,12,168,60

100,5000,3,12,170,61

100,5000,3,12,154,63

100,5000,3,13,168,60

100,5000,3,12,167,60

100,5000,3,12,148,64

100,5000,3,12,167,65

100,5000,3,12,164,60

100,5000,3,12,150,59

100,5000,2,20,217,65

100,5000,2,13,205,63

100,5000,2,14,204,60

100,5000,2,13,205,55

100,5000,2,14,210,59

100,5000,2,13,201,59

100,5000,2,13,211,59

100,5000,2,13,217,59

100,5000,2,14,207,60

100,5000,2,14,209,59

100,5000,2,14,214,61

100,5000,2,13,210,61

100,5000,1,24,376,60

100,5000,1,21,393,58

100,5000,1,20,386,58

100,5000,1,22,384,59

100,5000,1,21,387,59

100,4000,4,18,112,70

100,4000,4,13,118,62

100,4000,4,12,114,63

100,4000,4,14,112,65

100,4000,4,12,113,62

100,4000,4,14,109,61

100,4000,4,13,118,63

100,4000,4,12,112,61

100,4000,4,12,110,61

100,4000,4,11,111,63

100,4000,4,13,112,67

100,4000,4,12,110,60

100,4000,4,12,113,60

100,3000,4,19,100,66

100,3000,4,13,99,64

100,3000,4,12,100,65

100,3000,4,13,103,61

100,3000,4,14,104,63

100,3000,4,14,99,63

100,2000,4,18,90,67

100,2000,4,13,86,65

100,2000,4,13,85,63

100,2000,4,14,87,62

100,2000,4,13,85,61

100,2000,4,13,86,64

100,2000,4,13,85,61

100,2000,4,13,89,58

100,2000,4,13,85,61

100,2000,4,12,85,60

100,2000,4,13,85,66

100,2000,4,12,86,59

100,2000,4,12,86,61

100,2000,4,12,82,60

100,2000,4,13,87,62

100,2000,4,12,83,65

100,2000,4,12,85,60

100,2000,4,13,87,60

100,2000,4,12,86,59

100,1000,4,19,75,63

100,1000,4,14,71,61

100,1000,4,13,75,59

100,1000,4,14,72,61

100,1000,4,13,72,59

100,1000,4,13,71,59

100,1000,4,13,72,59

100,1000,4,14,70,62

100,1000,4,12,72,58

100,1000,4,13,71,59

100,1000,4,13,70,62

100,1000,4,12,72,59

100,1000,4,20,71,58

100,1000,4,13,69,60

100,1000,4,12,73,60

100,1000,4,13,69,59

100,1000,4,13,71,60

100,1000,4,13,73,62

100,1000,4,12,71,59

100,1000,4,12,70,56

100,1000,4,13,70,58

100,1000,4,12,69,57

运行方法
go test -v -run=FitClass
最终输出的数据和scipy的结果差不多

程序输出


图片.png

输出的拟合图像如下


图片.png

有疑问加站长微信联系(非本文作者)

本文来自:简书

感谢作者:FredricZhu

查看原文:使用golang做最小二乘法的线性拟合

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:1006366459

105 次点击  
加入收藏 微博
暂无回复
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传