- Add Recovery middleware (catches panics, returns 500, logs stack trace) - Add RateLimiter to middleware chain (30 req/min, burst 60 per IP) - Fix CI token comparison with subtle.ConstantTimeCompare (timing attack) - Middleware chain: Recovery → Logging → RateLimit → CORS → mux Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
175 lines
4.4 KiB
Go
175 lines
4.4 KiB
Go
// package middleware provides HTTP middleware (CORS, logging, rate limiting).
|
|
package middleware
|
|
|
|
import (
|
|
"log"
|
|
"net/http"
|
|
"runtime/debug"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// CORS adds permissive CORS headers for API endpoints.
|
|
// In production, restrict AllowOrigins to your actual domains.
|
|
func CORS(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Access-Control-Allow-Origin", "*")
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-CI-Token")
|
|
|
|
if r.Method == http.MethodOptions {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// Logging logs each request with method, path, status code, and duration.
|
|
func Logging(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
ww := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
|
next.ServeHTTP(ww, r)
|
|
log.Printf("%s %s %d %s %s",
|
|
r.Method, r.RequestURI, ww.status,
|
|
time.Since(start), r.RemoteAddr,
|
|
)
|
|
})
|
|
}
|
|
|
|
// Recovery catches panics in downstream handlers and returns 500.
|
|
// Logs the stack trace for debugging.
|
|
func Recovery(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
defer func() {
|
|
if rec := recover(); rec != nil {
|
|
log.Printf("PANIC: %v\n%s", rec, debug.Stack())
|
|
http.Error(w, `{"error":"Internal Server Error"}`, http.StatusInternalServerError)
|
|
}
|
|
}()
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// statusWriter wraps http.ResponseWriter to capture the status code.
|
|
type statusWriter struct {
|
|
http.ResponseWriter
|
|
status int
|
|
}
|
|
|
|
func (w *statusWriter) WriteHeader(status int) {
|
|
w.status = status
|
|
w.ResponseWriter.WriteHeader(status)
|
|
}
|
|
|
|
// RateLimiter implements a simple per-IP token bucket rate limiter.
|
|
// Not suitable for production behind a proxy (use a real rate limiter then),
|
|
// but sufficient for development and single-instance deployments.
|
|
type RateLimiter struct {
|
|
mu sync.Mutex
|
|
clients map[string]*bucket
|
|
rate int // tokens per interval
|
|
interval time.Duration
|
|
burst int // max bucket size
|
|
}
|
|
|
|
type bucket struct {
|
|
tokens int
|
|
last time.Time
|
|
}
|
|
|
|
// NewRateLimiter creates a rate limiter allowing `rate` requests per `interval`,
|
|
// with a maximum burst of `burst`.
|
|
func NewRateLimiter(rate int, interval time.Duration, burst int) *RateLimiter {
|
|
rl := &RateLimiter{
|
|
clients: make(map[string]*bucket),
|
|
rate: rate,
|
|
interval: interval,
|
|
burst: burst,
|
|
}
|
|
// Periodically clean up stale entries.
|
|
go rl.cleanup()
|
|
return rl
|
|
}
|
|
|
|
// Limit returns an HTTP middleware that rate-limits requests by client IP.
|
|
func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := clientIP(r)
|
|
if !rl.allow(ip) {
|
|
w.Header().Set("Retry-After", "60")
|
|
http.Error(w, `{"error":"Rate limit exceeded"}`, http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func (rl *RateLimiter) allow(ip string) bool {
|
|
rl.mu.Lock()
|
|
defer rl.mu.Unlock()
|
|
|
|
b, ok := rl.clients[ip]
|
|
if !ok {
|
|
rl.clients[ip] = &bucket{tokens: rl.burst - 1, last: time.Now()}
|
|
return true
|
|
}
|
|
|
|
// Refill tokens based on elapsed time.
|
|
elapsed := time.Since(b.last)
|
|
refill := int(elapsed / rl.interval * time.Duration(rl.rate))
|
|
if refill > 0 {
|
|
b.tokens = min(b.tokens+refill, rl.burst)
|
|
b.last = time.Now()
|
|
}
|
|
|
|
if b.tokens > 0 {
|
|
b.tokens--
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (rl *RateLimiter) cleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
for range ticker.C {
|
|
rl.mu.Lock()
|
|
now := time.Now()
|
|
for ip, b := range rl.clients {
|
|
if now.Sub(b.last) > 10*time.Minute {
|
|
delete(rl.clients, ip)
|
|
}
|
|
}
|
|
rl.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
func clientIP(r *http.Request) string {
|
|
// Check X-Forwarded-For first (if behind a proxy).
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
if idx := strings.IndexByte(xff, ','); idx != -1 {
|
|
return strings.TrimSpace(xff[:idx])
|
|
}
|
|
return strings.TrimSpace(xff)
|
|
}
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return xri
|
|
}
|
|
// Fall back to RemoteAddr (strip port).
|
|
host, _, ok := strings.Cut(r.RemoteAddr, ":")
|
|
if !ok {
|
|
return r.RemoteAddr
|
|
}
|
|
return host
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|