1
1
package httprate
2
2
3
3
import (
4
- "context"
5
4
"fmt"
6
5
"math"
7
6
"net/http"
@@ -68,28 +67,8 @@ func (l *rateLimiter) Counter() LimitCounter {
68
67
return l .limitCounter
69
68
}
70
69
71
- func (l * rateLimiter ) Status (ctx context.Context , key string ) (bool , float64 , error ) {
72
- t := time .Now ().UTC ()
73
- currentWindow := t .Truncate (l .windowLength )
74
- previousWindow := currentWindow .Add (- l .windowLength )
75
-
76
- currCount , prevCount , err := l .limitCounter .Get (key , currentWindow , previousWindow )
77
- if err != nil {
78
- return false , 0 , err
79
- }
80
-
81
- diff := t .Sub (currentWindow )
82
- rate := float64 (prevCount )* (float64 (l .windowLength )- float64 (diff ))/ float64 (l .windowLength ) + float64 (currCount )
83
-
84
- limit := l .requestLimit
85
- if val := getRequestLimit (ctx ); val > 0 {
86
- limit = val
87
- }
88
-
89
- if rate > float64 (limit ) {
90
- return false , rate , nil
91
- }
92
- return true , rate , nil
70
+ func (l * rateLimiter ) Status (key string ) (bool , float64 , error ) {
71
+ return l .calculateRate (key , l .requestLimit )
93
72
}
94
73
95
74
func (l * rateLimiter ) Handler (next http.Handler ) http.Handler {
@@ -112,7 +91,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
112
91
w .Header ().Set ("X-RateLimit-Reset" , fmt .Sprintf ("%d" , currentWindow .Add (l .windowLength ).Unix ()))
113
92
114
93
l .mu .Lock ()
115
- _ , rate , err := l .Status ( ctx , key )
94
+ _ , rate , err := l .calculateRate ( key , limit )
116
95
if err != nil {
117
96
l .mu .Unlock ()
118
97
http .Error (w , err .Error (), http .StatusPreconditionRequired )
@@ -143,6 +122,25 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
143
122
})
144
123
}
145
124
125
+ func (l * rateLimiter ) calculateRate (key string , requestLimit int ) (bool , float64 , error ) {
126
+ t := time .Now ().UTC ()
127
+ currentWindow := t .Truncate (l .windowLength )
128
+ previousWindow := currentWindow .Add (- l .windowLength )
129
+
130
+ currCount , prevCount , err := l .limitCounter .Get (key , currentWindow , previousWindow )
131
+ if err != nil {
132
+ return false , 0 , err
133
+ }
134
+
135
+ diff := t .Sub (currentWindow )
136
+ rate := float64 (prevCount )* (float64 (l .windowLength )- float64 (diff ))/ float64 (l .windowLength ) + float64 (currCount )
137
+ if rate > float64 (requestLimit ) {
138
+ return false , rate , nil
139
+ }
140
+
141
+ return true , rate , nil
142
+ }
143
+
146
144
type localCounter struct {
147
145
counters map [uint64 ]* count
148
146
windowLength time.Duration
0 commit comments