// package middleware provides HTTP middleware (CORS, logging, rate limiting). package middleware import ( "log" "net/http" "runtime/debug" "strings" "sync" "time" ) // CORS adds permissive CORS headers for API endpoints. // In production, restrict AllowOrigins to your actual domains. func CORS(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, X-CI-Token") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } // Logging logs each request with method, path, status code, and duration. func Logging(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() ww := &statusWriter{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(ww, r) log.Printf("%s %s %d %s %s", r.Method, r.RequestURI, ww.status, time.Since(start), r.RemoteAddr, ) }) } // Recovery catches panics in downstream handlers and returns 500. // Logs the stack trace for debugging. func Recovery(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if rec := recover(); rec != nil { log.Printf("PANIC: %v\n%s", rec, debug.Stack()) http.Error(w, `{"error":"Internal Server Error"}`, http.StatusInternalServerError) } }() next.ServeHTTP(w, r) }) } // statusWriter wraps http.ResponseWriter to capture the status code. type statusWriter struct { http.ResponseWriter status int } func (w *statusWriter) WriteHeader(status int) { w.status = status w.ResponseWriter.WriteHeader(status) } // RateLimiter implements a simple per-IP token bucket rate limiter. // Not suitable for production behind a proxy (use a real rate limiter then), // but sufficient for development and single-instance deployments. type RateLimiter struct { mu sync.Mutex clients map[string]*bucket rate int // tokens per interval interval time.Duration burst int // max bucket size } type bucket struct { tokens int last time.Time } // NewRateLimiter creates a rate limiter allowing `rate` requests per `interval`, // with a maximum burst of `burst`. func NewRateLimiter(rate int, interval time.Duration, burst int) *RateLimiter { rl := &RateLimiter{ clients: make(map[string]*bucket), rate: rate, interval: interval, burst: burst, } // Periodically clean up stale entries. go rl.cleanup() return rl } // Limit returns an HTTP middleware that rate-limits requests by client IP. func (rl *RateLimiter) Limit(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := clientIP(r) if !rl.allow(ip) { w.Header().Set("Retry-After", "60") http.Error(w, `{"error":"Rate limit exceeded"}`, http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) } func (rl *RateLimiter) allow(ip string) bool { rl.mu.Lock() defer rl.mu.Unlock() b, ok := rl.clients[ip] if !ok { rl.clients[ip] = &bucket{tokens: rl.burst - 1, last: time.Now()} return true } // Refill tokens based on elapsed time. elapsed := time.Since(b.last) refill := int(elapsed / rl.interval * time.Duration(rl.rate)) if refill > 0 { b.tokens = min(b.tokens+refill, rl.burst) b.last = time.Now() } if b.tokens > 0 { b.tokens-- return true } return false } func (rl *RateLimiter) cleanup() { ticker := time.NewTicker(5 * time.Minute) for range ticker.C { rl.mu.Lock() now := time.Now() for ip, b := range rl.clients { if now.Sub(b.last) > 10*time.Minute { delete(rl.clients, ip) } } rl.mu.Unlock() } } func clientIP(r *http.Request) string { // Check X-Forwarded-For first (if behind a proxy). if xff := r.Header.Get("X-Forwarded-For"); xff != "" { if idx := strings.IndexByte(xff, ','); idx != -1 { return strings.TrimSpace(xff[:idx]) } return strings.TrimSpace(xff) } if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } // Fall back to RemoteAddr (strip port). host, _, ok := strings.Cut(r.RemoteAddr, ":") if !ok { return r.RemoteAddr } return host } func min(a, b int) int { if a < b { return a } return b }