Add CORS middleware using rs/cors package for improved handling

This commit is contained in:
2025-04-28 20:16:04 -05:00
parent 6abd03ea96
commit b0a8ec88b3
3 changed files with 15 additions and 75 deletions

View File

@ -2,8 +2,8 @@ package middleware
import (
"net/http"
"strconv"
"strings"
"github.com/rs/cors"
)
// CORSConfig defines the configuration options for CORS middleware
@ -42,80 +42,16 @@ func DefaultCORSConfig() CORSConfig {
// CORS middleware handles Cross-Origin Resource Sharing
func CORS(config CORSConfig) Middleware {
// Precompute values for better performance
allowedOriginsMap := make(map[string]bool)
for _, origin := range config.AllowedOrigins {
allowedOriginsMap[origin] = true
}
allowAllOrigins := allowedOriginsMap["*"]
allowedMethodsValue := strings.Join(config.AllowedMethods, ", ")
allowedHeadersValue := strings.Join(config.AllowedHeaders, ", ")
exposedHeadersValue := strings.Join(config.ExposedHeaders, ", ")
// Convert MaxAge to a string once
var maxAgeValue string
if config.MaxAge > 0 {
maxAgeValue = strconv.Itoa(config.MaxAge)
}
c := cors.New(cors.Options{
AllowedOrigins: config.AllowedOrigins,
AllowedMethods: config.AllowedMethods,
AllowedHeaders: config.AllowedHeaders,
ExposedHeaders: config.ExposedHeaders,
AllowCredentials: config.AllowCredentials,
MaxAge: config.MaxAge,
})
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Skip if no Origin header is present (same-origin request)
if origin == "" {
next.ServeHTTP(w, r)
return
}
// Check if the origin is allowed
var allowedOrigin string
if allowAllOrigins {
allowedOrigin = origin
} else if allowedOriginsMap[origin] {
allowedOrigin = origin
}
// Preflight OPTIONS request
if r.Method == http.MethodOptions {
// Set allowed methods
w.Header().Set("Access-Control-Allow-Methods", allowedMethodsValue)
// Set allowed headers
if len(allowedHeadersValue) > 0 {
w.Header().Set("Access-Control-Allow-Headers", allowedHeadersValue)
}
// Set max age for preflight request caching
if maxAgeValue != "" {
w.Header().Set("Access-Control-Max-Age", maxAgeValue)
}
}
// Set allowed origin
if allowedOrigin != "" {
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
}
// Set exposed headers
if len(exposedHeadersValue) > 0 {
w.Header().Set("Access-Control-Expose-Headers", exposedHeadersValue)
}
// Set allow credentials
if config.AllowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
// If it's a preflight request, we're done
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
// Process the actual request
next.ServeHTTP(w, r)
})
return c.Handler(next)
}
}