Skip to content

Commit f8c004e

Browse files
committed
do not break Status api
1 parent 74791f6 commit f8c004e

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

limiter.go

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package httprate
22

33
import (
4-
"context"
54
"fmt"
65
"math"
76
"net/http"
@@ -68,28 +67,8 @@ func (l *rateLimiter) Counter() LimitCounter {
6867
return l.limitCounter
6968
}
7069

71-
func (l *rateLimiter) Status(ctx context.Context, key string) (bool, float64, error) {
72-
t := time.Now().UTC()
73-
currentWindow := t.Truncate(l.windowLength)
74-
previousWindow := currentWindow.Add(-l.windowLength)
75-
76-
currCount, prevCount, err := l.limitCounter.Get(key, currentWindow, previousWindow)
77-
if err != nil {
78-
return false, 0, err
79-
}
80-
81-
diff := t.Sub(currentWindow)
82-
rate := float64(prevCount)*(float64(l.windowLength)-float64(diff))/float64(l.windowLength) + float64(currCount)
83-
84-
limit := l.requestLimit
85-
if val := getRequestLimit(ctx); val > 0 {
86-
limit = val
87-
}
88-
89-
if rate > float64(limit) {
90-
return false, rate, nil
91-
}
92-
return true, rate, nil
70+
func (l *rateLimiter) Status(key string) (bool, float64, error) {
71+
return l.calculateRate(key, l.requestLimit)
9372
}
9473

9574
func (l *rateLimiter) Handler(next http.Handler) http.Handler {
@@ -112,7 +91,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
11291
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix()))
11392

11493
l.mu.Lock()
115-
_, rate, err := l.Status(ctx, key)
94+
_, rate, err := l.calculateRate(key, limit)
11695
if err != nil {
11796
l.mu.Unlock()
11897
http.Error(w, err.Error(), http.StatusPreconditionRequired)
@@ -143,6 +122,25 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
143122
})
144123
}
145124

125+
func (l *rateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) {
126+
t := time.Now().UTC()
127+
currentWindow := t.Truncate(l.windowLength)
128+
previousWindow := currentWindow.Add(-l.windowLength)
129+
130+
currCount, prevCount, err := l.limitCounter.Get(key, currentWindow, previousWindow)
131+
if err != nil {
132+
return false, 0, err
133+
}
134+
135+
diff := t.Sub(currentWindow)
136+
rate := float64(prevCount)*(float64(l.windowLength)-float64(diff))/float64(l.windowLength) + float64(currCount)
137+
if rate > float64(requestLimit) {
138+
return false, rate, nil
139+
}
140+
141+
return true, rate, nil
142+
}
143+
146144
type localCounter struct {
147145
counters map[uint64]*count
148146
windowLength time.Duration

0 commit comments

Comments
 (0)