tcp流量转发到不同的后端

jiangyd · 2022-09-21 18:01:22 · 2623 次点击 · 大约8小时之前 开始浏览    置顶
这是一个创建于 2022-09-21 18:01:22 的主题,其中的信息可能已经有所发展或是发生改变。

想把流量发个多个后端,每个后端都可以收到内容,但我写的,每次都只有一个后端收到内容,随机的

package main

import (
    "flag"
    "fmt"
    "io"
    "net"
    "strings"
)

var tcp_conn []net.Conn

func main() {
    var address *string
    var dst *string
    address = flag.String("listen_address", "127.0.0.1:8080", "listen address")
    dst = flag.String("target_address", "172.16.5.9:18061,172.16.5.9:18060", "目标地址,多个地址逗号分割")
    flag.Parse()
    listener, err := net.Listen("tcp", *address)
    if err != nil {
        fmt.Println(err)
    }
    ips := strings.Split(*dst, ",")
    if len(ips) == 0 {
        fmt.Println("目标地址不能为空")
        return
    }
    for _, ip := range ips {
        conn, err := net.Dial("tcp", ip)
        if err != nil {
            fmt.Println(err)
            continue
        }
        tcp_conn = append(tcp_conn, conn)
    }
    for {
        conn, err := listener.Accept()
        if err != nil {
            continue
        }
        go handler(conn)
    }

}

func handler(conn net.Conn) {
    // defer conn.Close()
    for _, dst_conn := range tcp_conn {
        go io.Copy(dst_conn, conn)
        go io.Copy(conn, dst_conn)
    }
}

该如何改呢,哪位大佬帮忙看看


有疑问加站长微信联系(非本文作者)

入群交流(和以上内容无关):加入Go大咖交流群,或添加微信:liuxiaoyan-s 备注:入群;或加QQ群:692541889

2623 次点击  
加入收藏 微博
11 回复  |  直到 2022-11-09 10:47:47
chengxuge
chengxuge · #1 · 3年之前

某个协程io.Copy之后相当于读完了流,另外的协程当然读不到数据了,你需要修改代码,每次读取[]byte,然后再每个连接Write([]byte)

jiangyd
jiangyd · #2 · 3年之前
chengxugechengxuge #1 回复

某个协程io.Copy之后相当于读完了流,另外的协程当然读不到数据了,你需要修改代码,每次读取[]byte,然后再每个连接Write([]byte)

代码如何写啊,不用io.copy了吗

zzustu
zzustu · #3 · 3年之前

将流量拷贝分发给多个后端可以理解,可以简单的这样写:

mw := io.MultiWriter(dest1, dest2, dest3)
io.Copy(mw, src)

但是,后端的返回数据都往一个conn里面写入,这个conn收到消息该怎么解析。

jiangyd
jiangyd · #4 · 3年之前
zzustuzzustu #3 回复

将流量拷贝分发给多个后端可以理解,可以简单的这样写: ```go mw := io.MultiWriter(dest1, dest2, dest3) io.Copy(mw, src) ``` 但是,后端的返回数据都往一个conn里面写入,这个conn收到消息该怎么解析。

我只取第一个回复的数据,这样试了下不行,会报错,

        i := make([]io.Writer, 3)
    for _, dst_conn := range tcp_conn {
        i = append(i, dst_conn)
    }
    mm := io.MultiWriter(i...)
    go io.Copy(mm, conn)
    go io.Copy(conn, tcp_conn[0])
zzustu
zzustu · #5 · 3年之前

make slice 时写错了

// 错误
i := make([]io.Writer, 3)

// 正确
i := make([]io.Writer, 0, 3)
zzustu
zzustu · #6 · 3年之前

楼主这个场景有点像流量镜像,就是把流量拷贝N多份,一份照常给业务服务器。其他的发送给流量审计服务器或者测试环境服务器。

需要有一些注意:

  1. 如果只读取第一个连接的数据回复,其他的不管,有可能会造其他连接的消息积压。 虽然你并不关心其他连接回复的消息,但是应该把回复的消息读出来,比如读出来丢弃掉:io.Copy(io.Discard, otherConn)

  2. 最好不要用 io.MultiWriter(),因为有一个写入出错就会导致 io.MultiWriter() 报错, 因为后端某个连接可以能会异常或断开,这样一个异常,会影响所有的 Writer 写入

jiangyd
jiangyd · #7 · 3年之前
zzustuzzustu #6 回复

楼主这个场景有点像流量镜像,就是把流量拷贝N多份,一份照常给业务服务器。其他的发送给流量审计服务器或者测试环境服务器。 需要有一些注意: 1. 如果只读取第一个连接的数据回复,其他的不管,`有可能`会造其他连接的消息积压。 虽然你并不关心其他连接回复的消息,但是应该把回复的消息读出来,比如读出来丢弃掉:`io.Copy(io.Discard, otherConn)` 2. 最好不要用 `io.MultiWriter()`,因为有一个写入出错就会导致 `io.MultiWriter()` 报错, 因为后端某个连接可以能会异常或断开,这样一个异常,会影响所有的 Writer 写入

是这样的场景,我们公司有模拟别家公司服务端的一个功能,客户端用的是人家的,接收到数据后,做解析,为了解析效果一致, 需要把客户端上报的数据,同时发给我司的服务端与别家公司的服务端,做比较找出差异。

        defer conn.Close()
    p := make([]byte, 0, 124)
    for {
        if len(p) == cap(p) {
            p = append(p, 0)[:len(p)]
        }
        n, err := conn.Read(p[len(p):cap(p)])
        p = p[:len(p)+n]
        if n < 124 && err == nil {
            break
        } else if err != nil && err == io.EOF {
            break
        } else {
            continue
        }

    }

    for _, dst_conn := range tcp_conn {
        t := bytes.NewReader(p)
        go io.Copy(dst_conn, t)
    }
    io.Copy(conn, tcp_conn[0])

我这样写,目前有个问题,第一次客户端发送数据,服务端都能收到,但是第二次发送 ,客户端就一致sending ,不知道是哪里阻塞了

zzustu
zzustu · #8 · 3年之前

凭感觉写的,未经优化未经测试,仅供参考

package main

import (
    "io"
    "net"
    "time"
)

func main() {
    backends := []string{"172.16.5.9:18060", "172.16.5.9:18061", "172.16.5.9:18062"}

    lis, err := net.Listen("tcp", ":8080")
    if err != nil {
        panic(err)
    }

    for {
        conn, err := lis.Accept()
        if err != nil {
            continue
        }

        go distribute(conn, backends)
    }
}

func distribute(conn net.Conn, backends []string) {
    defer conn.Close()

    backs := make([]net.Conn, 0, len(backends))
    for _, backend := range backends {
        back, err := net.DialTimeout("tcp", backend, 3*time.Second)
        if err != nil {
            continue
        }
        backs = append(backs, back)
    }
    if len(backs) == 0 {
        return
    }

    pipe := &pipeline{conn: conn, backs: backs}

    pipe.serve()
}

type pipeline struct {
    conn  net.Conn
    backs []net.Conn
}

func (p *pipeline) serve() {
    defer p.close()

    // 第一个连接为主连接,需要将数据响应给客户端
    master := p.backs[0]
    go p.readFrom(master)

    // 虽然不关心剩余的后端连接响应数据,但是应该将数据读出并丢弃
    discards := p.backs[1:]
    if len(discards) != 0 {
        go p.discard(discards)
    }

    buf := make([]byte, 1024)
    for {
        n, err := p.conn.Read(buf)
        if err != nil {
            break
        }

        // 读出客户端的数据,并分发给所有后端连接
        if hasErr := p.writeAll(buf[:n]); hasErr && len(p.backs) == 0 {
            break
        }
    }
}

func (p *pipeline) readFrom(r io.Reader) {
    _, _ = io.Copy(p.conn, r)
    _ = p.conn.Close()
}

func (*pipeline) discard(rds []net.Conn) {
    for _, rd := range rds {
        go io.Copy(io.Discard, rd)
    }
}

func (p *pipeline) writeAll(data []byte) bool {
    var hasErr bool
    for i, back := range p.backs {
        if _, err := back.Write(data); err == nil {
            continue
        }
        // 写入数据出现错误,说明后端连接异常,将这个连接剔除列表
        p.backs = append(p.backs[:i], p.backs[i+1:]...)
        hasErr = true
        _ = back.Close() // 保险起见,剔除后再主动执行 Close() 方法
    }

    return hasErr
}

func (p *pipeline) close() {
    _ = p.conn.Close()
    for _, back := range p.backs {
        _ = back.Close()
    }
}
zzustu
zzustu · #9 · 3年之前

code.png

jiangyd
jiangyd · #10 · 3年之前
zzustuzzustu #8 回复

凭感觉写的,未经优化未经测试,仅供参考 ```go package main import ( "io" "net" "time" ) func main() { backends := []string{"172.16.5.9:18060", "172.16.5.9:18061", "172.16.5.9:18062"} lis, err := net.Listen("tcp", ":8080") if err != nil { panic(err) } for { conn, err := lis.Accept() if err != nil { continue } go distribute(conn, backends) } } func distribute(conn net.Conn, backends []string) { defer conn.Close() backs := make([]net.Conn, 0, len(backends)) for _, backend := range backends { back, err := net.DialTimeout("tcp", backend, 3*time.Second) if err != nil { continue } backs = append(backs, back) } if len(backs) == 0 { return } pipe := &pipeline{conn: conn, backs: backs} pipe.serve() } type pipeline struct { conn net.Conn backs []net.Conn } func (p *pipeline) serve() { defer p.close() // 第一个连接为主连接,需要将数据响应给客户端 master := p.backs[0] go p.readFrom(master) // 虽然不关心剩余的后端连接响应数据,但是应该将数据读出并丢弃 discards := p.backs[1:] if len(discards) != 0 { go p.discard(discards) } buf := make([]byte, 1024) for { n, err := p.conn.Read(buf) if err != nil { break } // 读出客户端的数据,并分发给所有后端连接 if hasErr := p.writeAll(buf[:n]); hasErr && len(p.backs) == 0 { break } } } func (p *pipeline) readFrom(r io.Reader) { _, _ = io.Copy(p.conn, r) _ = p.conn.Close() } func (*pipeline) discard(rds []net.Conn) { for _, rd := range rds { go io.Copy(io.Discard, rd) } } func (p *pipeline) writeAll(data []byte) bool { var hasErr bool for i, back := range p.backs { if _, err := back.Write(data); err == nil { continue } // 写入数据出现错误,说明后端连接异常,将这个连接剔除列表 p.backs = append(p.backs[:i], p.backs[i+1:]...) hasErr = true _ = back.Close() // 保险起见,剔除后再主动执行 Close() 方法 } return hasErr } func (p *pipeline) close() { _ = p.conn.Close() for _, back := range p.backs { _ = back.Close() } } ```

真心感谢大佬,一把跑通,请求多次都是正常的

xzd20
xzd20 · #11 · 2年之前

请问如果实现轮询发送的话应该怎么做呢?就是第一次发给服务器A,接收它的返回;第二次发给服务器B,接收它的返回,以此类推,基于您的基础上改了改发现有问题

func (p *pipeline) serve() {
    defer p.close()

    //master := p.backs[0]
    //go p.readFrom(master)
    //
    //discards := p.backs[1:]
    //if len(discards) != 0 {
    //    go p.discard(discards)
    //}

    buf := make([]byte, 1024)
    bufRead := make([]byte, 1024)
    index := 0
    for {

        n, err := p.conn.Read(buf)
        if err != nil {
            return
        }
        if _, err := p.backs[index].Write(buf[:n]); err != nil {
            break
        }
        fmt.Println("index: ", index)

        n, err = p.backs[index].Read(bufRead)
        if err != nil {
            break
        }
        if _, err = p.conn.Write(bufRead[:n]); err != nil {
            break
        }

        index = (index + 1) % len(p.backs)

    }
}
添加一条新回复 (您需要 登录 后才能回复 没有账号 ?)
  • 请尽量让自己的回复能够对别人有帮助
  • 支持 Markdown 格式, **粗体**、~~删除线~~、`单行代码`
  • 支持 @ 本站用户;支持表情(输入 : 提示),见 Emoji cheat sheet
  • 图片支持拖拽、截图粘贴等方式上传