package main import ( "encoding/json" "flag" "fmt" "log" "net/http" "time" "config" "framework/logger" "global" "models/function" "models/schema" "github.com/go-redis/redis" "github.com/gorilla/websocket" "github.com/labstack/echo" ) var clients = make(map[*websocket.Conn]bool) var broadcast = make(chan Message) var upgrader = websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} //不使用默认设置,如果线上环境可能需要使用默认配置 var chananel = make(chan schema.Listening) //数据chan var configFile *string = flag.String("config", "./bin/etc/conf.yaml", "agency config file")//这是数据库的配置文件解析,单写的时候提出来 var agentSlice []map[string]*websocket.Conn //socket对应关系存储 //发送消息结构体 type Message struct { Message interface{} `json:"message"` SiteId string `json:"site_id"` SiteIndexId string `json:"site_index_id"` Count int64 `json:"count"` } //测试用[正式修改之后可以删除] func hu(w http.ResponseWriter, r *http.Request) { siteid := r.FormValue("site_id") siteIndexId := r.FormValue("site_index_id") fmt.Println(siteIndexId, siteid) s := schema.Listening{"zym", "b", 1} chananel <- s } func main() { //数据库初始化 cfg, err := config.ParseConfigFile(*configFile) if err != nil { log.Fatalf("parse config file error:%v\n", err.Error()) return } //初始化数据库 err = global.InitMysql(cfg.Mysqls) if err != nil { //数据库连接错误 global.GlobalLogger.Error("InitDb error:%v\n", err.Error()) return } http.HandleFunc("/o", hu) http.HandleFunc("/ws", handleConnections) go handleMessages() err = http.ListenAndServe(cfg.Wesocketport, nil) if err != nil { log.Fatal(err.Error()) } } func handleConnections(w http.ResponseWriter, r *http.Request) { //如果限制连接就可以使用ip+port限制,根据ip区分客户端,其他的可以根据r.Request提交的数据查找相应的内容 siteId := r.FormValue("site_id") siteIndexId := r.FormValue("site_index_id")//这里是用来唯一区分客户端的判断条件 if siteId == "" || siteIndexId == "" { http.Error(w, "site_id and site_index_id must not empty", 403) } //注册成为websocket ws, err := upgrader.Upgrade(w, r, nil) if err != nil { global.GlobalLogger.Error("error:%s", err.Error()) return } defer ws.Close() //存储连接[todo 这里可能还要考虑map并发读写问题] agent := make(map[string]*websocket.Conn) agent[s] = ws agentSlice = append(agentSlice, agent) clients[ws] = true //监听接收一个[models/schema]schema.Listening, for { var msg Message s := <-chananel if s.Types == 1 { //todo 这里解析取出来的数据可能还需要加工 //获取最新的没有确认得公司入款 newincome := new(function.MemberCompanyIncomeBean) info, count, err := newincome.GetNotConfirm(s.SiteId, s.SiteIndexId) if err != nil { global.GlobalLogger.Error("error:%s", err.Error()) return } msg = Message{SiteIndexId: s.SiteIndexId, SiteId: s.SiteId, Message: info, Count: count} } else if s.Types == 2 { //获取最新的线上入款 onLineBean := new(function.OnlineEntryRecordBean) info, count, err := onLineBean.GetNotConfirm(s.SiteId, s.SiteIndexId) if err != nil { global.GlobalLogger.Error("error:%s", err.Error()) return } msg = Message{SiteIndexId: s.SiteIndexId, SiteId: s.SiteId, Message: info, Count: count} } else { //获取没有确认得最新的出款管理 makeMoney := new(function.MakeMoneyBean) info, count, err := makeMoney.GetOperateRecord(s.SiteId, s.SiteIndexId) if err != nil { global.GlobalLogger.Error("error:%s", err.Error()) return } msg = Message{SiteIndexId: s.SiteIndexId, SiteId: s.SiteId, Count: count, Message: info} } broadcast <- msg } } //单点推送 func handleMessages() { for { msg := <-broadcast var pushClient []*websocket.Conn newS := fmt.Sprintf("%s%s", msg.SiteId, msg.SiteIndexId) lenAgent := len(agentSlice) for i := 0; i < lenAgent; i++ { for k, v := range agentSlice[i] { if newS == k { pushClient = append(pushClient, v) } } } for i := 0; i < len(pushClient); i++ { for client := range clients { if pushClient[i] == client { err := client.WriteJSON(msg) if err != nil { global.GlobalLogger.Error("error:%s", err.Error()) client.Close() delete(clients, client) } } } } } }
有疑问加站长微信联系(非本文作者)