Add CORS middleware using rs/cors package for improved handling
This commit is contained in:
@ -2,8 +2,8 @@ package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/cors"
|
||||
)
|
||||
|
||||
// CORSConfig defines the configuration options for CORS middleware
|
||||
@ -42,80 +42,16 @@ func DefaultCORSConfig() CORSConfig {
|
||||
|
||||
// CORS middleware handles Cross-Origin Resource Sharing
|
||||
func CORS(config CORSConfig) Middleware {
|
||||
// Precompute values for better performance
|
||||
allowedOriginsMap := make(map[string]bool)
|
||||
for _, origin := range config.AllowedOrigins {
|
||||
allowedOriginsMap[origin] = true
|
||||
}
|
||||
|
||||
allowAllOrigins := allowedOriginsMap["*"]
|
||||
allowedMethodsValue := strings.Join(config.AllowedMethods, ", ")
|
||||
allowedHeadersValue := strings.Join(config.AllowedHeaders, ", ")
|
||||
exposedHeadersValue := strings.Join(config.ExposedHeaders, ", ")
|
||||
|
||||
// Convert MaxAge to a string once
|
||||
var maxAgeValue string
|
||||
if config.MaxAge > 0 {
|
||||
maxAgeValue = strconv.Itoa(config.MaxAge)
|
||||
}
|
||||
c := cors.New(cors.Options{
|
||||
AllowedOrigins: config.AllowedOrigins,
|
||||
AllowedMethods: config.AllowedMethods,
|
||||
AllowedHeaders: config.AllowedHeaders,
|
||||
ExposedHeaders: config.ExposedHeaders,
|
||||
AllowCredentials: config.AllowCredentials,
|
||||
MaxAge: config.MaxAge,
|
||||
})
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
origin := r.Header.Get("Origin")
|
||||
|
||||
// Skip if no Origin header is present (same-origin request)
|
||||
if origin == "" {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the origin is allowed
|
||||
var allowedOrigin string
|
||||
if allowAllOrigins {
|
||||
allowedOrigin = origin
|
||||
} else if allowedOriginsMap[origin] {
|
||||
allowedOrigin = origin
|
||||
}
|
||||
|
||||
// Preflight OPTIONS request
|
||||
if r.Method == http.MethodOptions {
|
||||
// Set allowed methods
|
||||
w.Header().Set("Access-Control-Allow-Methods", allowedMethodsValue)
|
||||
|
||||
// Set allowed headers
|
||||
if len(allowedHeadersValue) > 0 {
|
||||
w.Header().Set("Access-Control-Allow-Headers", allowedHeadersValue)
|
||||
}
|
||||
|
||||
// Set max age for preflight request caching
|
||||
if maxAgeValue != "" {
|
||||
w.Header().Set("Access-Control-Max-Age", maxAgeValue)
|
||||
}
|
||||
}
|
||||
|
||||
// Set allowed origin
|
||||
if allowedOrigin != "" {
|
||||
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
|
||||
}
|
||||
|
||||
// Set exposed headers
|
||||
if len(exposedHeadersValue) > 0 {
|
||||
w.Header().Set("Access-Control-Expose-Headers", exposedHeadersValue)
|
||||
}
|
||||
|
||||
// Set allow credentials
|
||||
if config.AllowCredentials {
|
||||
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
// If it's a preflight request, we're done
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
// Process the actual request
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
return c.Handler(next)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user