Skip to content

Commit a5842b5

Browse files
committed
improve handler lock
1 parent 408ed3d commit a5842b5

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

limiter.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,34 +62,31 @@ type rateLimiter struct {
6262
mu sync.Mutex
6363
}
6464

65-
func (r *rateLimiter) Counter() LimitCounter {
66-
return r.limitCounter
65+
func (l *rateLimiter) Counter() LimitCounter {
66+
return l.limitCounter
6767
}
6868

69-
func (r *rateLimiter) Status(key string) (bool, float64, error) {
69+
func (l *rateLimiter) Status(key string) (bool, float64, error) {
7070
t := time.Now().UTC()
71-
currentWindow := t.Truncate(r.windowLength)
72-
previousWindow := currentWindow.Add(-r.windowLength)
71+
currentWindow := t.Truncate(l.windowLength)
72+
previousWindow := currentWindow.Add(-l.windowLength)
7373

74-
currCount, prevCount, err := r.limitCounter.Get(key, currentWindow, previousWindow)
74+
currCount, prevCount, err := l.limitCounter.Get(key, currentWindow, previousWindow)
7575
if err != nil {
7676
return false, 0, err
7777
}
7878

7979
diff := t.Sub(currentWindow)
80-
rate := float64(prevCount)*(float64(r.windowLength)-float64(diff))/float64(r.windowLength) + float64(currCount)
80+
rate := float64(prevCount)*(float64(l.windowLength)-float64(diff))/float64(l.windowLength) + float64(currCount)
8181

82-
if rate > float64(r.requestLimit) {
82+
if rate > float64(l.requestLimit) {
8383
return false, rate, nil
8484
}
8585
return true, rate, nil
8686
}
8787

8888
func (l *rateLimiter) Handler(next http.Handler) http.Handler {
8989
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
90-
l.mu.Lock()
91-
defer l.mu.Unlock()
92-
9390
key, err := l.keyFn(r)
9491
if err != nil {
9592
http.Error(w, err.Error(), http.StatusPreconditionRequired)
@@ -102,8 +99,10 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
10299
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", 0))
103100
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix()))
104101

102+
l.mu.Lock()
105103
_, rate, err := l.Status(key)
106104
if err != nil {
105+
l.mu.Unlock()
107106
http.Error(w, err.Error(), http.StatusPreconditionRequired)
108107
return
109108
}
@@ -114,16 +113,19 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
114113
}
115114

116115
if nrate >= l.requestLimit {
116+
l.mu.Unlock()
117117
w.Header().Set("Retry-After", fmt.Sprintf("%d", int(l.windowLength.Seconds()))) // RFC 6585
118118
l.onRequestLimit(w, r)
119119
return
120120
}
121121

122122
err = l.limitCounter.Increment(key, currentWindow)
123123
if err != nil {
124+
l.mu.Unlock()
124125
http.Error(w, err.Error(), http.StatusInternalServerError)
125126
return
126127
}
128+
l.mu.Unlock()
127129

128130
next.ServeHTTP(w, r)
129131
})

0 commit comments

Comments
 (0)