230 lines
6.0 KiB
Go
230 lines
6.0 KiB
Go
package certmagic
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"runtime"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// NewRateLimiter returns a rate limiter that allows up to maxEvents
|
|
// in a sliding window of size window. If maxEvents and window are
|
|
// both 0, or if maxEvents is non-zero and window is 0, rate limiting
|
|
// is disabled. This function panics if maxEvents is less than 0 or
|
|
// if maxEvents is 0 and window is non-zero, which is considered to be
|
|
// an invalid configuration, as it would never allow events.
|
|
func NewRateLimiter(maxEvents int, window time.Duration) *RingBufferRateLimiter {
|
|
if maxEvents < 0 {
|
|
panic("maxEvents cannot be less than zero")
|
|
}
|
|
if maxEvents == 0 && window != 0 {
|
|
panic("invalid configuration: maxEvents = 0 and window != 0 would not allow any events")
|
|
}
|
|
rbrl := &RingBufferRateLimiter{
|
|
window: window,
|
|
ring: make([]time.Time, maxEvents),
|
|
started: make(chan struct{}),
|
|
stopped: make(chan struct{}),
|
|
ticket: make(chan struct{}),
|
|
}
|
|
go rbrl.loop()
|
|
<-rbrl.started // make sure loop is ready to receive before we return
|
|
return rbrl
|
|
}
|
|
|
|
// RingBufferRateLimiter uses a ring to enforce rate limits
|
|
// consisting of a maximum number of events within a single
|
|
// sliding window of a given duration. An empty value is
|
|
// not valid; use NewRateLimiter to get one.
|
|
type RingBufferRateLimiter struct {
|
|
window time.Duration
|
|
ring []time.Time // maxEvents == len(ring)
|
|
cursor int // always points to the oldest timestamp
|
|
mu sync.Mutex // protects ring, cursor, and window
|
|
started chan struct{}
|
|
stopped chan struct{}
|
|
ticket chan struct{}
|
|
}
|
|
|
|
// Stop cleans up r's scheduling goroutine.
|
|
func (r *RingBufferRateLimiter) Stop() {
|
|
close(r.stopped)
|
|
}
|
|
|
|
func (r *RingBufferRateLimiter) loop() {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
buf := make([]byte, stackTraceBufferSize)
|
|
buf = buf[:runtime.Stack(buf, false)]
|
|
log.Printf("panic: ring buffer rate limiter: %v\n%s", err, buf)
|
|
}
|
|
}()
|
|
|
|
for {
|
|
// if we've been stopped, return
|
|
select {
|
|
case <-r.stopped:
|
|
return
|
|
default:
|
|
}
|
|
|
|
if len(r.ring) == 0 {
|
|
if r.window == 0 {
|
|
// rate limiting is disabled; always allow immediately
|
|
r.permit()
|
|
continue
|
|
}
|
|
panic("invalid configuration: maxEvents = 0 and window != 0 does not allow any events")
|
|
}
|
|
|
|
// wait until next slot is available or until we've been stopped
|
|
r.mu.Lock()
|
|
then := r.ring[r.cursor].Add(r.window)
|
|
r.mu.Unlock()
|
|
waitDuration := time.Until(then)
|
|
waitTimer := time.NewTimer(waitDuration)
|
|
select {
|
|
case <-waitTimer.C:
|
|
r.permit()
|
|
case <-r.stopped:
|
|
waitTimer.Stop()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Allow returns true if the event is allowed to
|
|
// happen right now. It does not wait. If the event
|
|
// is allowed, a ticket is claimed.
|
|
func (r *RingBufferRateLimiter) Allow() bool {
|
|
select {
|
|
case <-r.ticket:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Wait blocks until the event is allowed to occur. It returns an
|
|
// error if the context is cancelled.
|
|
func (r *RingBufferRateLimiter) Wait(ctx context.Context) error {
|
|
select {
|
|
case <-ctx.Done():
|
|
return context.Canceled
|
|
case <-r.ticket:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// MaxEvents returns the maximum number of events that
|
|
// are allowed within the sliding window.
|
|
func (r *RingBufferRateLimiter) MaxEvents() int {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
return len(r.ring)
|
|
}
|
|
|
|
// SetMaxEvents changes the maximum number of events that are
|
|
// allowed in the sliding window. If the new limit is lower,
|
|
// the oldest events will be forgotten. If the new limit is
|
|
// higher, the window will suddenly have capacity for new
|
|
// reservations. It panics if maxEvents is 0 and window size
|
|
// is not zero.
|
|
func (r *RingBufferRateLimiter) SetMaxEvents(maxEvents int) {
|
|
newRing := make([]time.Time, maxEvents)
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
|
|
if r.window != 0 && maxEvents == 0 {
|
|
panic("invalid configuration: maxEvents = 0 and window != 0 would not allow any events")
|
|
}
|
|
|
|
// only make the change if the new limit is different
|
|
if maxEvents == len(r.ring) {
|
|
return
|
|
}
|
|
|
|
// the new ring may be smaller; fast-forward to the
|
|
// oldest timestamp that will be kept in the new
|
|
// ring so the oldest ones are forgotten and the
|
|
// newest ones will be remembered
|
|
sizeDiff := len(r.ring) - maxEvents
|
|
for i := 0; i < sizeDiff; i++ {
|
|
r.advance()
|
|
}
|
|
|
|
if len(r.ring) > 0 {
|
|
// copy timestamps into the new ring until we
|
|
// have either copied all of them or have reached
|
|
// the capacity of the new ring
|
|
startCursor := r.cursor
|
|
for i := 0; i < len(newRing); i++ {
|
|
newRing[i] = r.ring[r.cursor]
|
|
r.advance()
|
|
if r.cursor == startCursor {
|
|
// new ring is larger than old one;
|
|
// "we've come full circle"
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
r.ring = newRing
|
|
r.cursor = 0
|
|
}
|
|
|
|
// Window returns the size of the sliding window.
|
|
func (r *RingBufferRateLimiter) Window() time.Duration {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
return r.window
|
|
}
|
|
|
|
// SetWindow changes r's sliding window duration to window.
|
|
// Goroutines that are already blocked on a call to Wait()
|
|
// will not be affected. It panics if window is non-zero
|
|
// but the max event limit is 0.
|
|
func (r *RingBufferRateLimiter) SetWindow(window time.Duration) {
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if window != 0 && len(r.ring) == 0 {
|
|
panic("invalid configuration: maxEvents = 0 and window != 0 would not allow any events")
|
|
}
|
|
r.window = window
|
|
}
|
|
|
|
// permit allows one event through the throttle. This method
|
|
// blocks until a goroutine is waiting for a ticket or until
|
|
// the rate limiter is stopped.
|
|
func (r *RingBufferRateLimiter) permit() {
|
|
for {
|
|
select {
|
|
case r.started <- struct{}{}:
|
|
// notify parent goroutine that we've started; should
|
|
// only happen once, before constructor returns
|
|
continue
|
|
case <-r.stopped:
|
|
return
|
|
case r.ticket <- struct{}{}:
|
|
r.mu.Lock()
|
|
defer r.mu.Unlock()
|
|
if len(r.ring) > 0 {
|
|
r.ring[r.cursor] = time.Now()
|
|
r.advance()
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// advance moves the cursor to the next position.
|
|
// It is NOT safe for concurrent use, so it must
|
|
// be called inside a lock on r.mu.
|
|
func (r *RingBufferRateLimiter) advance() {
|
|
r.cursor++
|
|
if r.cursor >= len(r.ring) {
|
|
r.cursor = 0
|
|
}
|
|
}
|