go语言websocket源码

haitgo · · 2826 次点击 · · 开始浏览    
这是一个创建于 的文章,其中的信息可能已经有所发展或是发生改变。

package websocket    
import (    
"bufio"    
"bytes"    
"crypto/sha1"    
"encoding/base64"    
"encoding/binary"    
"errors"    
"io"    
"net"    
"net/http"    
"strings"    
)    
var (    
ErrUpgrade = errors.New("Can \"Upgrade\" only to \"WebSocket\"")    
ErrConnection = errors.New("\"Connection\" must be \"Upgrade\"")    
ErrCrossOrigin = errors.New("Cross origin websockets not allowed")    
ErrSecVersion = errors.New("HTTP/1.1 Upgrade Required\r\nSec-WebSocket-Version: 13\r\n\r\n")    
ErrSecKey = errors.New("\"Sec-WebSocket-Key\" must not be nil")    
ErrHijacker = errors.New("Not implement http.Hijacker")    
)    
var (    
ErrReservedBits = errors.New("Reserved_bits show using undefined extensions")    
ErrFrameOverload = errors.New("Control frame payload overload")    
ErrFrameFragmented = errors.New("Control frame must not be fragmented")    
ErrInvalidOpcode = errors.New("Invalid frame opcode")    
)    
var (    
crlf = []byte("\r\n")    
challengeKey = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")    
)    
//referer https://github.com/Skycrab/skynet_websocket/blob/master/websocket.lua    
type WsHandler interface {    
CheckOrigin(origin, host string) bool    
OnOpen(ws *Websocket)    
OnMessage(ws *Websocket, message []byte)    
OnClose(ws *Websocket, code uint16, reason []byte)    
OnPong(ws *Websocket, data []byte)    
}    
type WsDefaultHandler struct {    
checkOriginOr bool // 是否校验origin, default true    
}    
func (wd WsDefaultHandler) CheckOrigin(origin, host string) bool {    
return true    
}    
func (wd WsDefaultHandler) OnOpen(ws *Websocket) {    
}    
func (wd WsDefaultHandler) OnMessage(ws *Websocket, message []byte) {    
}    
func (wd WsDefaultHandler) OnClose(ws *Websocket, code uint16, reason []byte) {    
}    
func (wd WsDefaultHandler) OnPong(ws *Websocket, data []byte) {    
}    
type Websocket struct {    
conn net.Conn    
rw *bufio.ReadWriter    
handler WsHandler    
clientTerminated bool    
serverTerminated bool    
maskOutgoing bool    
}    
type Option struct {    
Handler WsHandler // 处理器, default WsDefaultHandler    
MaskOutgoing bool //发送frame是否mask, default false    
}    
func challengeResponse(key, protocol string) []byte {    
sha := sha1.New()    
sha.Write([]byte(key))    
sha.Write(challengeKey)    
accept := base64.StdEncoding.EncodeToString(sha.Sum(nil))    
buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")    
buf.WriteString(accept)    
buf.Write(crlf)    
if protocol != "" {    
buf.WriteString("Sec-WebSocket-Protocol: ")    
buf.WriteString(protocol)    
buf.Write(crlf)    
}    
buf.Write(crlf)    
return buf.Bytes()    
}    
func acceptConnection(r *http.Request, h WsHandler) (challenge []byte, err error) {    
//Upgrade header should be present and should be equal to WebSocket    
if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {    
return nil, ErrUpgrade    
}    
//Connection header should be upgrade. Some proxy servers/load balancers    
// might mess with it.    
if !strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") {    
return nil, ErrConnection    
}    
// Handle WebSocket Origin naming convention differences    
// The difference between version 8 and 13 is that in 8 the    
// client sends a "Sec-Websocket-Origin" header and in 13 it's    
// simply "Origin".    
if r.Header.Get("Sec-Websocket-Version") != "13" {    
return nil, ErrSecVersion    
}    
origin := r.Header.Get("Origin")    
if origin == "" {    
origin = r.Header.Get("Sec-Websocket-Origin")    
}    
if origin != "" && !h.CheckOrigin(origin, r.Header.Get("Host")) {    
return nil, ErrCrossOrigin    
}    
key := r.Header.Get("Sec-Websocket-Key")    
if key == "" {    
return nil, ErrSecKey    
}    
protocol := r.Header.Get("Sec-Websocket-Protocol")    
if protocol != "" {    
idx := strings.IndexByte(protocol, ',')    
if idx != -1 {    
protocol = protocol[:idx]    
}    
}    
return challengeResponse(key, protocol), nil    
}    
func websocketMask(mask []byte, data []byte) {    
for i := range data {    
data[i] ^= mask[i%4]    
}    
}    
func New(w http.ResponseWriter, r *http.Request, opt *Option) (*Websocket, error) {    
var h WsHandler    
var maskOutgoing bool    
if opt == nil {    
h = WsDefaultHandler{true}    
maskOutgoing = false    
} else {    
h = opt.Handler    
maskOutgoing = opt.MaskOutgoing    
}    
challenge, err := acceptConnection(r, h)    
if err != nil {    
var code int    
if err == ErrCrossOrigin {    
code = 403    
} else {    
code = 400    
}    
w.WriteHeader(code)    
w.Write([]byte(err.Error()))    
return nil, err    
}    
hj, ok := w.(http.Hijacker)    
if !ok {    
return nil, ErrHijacker    
}    
conn, rw, err := hj.Hijack()    
ws := new(Websocket)    
ws.conn = conn    
ws.rw = rw    
ws.handler = h    
ws.maskOutgoing = maskOutgoing    
if _, err := ws.conn.Write(challenge); err != nil {    
ws.conn.Close()    
return nil, err    
}    
ws.handler.OnOpen(ws)    
return ws, nil    
}    
func (ws *Websocket) read(buf []byte) error {    
_, err := io.ReadFull(ws.rw, buf)    
return err    
}    
func (ws *Websocket) SendFrame(fin bool, opcode byte, data []byte) error {    
//max frame header may 14 length    
buf := make([]byte, 0, len(data)+14)    
var finBit, maskBit byte    
if fin {    
finBit = 0x80    
} else {    
finBit = 0    
}    
buf = append(buf, finBit|opcode)    
length := len(data)    
if ws.maskOutgoing {    
maskBit = 0x80    
} else {    
maskBit = 0    
}    
if length < 126 {    
buf = append(buf, byte(length)|maskBit)    
} else if length < 0xFFFF {    
buf = append(buf, 126|maskBit, 0, 0)    
binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(length))    
} else {    
buf = append(buf, 127|maskBit, 0, 0, 0, 0, 0, 0, 0, 0)    
binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(length))    
}    
if ws.maskOutgoing {    
}    
buf = append(buf, data...)    
ws.rw.Write(buf)    
return ws.rw.Flush()    
}    
func (ws *Websocket) SendText(data []byte) error {    
return ws.SendFrame(true, 0x1, data)    
}    
func (ws *Websocket) SendBinary(data []byte) error {    
return ws.SendFrame(true, 0x2, data)    
}    
func (ws *Websocket) SendPing(data []byte) error {    
return ws.SendFrame(true, 0x9, data)    
}    
func (ws *Websocket) SendPong(data []byte) error {    
return ws.SendFrame(true, 0xA, data)    
}    
func (ws *Websocket) Close(code uint16, reason []byte) {    
if !ws.serverTerminated {    
data := make([]byte, 0, len(reason)+2)    
if code == 0 && reason != nil {    
code = 1000    
}    
if code != 0 {    
data = append(data, 0, 0)    
binary.BigEndian.PutUint16(data, code)    
}    
if reason != nil {    
data = append(data, reason...)    
}    
ws.SendFrame(true, 0x8, data)    
ws.serverTerminated = true    
}    
if ws.clientTerminated {    
ws.conn.Close()    
}    
}    
func (ws *Websocket) RecvFrame() (final bool, message []byte, err error) { //text 数据报文    
buf := make([]byte, 8, 8)    
err = ws.read(buf[:2])    
if err != nil {    
return    
}    
header, payload := buf[0], buf[1]    
final = header&0x80 != 0    
reservedBits := header&0x70 != 0    
frameOpcode := header & 0xf    
frameOpcodeIsControl := frameOpcode&0x8 != 0    
if reservedBits {    
// client is using as-yet-undefined extensions    
err = ErrReservedBits    
return    
}    
maskFrame := payload&0x80 != 0    
payloadlen := uint64(payload & 0x7f)    
if frameOpcodeIsControl && payloadlen >= 126 {    
err = ErrFrameOverload    
return    
}    
if frameOpcodeIsControl && !final {    
err = ErrFrameFragmented    
return    
}    
//解析frame长度    
var frameLength uint64    
if payloadlen < 126 {    
frameLength = payloadlen    
} else if payloadlen == 126 {    
err = ws.read(buf[:2])    
if err != nil {    
return    
}    
frameLength = uint64(binary.BigEndian.Uint16(buf[:2]))    
} else { //payloadlen == 127    
err = ws.read(buf[:8])    
if err != nil {    
return    
}    
frameLength = binary.BigEndian.Uint64(buf[:8])    
}    
frameMask := make([]byte, 4, 4)    
if maskFrame {    
err = ws.read(frameMask)    
if err != nil {    
return    
}    
}    
// fmt.Println("final_frame:", final, "frame_opcode:", frameOpcode, "mask_frame:", maskFrame, "frame_length:", frameLength)    
message = make([]byte, frameLength, frameLength)    
if frameLength > 0 {    
err = ws.read(message)    
if err != nil {    
return    
}    
}    
if maskFrame && frameLength > 0 {    
websocketMask(frameMask, message)    
}    
if !final {    
return    
} else {    
switch frameOpcode {    
case 0x1: //text    
case 0x2: //binary    
case 0x8: // close    
var code uint16    
var reason []byte    
if frameLength >= 2 {    
code = binary.BigEndian.Uint16(message[:2])    
}    
if frameLength > 2 {    
reason = message[2:]    
}    
message = nil    
ws.clientTerminated = true    
ws.Close(0, nil)    
ws.handler.OnClose(ws, code, reason)    
case 0x9: //ping    
message = nil    
ws.SendPong(nil)    
case 0xA:    
ws.handler.OnPong(ws, message)    
message = nil    
default:    
err = ErrInvalidOpcode    
}    
return    
}    
}    
func (ws *Websocket) Recv() ([]byte, error) {    
data := make([]byte, 0, 8)    
for {    
final, message, err := ws.RecvFrame()    
if final {    
data = append(data, message...)    
break    
} else {    
data = append(data, message...)    
}    
if err != nil {    
return data, err    
}    
}    
if len(data) > 0 {    
ws.handler.OnMessage(ws, data)    
}    
return data, nil    
}    
func (ws *Websocket) Start() {    
for {    
_, err := ws.Recv()    
if err != nil {    
ws.conn.Close()    
}    
}    
}


有疑问加站长微信联系(非本文作者)

本文来自:开源中国博客

感谢作者:haitgo

查看原文:go语言websocket源码

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

2826 次点击  
加入收藏 微博
暂无回复
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传