package main
import (
"fmt"
"net/http"
"net/http/httptest"
"time"
"golang.org/x/time/rate"
)
type LimiterOption struct {
lm *rate.Limiter
}
func WithLimiter(duration time.Duration, count int) func(o *LimiterOption) {
return func(o *LimiterOption) {
o.lm = rate.NewLimiter(rate.Every(duration), count)
}
}
// LimiterWrap 每个 handler 单独限制
func LimiterWrap(f http.HandlerFunc, opts ...func(o *LimiterOption)) http.HandlerFunc {
o := &LimiterOption{
lm: rate.NewLimiter(rate.Every(100*time.Millisecond), 10),
}
for _, opt := range opts {
opt(o)
}
return func(w http.ResponseWriter, r *http.Request) {
if !o.lm.Allow() {
w.WriteHeader(http.StatusInternalServerError)
return
}
f(w, r)
}
}
func xHandle(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hello\n"))
}
func main() {
// 单独控制一个接口
http.HandleFunc("/", LimiterWrap(xHandle, WithLimiter(100*time.Millisecond, 1)))
// 所有接口都走同一个限流
// http.ListenAndServe(":8080", LimiterWrap(http.DefaultServeMux.ServeHTTP))
for i := 0; i < 10; i++ {
w := httptest.NewRecorder()
req, _ := http.NewRequest(http.MethodGet, "/", nil)
http.DefaultServeMux.ServeHTTP(w, req)
fmt.Println(i, w.Code)
time.Sleep(time.Millisecond * 20)
}
//output
//0 200
//1 500
//2 500
//3 500
//4 500
//5 200
//6 500
//7 500
//8 500
//9 500
}