用Go实现TCP连接的双向拷贝

陶文 · · 2410 次点击 · · 开始浏览    
这是一个创建于 的文章,其中的信息可能已经有所发展或是发生改变。

最简单的实现

每次来一个Server的连接,就新开一个Client的连接。用一个goroutine从server拷贝到client,再用另外一个goroutine从client拷贝到server。任何一方断开连接,双向都断开连接。

func main() {
	runtime.GOMAXPROCS(1)
	listener, err := net.Listen("tcp", "127.0.0.1:8848")
	if err != nil {
		panic(err)
	}
	for {
		conn, err := listener.Accept()
		if err != nil {
			panic(err)
		}
		go handle(conn.(*net.TCPConn))
	}
}

func handle(server *net.TCPConn) {
	defer server.Close()
	client, err := net.Dial("tcp", "127.0.0.1:8849")
	if err != nil {
		fmt.Print(err)
		return
	}
	defer client.Close()
	go func() {
		defer server.Close()
		defer client.Close()
		buf := make([]byte, 2048)
		io.CopyBuffer(server, client, buf)
	}()
	buf := make([]byte, 2048)
	io.CopyBuffer(client, server, buf)
}

一个值得注意的地方是io.Copy的默认buffer比较大,给一个小的buffer可以支持更多的并发连接。

这两个goroutine并序在一个退出之后,另外一个也退出。这个的实现是通过关闭server或者client的socket来实现的。因为socket被关闭了,io.CopyBuffer 就会退出。

Client端实现连接池

一个显而易见的问题是,每次Server的连接进来之后都需要临时去建立一个新的Client的端的连接。这样在代理的总耗时里就包括了一个tcp连接的握手时间。如果能够让Client端实现连接池复用已有连接的话,可以缩短端到端的延迟。

var pool = make(chan net.Conn, 100)

func borrow() (net.Conn, error) {
	select {
	case conn := <- pool:
		return conn, nil
	default:
		return net.Dial("tcp", "127.0.0.1:8849")
	}
}

func release(conn net.Conn) error {
	select {
	case pool <- conn:
		// returned to pool
		return nil
	default:
		// pool is overflow
		return conn.Close()
	}
}

func handle(server *net.TCPConn) {
	defer server.Close()
	client, err := borrow()
	if err != nil {
		fmt.Print(err)
		return
	}
	defer release(client)
	go func() {
		defer server.Close()
		defer release(client)
		buf := make([]byte, 2048)
		io.CopyBuffer(server, client, buf)
	}()
	buf := make([]byte, 2048)
	io.CopyBuffer(client, server, buf)
}

这个版本的实现是显而易见有问题的。因为连接在归还到池里的时候并不能保证是还保持连接的状态。另外一个更严重的问题是,因为client的连接不再被关闭了,当server端关闭连接时,从client向server做io.CopyBuffer的goroutine就无法退出了。

所以,有以下几个问题要解决:

  • 如何在一个goroutine时退出时另外一个goroutine也退出?
  • 怎么保证归还给pool的连接是有效的?
  • 怎么保持在pool中的连接仍然是一直有效的?

通过SetDeadline中断Goroutine

一个普遍的观点是Goroutine是无法被中断的。当一个Goroutine在做conn.Read时,这个协程就被阻塞在那里了。实际上并不是毫无办法的,我们可以通过conn.Close来中断Goroutine。但是在连接池的情况下,又无法Close链接。另外一种做法就是通过SetDeadline为一个过去的时间戳来中断当前正在进行的阻塞读或者阻塞写。

var pool = make(chan net.Conn, 100)

type client struct {
	conn net.Conn
	inUse *sync.WaitGroup
}

func borrow() (clt *client, err error) {
	var conn net.Conn
	select {
	case conn = <- pool:
	default:
		conn, err = net.Dial("tcp", "127.0.0.1:18849")
	}
	if err != nil {
		return nil, err
	}
	clt = &client{
		conn: conn,
		inUse: &sync.WaitGroup{},
	}
	return
}

func release(clt *client) error {
	clt.conn.SetDeadline(time.Now().Add(-time.Second))
	clt.inUse.Done()
	clt.inUse.Wait()
	select {
	case pool <- clt.conn:
		// returned to pool
		return nil
	default:
		// pool is overflow
		return clt.conn.Close()
	}
}

func handle(server *net.TCPConn) {
	defer server.Close()
	clt, err := borrow()
	if err != nil {
		fmt.Print(err)
		return
	}
	clt.inUse.Add(1)
	defer release(clt)
	go func() {
		clt.inUse.Add(1)
		defer server.Close()
		defer release(clt)
		buf := make([]byte, 2048)
		io.CopyBuffer(server, clt.conn, buf)
	}()
	buf := make([]byte, 2048)
	io.CopyBuffer(clt.conn, server, buf)
}

通过SetDeadline实现了goroutine的中断,然后通过sync.WaitGroup来保证这些使用方都退出了之后再归还给连接池。否则一个连接被复用的时候,之前的使用方可能还没有退出。

连接有效性

为了保证在归还给pool之前,连接仍然是有效的。连接在被读写的过程中如果发现了error,我们就要标记这个连接是有问题的,会释放之后直接close掉。但是SetDeadline必然会导致读取或者写入的时候出现一次timeout的错误,所以还需要把timeout排除掉。

var pool = make(chan net.Conn, 100)

type client struct {
	conn net.Conn
	inUse *sync.WaitGroup
	isValid int32
}

const maybeValid = 0
const isValid = 1
const isInvalid = 2

func (clt *client) Read(b []byte) (n int, err error) {
	n, err = clt.conn.Read(b)
	if err != nil {
		if !isTimeoutError(err) {
			atomic.StoreInt32(&clt.isValid, isInvalid)
		}
	} else {
		atomic.StoreInt32(&clt.isValid, isValid)
	}
	return
}

func (clt *client) Write(b []byte) (n int, err error) {
	n, err = clt.conn.Write(b)
	if err != nil {
		if !isTimeoutError(err) {
			atomic.StoreInt32(&clt.isValid, isInvalid)
		}
	} else {
		atomic.StoreInt32(&clt.isValid, isValid)
	}
	return
}

type timeoutErr interface {
	Timeout() bool
}

func isTimeoutError(err error) bool {
	timeoutErr, _ := err.(timeoutErr)
	if timeoutErr == nil {
		return false
	}
	return timeoutErr.Timeout()
}

func borrow() (clt *client, err error) {
	var conn net.Conn
	select {
	case conn = <- pool:
	default:
		conn, err = net.Dial("tcp", "127.0.0.1:18849")
	}
	if err != nil {
		return nil, err
	}
	clt = &client{
		conn: conn,
		inUse: &sync.WaitGroup{},
		isValid: maybeValid,
	}
	return
}

func release(clt *client) error {
	clt.conn.SetDeadline(time.Now().Add(-time.Second))
	clt.inUse.Done()
	clt.inUse.Wait()
	if clt.isValid == isValid {
		return clt.conn.Close()
	}
	select {
	case pool <- clt.conn:
		// returned to pool
		return nil
	default:
		// pool is overflow
		return clt.conn.Close()
	}
}

func handle(server *net.TCPConn) {
	defer server.Close()
	clt, err := borrow()
	if err != nil {
		fmt.Print(err)
		return
	}
	clt.inUse.Add(1)
	defer release(clt)
	go func() {
		clt.inUse.Add(1)
		defer server.Close()
		defer release(clt)
		buf := make([]byte, 2048)
		io.CopyBuffer(server, clt, buf)
	}()
	buf := make([]byte, 2048)
	io.CopyBuffer(clt, server, buf)
}

判断 error 是否是 timeout 需要类型强转来实现。

对于连接池里的conn是否仍然是有效的,如果用后台不断ping的方式来实现成本比较高。因为不同的协议要连接保持需要不同的ping的方式。一个最简单的办法就是下次用的时候试一下。如果连接不好用了,则改成新建一个连接,避免连续拿到无效的连接。通过这种方式把无效的连接给淘汰掉。

关于正确性

本文在杭州机场写成,完全不保证内容的正确性


有疑问加站长微信联系(非本文作者)

本文来自:知乎专栏

感谢作者:陶文

查看原文:用Go实现TCP连接的双向拷贝

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

2410 次点击  ∙  1 赞  
加入收藏 微博
暂无回复
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传