## 引言
> 简单说下本章的重点
* 1、修改路由存储数据结构(由字典变成前缀树)
* 2、路由支持动态参数
* 3、优化Context,将动态参数的键值对存储起来
* 4、编写单元测试、执行案例测试
* 5、[代码地址 https://github.com/18211167516/go-Ebb/tree/master/day3-router](https://github.com/18211167516/go-Ebb/tree/master/day3-router)
## 1、Trie 前缀树实现
```golang
package ebb
import (
//"fmt"
"strings"
)
type node struct {
pattern string // 待匹配路由,例如 /p/:lang
part string // 路由中的一部分,例如 :lang
children []*node // 子节点,例如 [doc, tutorial, intro]
isWild bool // 是否精确匹配,part 含有 : 或 * 时为true
}
func (n *node) matchChild(part string) *node {
for _, child := range n.children {
if child.part == part || child.isWild {
return child
}
}
return nil
}
// 所有匹配成功的节点,用于查找
func (n *node) matchChildren(part string) []*node {
nodes := make([]*node, 0)
for _, child := range n.children {
if child.part == part || child.isWild {
nodes = append(nodes, child)
}
}
return nodes
}
func parsePattern(pattern string) []string {
vs := strings.Split(pattern, "/")
parts := make([]string, 0)
for _, item := range vs {
if item != "" {
parts = append(parts, item)
if item[0] == '*' {
break
}
}
}
return parts
}
func (n *node) insert(pattern string, parts []string, height int) {
if len(parts) == height {
n.pattern = pattern
return
}
part := parts[height]
child := n.matchChild(part)
if child == nil {
child = &node{part: part, isWild: part[0] == ':' || part[0] == '*'}
n.children = append(n.children, child)
}
child.insert(pattern, parts, height+1)
}
func (n *node) search(parts []string, height int) *node {
if len(parts) == height || strings.HasPrefix(n.part, "*") {
if n.pattern == "" {
return nil
}
return n
}
part := parts[height]
children := n.matchChildren(part)
for _, child := range children {
result := child.search(parts, height+1)
if result != nil {
return result
}
}
return nil
}
```
## 2、优化路由类
```golang
package ebb
import (
"strings"
)
type router struct{
roots map[string]*node
handlers map[string]HandlerFunc
}
func newRouter() *router{
return &router{
roots:make(map[string]*node),
handlers:make(map[string]HandlerFunc),
}
}
func (r *router) addRoute(method string,pattern string,handler HandlerFunc){
parts := parsePattern(pattern)
key := method + "-" + pattern
_, ok := r.roots[method]
if !ok {
r.roots[method] = &node{}
}
r.roots[method].insert(pattern, parts, 0)
r.handlers[key] = handler
}
func (r *router) getRoute(method string, pattern string) (*node, map[string]interface{}) {
searchParts := parsePattern(pattern)
params := make(map[string]interface{})
root, ok := r.roots[method]
if !ok {
return nil, nil
}
n := root.search(searchParts, 0)
if n != nil {
parts := parsePattern(n.pattern)
for index, part := range parts {
if part[0] == ':' {
params[part[1:]] = searchParts[index]
}
if part[0] == '*' && len(part) > 1 {
params[part[1:]] = strings.Join(searchParts[index:], "/")
break
}
}
return n, params
}
return nil, nil
}
```
## 3、优化核心ebb的handleHTTPRequest
```golang
func (engine *Engine) handleHTTPRequest(c *Context){
node,params := engine.router.getRoute(c.Method,c.Path)
if node != nil {
c.Params = params //将params的值存储到Context
key := c.Method + "-" + node.pattern
//调用实际的方法
engine.router.handlers[key](c)
}else{
c.String(http.StatusNotFound, "404 NOT FOUND: %s\n", c.Path)
}
}
```
## 4、优化Context将动态参数存储以及获取
```golang
package ebb
import (
"net/http"
"fmt"
"encoding/json"
)
type H map[string]interface{}
type Context struct{
//write and request
Writer http.ResponseWriter
Request *http.Request
//request info
Method string
Path string
Params map[string]interface{}
}
func newContext(w http.ResponseWriter,r *http.Request) *Context{
context := &Context{
Writer:w,
Request:r,
Path: r.URL.Path,
Method: r.Method,
Params: make(map[string]interface{}),
}
return context
}
func (c *Context) Param(key string) string{
value, _ := c.Params[key].(string)
return value
}
func (c *Context) PostForm(key string) string {
return c.Request.FormValue(key)
}
func (c *Context) Query(key string) string {
return c.Request.URL.Query().Get(key)
}
func (c *Context) Status(code int) {
c.Writer.WriteHeader(code)
}
func (c *Context) SetHeader(key string, value string) {
if value == "" {
c.Writer.Header().Del(key)
return
}
c.Writer.Header().Set(key, value)
}
func (c *Context) GetHeader(key string) string{
return c.Request.Header.Get(key)
}
func (c *Context) Write(data []byte){
c.Writer.Write(data)
}
func (c *Context) String(code int,message string,v ...interface{}){
c.SetHeader("Content-Type", "text/plain")
c.Status(code)
c.Write([]byte(fmt.Sprintf(message, v...)))
}
func (c *Context) JSON(code int, obj interface{}) {
c.SetHeader("Content-Type", "application/json")
c.Status(code)
data,err:= json.Marshal(obj)
if err!=nil {
http.Error(c.Writer, err.Error(), 500)
}
c.Write(data)
}
func (c *Context) HTML(code int,html string){
c.SetHeader("Content-Type", "text/html")
c.Status(code)
c.Write([]byte(html))
}
```
## 5、编写单元测试
```golang
package ebb
import (
"net/http/httptest"
"testing"
)
func PerformRequest(mothod string,url string ,body io.Reader,e *Engine) (w *httptest.ResponseRecorder){
w = httptest.NewRecorder()
r := httptest.NewRequest(mothod, url, body)
r.Header.Set("Content-Type", "application/json")
e.ServeHTTP(w,r)
return w
}
```
```golang
package ebb
import (
"testing"
"bytes"
"encoding/json"
"github.com/stretchr/testify/assert"
)
func TestGetRoute(t *testing.T) {
r := newRouter()
r.addRoute("GET", "/", nil)
r.addRoute("GET", "/hello/:name", nil)
r.addRoute("GET", "/hello/b/c", nil)
n, params := r.getRoute("GET", "/index/baibai")
if assert.NotNil(t, n,"404 not found ") {
assert.Equal(t,n.pattern,"/hello/:name","should match /hello/:name")
assert.Equal(t,params["name"],"baibai","name should be equel to 'baibai'")
}
}
func TestParsePattern(t *testing.T) {
assert.Equal(t,parsePattern("/p/:name"),[]string{"p",":name"},"not parsePattern :name")
assert.Equal(t,parsePattern("/p/*"),[]string{"p","*"},"not parsePattern *")
assert.Equal(t,parsePattern("/p/*name/*"),[]string{"p","*name"},"parsePattern not truncation")
}
func TestRouter(t *testing.T){
r := New()
r.POST("/login/*filepath", func(c *Context) {
c.JSON(200, H{
"name": c.PostForm("name"),
})
})
param := `{"name":"56789","state":3}`
w := PerformRequest("POST","/login/123213?name=1233",bytes.NewBufferString(param),r)
s := struct{
Name string `json:"name"`
}{}
json.Unmarshal([]byte(w.Body.String()),&s)
assert.Equal(t,s.Name,"1233","PostForm error")
}
```
## 6、在本章节我们主要优化了路由,并支持动态参数
> 接下来我们会陆续实现路由的分组、中间件等等
有疑问加站长微信联系(非本文作者))