教你如何搭建自己的go-gin框架(三 动态路由设计)

18211167516 · 2020-08-10 14:55:15 · 1676 次点击 · 预计阅读时间 8 分钟 · 大约8小时之前 开始浏览    
这是一个创建于 2020-08-10 14:55:15 的文章,其中的信息可能已经有所发展或是发生改变。

引言

简单说下本章的重点

1、Trie 前缀树实现

package ebb

import (
    //"fmt"
    "strings"
)

type node struct {
    pattern  string // 待匹配路由,例如 /p/:lang
    part     string // 路由中的一部分,例如 :lang
    children []*node // 子节点,例如 [doc, tutorial, intro]
    isWild   bool // 是否精确匹配,part 含有 : 或 * 时为true
}

func (n *node) matchChild(part string) *node {
    for _, child := range n.children {
        if child.part == part || child.isWild {
            return child
        }
    }
    return nil
}
// 所有匹配成功的节点,用于查找
func (n *node) matchChildren(part string) []*node {
    nodes := make([]*node, 0)
    for _, child := range n.children {
        if child.part == part || child.isWild {
            nodes = append(nodes, child)
        }
    }
    return nodes
}

func parsePattern(pattern string) []string {
    vs := strings.Split(pattern, "/")
    parts := make([]string, 0)
    for _, item := range vs {
        if item != "" {
            parts = append(parts, item)
            if item[0] == '*' {
                break
            }
        }
    }

    return parts
}


func (n *node) insert(pattern string, parts []string, height int) {
    if len(parts) == height {
        n.pattern = pattern
        return
    }

    part := parts[height]
    child := n.matchChild(part)
    if child == nil {
        child = &node{part: part, isWild: part[0] == ':' || part[0] == '*'}
        n.children = append(n.children, child)
    }
    child.insert(pattern, parts, height+1)
}

func (n *node) search(parts []string, height int) *node {
    if len(parts) == height || strings.HasPrefix(n.part, "*") {
        if n.pattern == "" {
            return nil

        }
        return n
    }

    part := parts[height]
    children := n.matchChildren(part)

    for _, child := range children {
        result := child.search(parts, height+1)
        if result != nil {
            return result
        }
    }

    return nil
}

2、优化路由类

package ebb

import (
    "strings"
)

type router struct{
    roots map[string]*node
    handlers map[string]HandlerFunc
}

func newRouter() *router{
    return &router{
        roots:make(map[string]*node),
        handlers:make(map[string]HandlerFunc),
    }
}

func (r *router) addRoute(method string,pattern string,handler HandlerFunc){
    parts := parsePattern(pattern)
    key := method + "-" + pattern
    _, ok := r.roots[method]
    if !ok {
        r.roots[method] = &node{}
    }
    r.roots[method].insert(pattern, parts, 0)
    r.handlers[key] = handler
}

func (r *router) getRoute(method string, pattern string) (*node, map[string]interface{}) {
    searchParts := parsePattern(pattern)
    params := make(map[string]interface{})
    root, ok := r.roots[method]

    if !ok {
        return nil, nil
    }

    n := root.search(searchParts, 0)

    if n != nil {
        parts := parsePattern(n.pattern)
        for index, part := range parts {
            if part[0] == ':' {
                params[part[1:]] = searchParts[index]
            }
            if part[0] == '*' && len(part) > 1 {
                params[part[1:]] = strings.Join(searchParts[index:], "/")
                break
            }
        }
        return n, params
    }

    return nil, nil
}

3、优化核心ebb的handleHTTPRequest

func (engine *Engine) handleHTTPRequest(c *Context){
    node,params := engine.router.getRoute(c.Method,c.Path)
    if node != nil {
        c.Params = params //将params的值存储到Context
        key := c.Method + "-" + node.pattern
        //调用实际的方法
        engine.router.handlers[key](c)
    }else{
        c.String(http.StatusNotFound, "404 NOT FOUND: %s\n", c.Path)
    }
}

4、优化Context将动态参数存储以及获取

package ebb


import (
    "net/http"
    "fmt"
    "encoding/json"
)

type H map[string]interface{}

type Context struct{
    //write and request
    Writer http.ResponseWriter
    Request *http.Request
    //request info
    Method string
    Path string
    Params map[string]interface{}
}


func newContext(w http.ResponseWriter,r *http.Request) *Context{
    context := &Context{
        Writer:w,
        Request:r,
        Path:   r.URL.Path,
        Method: r.Method,
        Params: make(map[string]interface{}),
    }

    return context
} 

func (c *Context) Param(key string) string{
    value, _ := c.Params[key].(string)
    return value
}

func (c *Context) PostForm(key string) string {
    return c.Request.FormValue(key)
}

func (c *Context) Query(key string) string {
    return c.Request.URL.Query().Get(key)
}

func (c *Context) Status(code int) {
    c.Writer.WriteHeader(code)
}

func (c *Context) SetHeader(key string, value string) {
    if value == "" {
        c.Writer.Header().Del(key)
        return
    }
    c.Writer.Header().Set(key, value)
}

func (c *Context) GetHeader(key string) string{
    return c.Request.Header.Get(key)
}

func (c *Context) Write(data []byte){
    c.Writer.Write(data)
}


func (c *Context) String(code int,message string,v ...interface{}){
    c.SetHeader("Content-Type", "text/plain")
    c.Status(code)
    c.Write([]byte(fmt.Sprintf(message, v...)))
}

func (c *Context) JSON(code int, obj interface{}) {
    c.SetHeader("Content-Type", "application/json")
    c.Status(code)
    data,err:= json.Marshal(obj)
    if err!=nil {
        http.Error(c.Writer, err.Error(), 500)
    }
    c.Write(data)
}

func (c *Context) HTML(code int,html string){
    c.SetHeader("Content-Type", "text/html")
    c.Status(code)
    c.Write([]byte(html))
}

5、编写单元测试

package ebb

import (
    "net/http/httptest"
    "testing"
)


func PerformRequest(mothod string,url string ,body io.Reader,e *Engine) (w *httptest.ResponseRecorder){
    w = httptest.NewRecorder()
    r := httptest.NewRequest(mothod, url, body)
    r.Header.Set("Content-Type", "application/json")
    e.ServeHTTP(w,r)
    return w
}
package ebb


import (
    "testing"
    "bytes"
    "encoding/json"
    "github.com/stretchr/testify/assert"
)

func TestGetRoute(t *testing.T) {
    r := newRouter()
    r.addRoute("GET", "/", nil)
    r.addRoute("GET", "/hello/:name", nil)
    r.addRoute("GET", "/hello/b/c", nil)
    n, params := r.getRoute("GET", "/index/baibai")

    if assert.NotNil(t, n,"404 not found ") {
        assert.Equal(t,n.pattern,"/hello/:name","should match /hello/:name")
        assert.Equal(t,params["name"],"baibai","name should be equel to 'baibai'")
    }

}
func TestParsePattern(t *testing.T) {
    assert.Equal(t,parsePattern("/p/:name"),[]string{"p",":name"},"not parsePattern :name")
    assert.Equal(t,parsePattern("/p/*"),[]string{"p","*"},"not parsePattern *")
    assert.Equal(t,parsePattern("/p/*name/*"),[]string{"p","*name"},"parsePattern not truncation")
}
func TestRouter(t *testing.T){
    r := New()

    r.POST("/login/*filepath", func(c *Context) {
        c.JSON(200, H{
            "name": c.PostForm("name"),
        })
    })

    param := `{"name":"56789","state":3}`

    w := PerformRequest("POST","/login/123213?name=1233",bytes.NewBufferString(param),r)

    s := struct{
        Name string `json:"name"`
    }{}
    json.Unmarshal([]byte(w.Body.String()),&s)
    assert.Equal(t,s.Name,"1233","PostForm error")
}

6、在本章节我们主要优化了路由,并支持动态参数

接下来我们会陆续实现路由的分组、中间件等等


有疑问加站长微信联系(非本文作者))

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

1676 次点击  
加入收藏 微博
暂无回复
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传