package certmagic import ( "context" "log" "runtime" "sync" "time" ) // NewRateLimiter returns a rate limiter that allows up to maxEvents // in a sliding window of size window. If maxEvents and window are // both 0, or if maxEvents is non-zero and window is 0, rate limiting // is disabled. This function panics if maxEvents is less than 0 or // if maxEvents is 0 and window is non-zero, which is considered to be // an invalid configuration, as it would never allow events. func NewRateLimiter(maxEvents int, window time.Duration) *RingBufferRateLimiter { if maxEvents < 0 { panic("maxEvents cannot be less than zero") } if maxEvents == 0 && window != 0 { panic("invalid configuration: maxEvents = 0 and window != 0 would not allow any events") } rbrl := &RingBufferRateLimiter{ window: window, ring: make([]time.Time, maxEvents), started: make(chan struct{}), stopped: make(chan struct{}), ticket: make(chan struct{}), } go rbrl.loop() <-rbrl.started // make sure loop is ready to receive before we return return rbrl } // RingBufferRateLimiter uses a ring to enforce rate limits // consisting of a maximum number of events within a single // sliding window of a given duration. An empty value is // not valid; use NewRateLimiter to get one. type RingBufferRateLimiter struct { window time.Duration ring []time.Time // maxEvents == len(ring) cursor int // always points to the oldest timestamp mu sync.Mutex // protects ring, cursor, and window started chan struct{} stopped chan struct{} ticket chan struct{} } // Stop cleans up r's scheduling goroutine. func (r *RingBufferRateLimiter) Stop() { close(r.stopped) } func (r *RingBufferRateLimiter) loop() { defer func() { if err := recover(); err != nil { buf := make([]byte, stackTraceBufferSize) buf = buf[:runtime.Stack(buf, false)] log.Printf("panic: ring buffer rate limiter: %v\n%s", err, buf) } }() for { // if we've been stopped, return select { case <-r.stopped: return default: } if len(r.ring) == 0 { if r.window == 0 { // rate limiting is disabled; always allow immediately r.permit() continue } panic("invalid configuration: maxEvents = 0 and window != 0 does not allow any events") } // wait until next slot is available or until we've been stopped r.mu.Lock() then := r.ring[r.cursor].Add(r.window) r.mu.Unlock() waitDuration := time.Until(then) waitTimer := time.NewTimer(waitDuration) select { case <-waitTimer.C: r.permit() case <-r.stopped: waitTimer.Stop() return } } } // Allow returns true if the event is allowed to // happen right now. It does not wait. If the event // is allowed, a ticket is claimed. func (r *RingBufferRateLimiter) Allow() bool { select { case <-r.ticket: return true default: return false } } // Wait blocks until the event is allowed to occur. It returns an // error if the context is cancelled. func (r *RingBufferRateLimiter) Wait(ctx context.Context) error { select { case <-ctx.Done(): return context.Canceled case <-r.ticket: return nil } } // MaxEvents returns the maximum number of events that // are allowed within the sliding window. func (r *RingBufferRateLimiter) MaxEvents() int { r.mu.Lock() defer r.mu.Unlock() return len(r.ring) } // SetMaxEvents changes the maximum number of events that are // allowed in the sliding window. If the new limit is lower, // the oldest events will be forgotten. If the new limit is // higher, the window will suddenly have capacity for new // reservations. It panics if maxEvents is 0 and window size // is not zero. func (r *RingBufferRateLimiter) SetMaxEvents(maxEvents int) { newRing := make([]time.Time, maxEvents) r.mu.Lock() defer r.mu.Unlock() if r.window != 0 && maxEvents == 0 { panic("invalid configuration: maxEvents = 0 and window != 0 would not allow any events") } // only make the change if the new limit is different if maxEvents == len(r.ring) { return } // the new ring may be smaller; fast-forward to the // oldest timestamp that will be kept in the new // ring so the oldest ones are forgotten and the // newest ones will be remembered sizeDiff := len(r.ring) - maxEvents for i := 0; i < sizeDiff; i++ { r.advance() } if len(r.ring) > 0 { // copy timestamps into the new ring until we // have either copied all of them or have reached // the capacity of the new ring startCursor := r.cursor for i := 0; i < len(newRing); i++ { newRing[i] = r.ring[r.cursor] r.advance() if r.cursor == startCursor { // new ring is larger than old one; // "we've come full circle" break } } } r.ring = newRing r.cursor = 0 } // Window returns the size of the sliding window. func (r *RingBufferRateLimiter) Window() time.Duration { r.mu.Lock() defer r.mu.Unlock() return r.window } // SetWindow changes r's sliding window duration to window. // Goroutines that are already blocked on a call to Wait() // will not be affected. It panics if window is non-zero // but the max event limit is 0. func (r *RingBufferRateLimiter) SetWindow(window time.Duration) { r.mu.Lock() defer r.mu.Unlock() if window != 0 && len(r.ring) == 0 { panic("invalid configuration: maxEvents = 0 and window != 0 would not allow any events") } r.window = window } // permit allows one event through the throttle. This method // blocks until a goroutine is waiting for a ticket or until // the rate limiter is stopped. func (r *RingBufferRateLimiter) permit() { for { select { case r.started <- struct{}{}: // notify parent goroutine that we've started; should // only happen once, before constructor returns continue case <-r.stopped: return case r.ticket <- struct{}{}: r.mu.Lock() defer r.mu.Unlock() if len(r.ring) > 0 { r.ring[r.cursor] = time.Now() r.advance() } return } } } // advance moves the cursor to the next position. // It is NOT safe for concurrent use, so it must // be called inside a lock on r.mu. func (r *RingBufferRateLimiter) advance() { r.cursor++ if r.cursor >= len(r.ring) { r.cursor = 0 } }