程序实现了基本的GET、POST方式路由,不依赖net/http包。
程序只接收content-type为application/json时的POST参数,返回的数据也仅支持json格式。程序仅支持GET、POST方式路由。
router.go:注册路由、启动服务
package http_server
import (
"fmt"
"net"
)
type handlerFunc func(*Request)
type methodTree struct {
method string
nodes []Node
}
type Node struct {
path string
handle handlerFunc
}
type Router struct {
Trees []methodTree
}
//单条请求数据大小为 40k
const MaxRequestSize = 1024 * 40
func Default() *Router {
return &Router{}
}
func (r *Router) Run(addr string) error {
listener, err := net.Listen("tcp4", addr)
if err != nil {
return fmt.Errorf("listen error:%v", err)
}
for {
conn, err := listener.Accept()
if err != nil {
fmt.Println(err)
continue
}
go r.handle(conn)
}
}
func (r *Router) handle(conn net.Conn) {
accept := make(chan readerData)
go r.parseRequest(accept, conn)
reader := NewReader(conn, MaxRequestSize)
err := reader.read(accept)
if err != nil {
fmt.Println(err)
//读取数据失败,响应400 bad request错误
response := newResponse(conn)
response.errWrite(400)
}
conn.Close()
close(accept)
}
//监听通道,解析请求数据
func (r *Router) parseRequest(accept chan readerData, conn net.Conn) {
for {
data, isOk := <- accept
if !isOk {
return
}
request := newRequest()
request.response = newResponse(conn)
request.parse(data)
r.handleHTTPRequest(request)
}
}
//调用函数处理http请求
func (r *Router) handleHTTPRequest(request *Request) {
httpMethod := request.Method
for _, tree := range r.Trees {
if tree.method != httpMethod {
continue
}
for _, node := range tree.nodes {
if node.path == request.Path {
node.handle(request)
request.response.write()
return
}
}
}
//未找到任何handle 返回404
request.response.errWrite(404)
}
//设置post路由
func (r *Router) POST(path string, handle handlerFunc) {
r.addRoute("POST", path, handle)
}
//设置get路由
func (r *Router) GET(path string, handle handlerFunc) {
r.addRoute("GET", path, handle)
}
//添加路由
func (r *Router) addRoute(method string, path string, handle handlerFunc) {
var newNodes bool
for k, v := range r.Trees {
if method == v.method {
r.Trees[k].nodes = append(v.nodes, Node{
path: path,
handle: handle,
})
newNodes = true
break
}
}
if !newNodes {
tree := methodTree{
method: method,
}
tree.nodes = append(tree.nodes, Node{
path: path,
handle: handle,
})
r.Trees = append(r.Trees, tree)
}
}
reader.go:读取并解析HTTP请求行、请求头、请求体
package http_server
import (
"bytes"
"fmt"
"net"
"strconv"
"strings"
)
type Reader struct {
conn net.Conn
readerData
buff []byte
buffLen int
start int
end int
}
type readerData struct {
Line map[string]string //请求行
Header map[string]string //请求头
Body string //请求体
}
//实例化
func NewReader(conn net.Conn, buffLen int) *Reader {
return &Reader{
conn: conn,
readerData: readerData{
Line: make(map[string]string),
Header: make(map[string]string),
},
buffLen: buffLen,
buff: make([]byte, buffLen),
}
}
//读取并解析请求行
func (reader *Reader) parseLine() (isOK bool, err error) {
index := bytes.Index(reader.buff, []byte{byte('\r'), byte('\n')})
if index == -1 {
//没有解析到\r\n返回继续读取
return
}
//读取请求行
requestLine := string(reader.buff[:index])
arr := strings.Split(requestLine, " ")
if len(arr) != 3 {
return false, fmt.Errorf("bad request line")
}
reader.Line["method"] = arr[0]
reader.Line["url"] = arr[1]
reader.Line["version"] = arr[2]
reader.start = index + 2
return true, nil
}
//读取并解析请求头
func (reader *Reader) parseHeader() {
if reader.start == reader.end {
return
}
index := bytes.Index(reader.buff[reader.start:], []byte{byte('\r'), byte('\n'), byte('\r'), byte('\n')})
if index == -1 {
return
}
headerStr := string(reader.buff[reader.start:reader.start+index])
requestHeader := strings.Split(headerStr, "\r\n")
for _, v := range requestHeader {
arr := strings.Split(v, ":")
if len(arr) < 2 {
continue
}
reader.Header[strings.ToUpper(arr[0])] = strings.ToLower(strings.Trim(strings.Join(arr[1:], ":"), " "))
}
reader.start += index + 4
}
//读取并解析请求体
func (reader *Reader) parseBody() (isOk bool, err error) {
//判断请求头中是否指明了请求体的数据长度
contentLenStr, ok := reader.Header["CONTENT-LENGTH"]
if !ok {
return false, fmt.Errorf("bad request:no content-length")
}
contentLen, err := strconv.ParseInt(contentLenStr, 10, 64)
if err != nil {
return false, fmt.Errorf("parse content-length error:%s", contentLenStr)
}
if contentLen > int64(reader.end - reader.start) {
//请求体长度不够,返回继续读取
return false, nil
}
reader.Body = string(reader.buff[reader.start:int64(reader.start)+contentLen])
return true, nil
}
//读取http请求
func (reader *Reader) read(accept chan readerData) (err error) {
for {
if reader.end == reader.buffLen {
//缓冲区的容量存不了一条请求的数据
return fmt.Errorf("request is too large:%v", reader)
}
buffLen, err := reader.conn.Read(reader.buff)
if err != nil {
//连接关闭了
return nil
}
reader.end += buffLen
//解析请求行
isOk, err := reader.parseLine()
if err != nil {
return fmt.Errorf("parse request line error:%v", err)
}
if !isOk {
continue
}
//解析请求头
reader.parseHeader()
//如果是post请求,解析请求体
if len(reader.Header) > 0 && strings.EqualFold(strings.ToUpper(reader.Line["method"]), "POST") {
isOk, err := reader.parseBody()
if err != nil {
return fmt.Errorf("parse request body error:%v", err)
}
//读取http请求体未成功
if !isOk {
reader.start = 0
reader.Line = make(map[string]string)
reader.Header = make(map[string]string)
continue
}
}
accept <- reader.readerData
reader.move()
}
}
//前移上一次未处理完的数据
func (reader *Reader) move() {
if reader.start == 0 {
return
}
copy(reader.buff, reader.buff[reader.start:reader.end])
reader.end -= reader.start
reader.start = 0
}
request.go:解析请求头、请求参数等
package http_server
import (
"encoding/json"
"strings"
)
type H map[string]interface{}
type Request struct {
Path string
Method string
headers map[string]string
queries map[string]string
posts map[string]string
*response
}
func newRequest() *Request {
return &Request{
headers: make(map[string]string),
queries: make(map[string]string),
posts: make(map[string]string),
}
}
//解析请求内容
func (request *Request) parse(readerData readerData) {
request.Method = readerData.Line["method"]
request.headers = readerData.Header
//解析请求path和get参数
var queries string
index := strings.Index(readerData.Line["url"], "?")
if index == -1 {
request.Path = readerData.Line["url"]
}else {
request.Path = readerData.Line["url"][:index]
queries = readerData.Line["url"][index+1:]
}
if request.Method == "GET" {
//解析get请求参数
if queries != "" {
q := strings.Split(queries, "&")
for _, v := range q {
param := strings.Split(v, "=")
request.queries[param[0]] = param[1]
}
}
}else {
//判断content-type类型是不是 application/json
contentTypes, isExist := request.headers["CONTENT-TYPE"]
if isExist {
cTypeArr := strings.Split(contentTypes, ";")
if strings.EqualFold(cTypeArr[0], "application/json") {
//解析post请求参数
json.Unmarshal([]byte(readerData.Body), &(request.posts))
}
}
}
}
//获取get请求参数
func (request *Request) Query(name string) string {
val, isExist := request.queries[name]
if isExist {
return val
}
return ""
}
//获取post请求参数
func (request *Request) Post(name string) string {
val, isExist := request.posts[name]
if isExist {
return val
}
return ""
}
//获取get请求参数
func (request *Request) DefaultQuery(name, def string) string {
val, isExist := request.queries[name]
if isExist {
return val
}
return def
}
//获取post请求参数
func (request *Request) DefaultPost(name, def string) string {
val, isExist := request.posts[name]
if isExist {
return val
}
return def
}
//获取请求头
func (request *Request) GetHeader(name string) string {
val, isExist := request.posts[strings.ToUpper(name)]
if isExist {
return val
}
return ""
}
//设置要返回的json数据
func (request *Request) Json(code int, obj interface{}) {
ret, err := json.Marshal(obj)
if err == nil {
//设置content-length
request.response.bodyLen = len(ret)
request.response.body = ret
}
request.response.status = code
}
//设置响应头
func (request *Request) Header(name string, val string) {
if _, isExist := request.response.headers[name]; !isExist {
request.response.headers[strings.ToLower(name)] = val
}
}
response.go:构造HTTP响应
package http_server
import (
"fmt"
"net"
"strconv"
)
type response struct {
status int
body []byte
bodyLen int
headers map[string]string
buff []byte
conn net.Conn
}
func newResponse (conn net.Conn) *response {
return &response{
conn: conn,
headers: make(map[string]string),
}
}
//响应行
func (response *response) writeLine() {
line := fmt.Sprintf("HTTP/1.1 %d OK\r\n", response.status)
response.buff = append(response.buff, []byte(line)...)
}
//响应头
func (response *response) writeHeader() {
response.headers["server"] = "^_^"
response.headers["content-type"] = "application/json"
response.headers["content-length"] = strconv.FormatInt(int64(response.bodyLen), 10)
for k, v := range response.headers {
response.buff = append(response.buff, []byte(fmt.Sprintf("%s: %v\r\n", k, v))...)
}
response.buff = append(response.buff, []byte("\r\n")...)
}
func (response *response) write() {
response.writeLine()
response.writeHeader()
response.buff = append(response.buff, response.body...)
response.conn.Write(response.buff)
}
func (response *response) errWrite(status int) {
response.status = status
response.body = []byte("Request Error")
response.bodyLen = len(response.body)
response.write()
}
项目放在Github上,欢迎给star~
有疑问加站长微信联系(非本文作者)