tcp流量转发到不同的后端

jiangyd · 2022-09-21 18:01:22 · 2596 次点击

凭感觉写的,未经优化未经测试,仅供参考

package main

import (
    "io"
    "net"
    "time"
)

func main() {
    backends := []string{"172.16.5.9:18060", "172.16.5.9:18061", "172.16.5.9:18062"}

    lis, err := net.Listen("tcp", ":8080")
    if err != nil {
        panic(err)
    }

    for {
        conn, err := lis.Accept()
        if err != nil {
            continue
        }

        go distribute(conn, backends)
    }
}

func distribute(conn net.Conn, backends []string) {
    defer conn.Close()

    backs := make([]net.Conn, 0, len(backends))
    for _, backend := range backends {
        back, err := net.DialTimeout("tcp", backend, 3*time.Second)
        if err != nil {
            continue
        }
        backs = append(backs, back)
    }
    if len(backs) == 0 {
        return
    }

    pipe := &pipeline{conn: conn, backs: backs}

    pipe.serve()
}

type pipeline struct {
    conn  net.Conn
    backs []net.Conn
}

func (p *pipeline) serve() {
    defer p.close()

    // 第一个连接为主连接,需要将数据响应给客户端
    master := p.backs[0]
    go p.readFrom(master)

    // 虽然不关心剩余的后端连接响应数据,但是应该将数据读出并丢弃
    discards := p.backs[1:]
    if len(discards) != 0 {
        go p.discard(discards)
    }

    buf := make([]byte, 1024)
    for {
        n, err := p.conn.Read(buf)
        if err != nil {
            break
        }

        // 读出客户端的数据,并分发给所有后端连接
        if hasErr := p.writeAll(buf[:n]); hasErr && len(p.backs) == 0 {
            break
        }
    }
}

func (p *pipeline) readFrom(r io.Reader) {
    _, _ = io.Copy(p.conn, r)
    _ = p.conn.Close()
}

func (*pipeline) discard(rds []net.Conn) {
    for _, rd := range rds {
        go io.Copy(io.Discard, rd)
    }
}

func (p *pipeline) writeAll(data []byte) bool {
    var hasErr bool
    for i, back := range p.backs {
        if _, err := back.Write(data); err == nil {
            continue
        }
        // 写入数据出现错误,说明后端连接异常,将这个连接剔除列表
        p.backs = append(p.backs[:i], p.backs[i+1:]...)
        hasErr = true
        _ = back.Close() // 保险起见,剔除后再主动执行 Close() 方法
    }

    return hasErr
}

func (p *pipeline) close() {
    _ = p.conn.Close()
    for _, back := range p.backs {
        _ = back.Close()
    }
}
#8
更多评论

某个协程io.Copy之后相当于读完了流,另外的协程当然读不到数据了,你需要修改代码,每次读取[]byte,然后再每个连接Write([]byte)

#1

代码如何写啊,不用io.copy了吗

#2