Skip to content

Commit 3327e65

Browse files
authored
Add IncrementBy (#21)
1 parent 2aed83f commit 3327e65

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

context.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package httprate
2+
3+
import "context"
4+
5+
var incrementKey = &struct{}{}
6+
7+
func WithIncrement(ctx context.Context, value int) context.Context {
8+
return context.WithValue(ctx, incrementKey, value)
9+
}
10+
11+
func getIncrement(ctx context.Context) int {
12+
if value, ok := ctx.Value(incrementKey).(int); ok {
13+
return value
14+
}
15+
return 1
16+
}

limiter.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
type LimitCounter interface {
1414
Config(requestLimit int, windowLength time.Duration)
1515
Increment(key string, currentWindow time.Time) error
16+
IncrementBy(key string, currentWindow time.Time, amount int) error
1617
Get(key string, currentWindow, previousWindow time.Time) (int, int, error)
1718
}
1819

@@ -119,7 +120,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
119120
return
120121
}
121122

122-
err = l.limitCounter.Increment(key, currentWindow)
123+
err = l.limitCounter.IncrementBy(key, currentWindow, getIncrement(r.Context()))
123124
if err != nil {
124125
l.mu.Unlock()
125126
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -152,6 +153,10 @@ func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {
152153
}
153154

154155
func (c *localCounter) Increment(key string, currentWindow time.Time) error {
156+
return c.IncrementBy(key, currentWindow, 1)
157+
}
158+
159+
func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount int) error {
155160
c.evict()
156161

157162
c.mu.Lock()
@@ -164,7 +169,7 @@ func (c *localCounter) Increment(key string, currentWindow time.Time) error {
164169
v = &count{}
165170
c.counters[hkey] = v
166171
}
167-
v.value += 1
172+
v.value += amount
168173
v.updatedAt = time.Now()
169174

170175
return nil

limiter_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,45 @@ func TestLimit(t *testing.T) {
4949
}
5050
}
5151

52+
func TestWithIncrement(t *testing.T) {
53+
type test struct {
54+
name string
55+
requestsLimit int
56+
windowLength time.Duration
57+
respCodes []int
58+
}
59+
tests := []test{
60+
{
61+
name: "no-block",
62+
requestsLimit: 3,
63+
windowLength: 4 * time.Second,
64+
respCodes: []int{200, 200, 429},
65+
},
66+
{
67+
name: "block",
68+
requestsLimit: 3,
69+
windowLength: 2 * time.Second,
70+
respCodes: []int{200, 200, 429, 429},
71+
},
72+
}
73+
for i, tt := range tests {
74+
t.Run(tt.name, func(t *testing.T) {
75+
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
76+
router := httprate.LimitAll(tt.requestsLimit, tt.windowLength)(h)
77+
78+
for _, code := range tt.respCodes {
79+
req := httptest.NewRequest("GET", "/", nil)
80+
req = req.WithContext(httprate.WithIncrement(req.Context(), 2))
81+
recorder := httptest.NewRecorder()
82+
router.ServeHTTP(recorder, req)
83+
if respCode := recorder.Result().StatusCode; respCode != code {
84+
t.Errorf("resp.StatusCode(%v) = %v, want %v", i, respCode, code)
85+
}
86+
}
87+
})
88+
}
89+
}
90+
5291
func TestLimitHandler(t *testing.T) {
5392
type test struct {
5493
name string

0 commit comments

Comments
 (0)