diff --git a/go.mod b/go.mod index a7fc4b0..2034a7d 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,8 @@ require ( github.com/spf13/viper v1.19.0 ) +require github.com/rs/cors v1.11.1 // indirect + require ( github.com/coreos/go-oidc/v3 v3.12.0 github.com/fsnotify/fsnotify v1.7.0 // indirect diff --git a/go.sum b/go.sum index aa8bec5..da4a4d9 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go index 33c4228..b0ebac9 100644 --- a/internal/middleware/cors.go +++ b/internal/middleware/cors.go @@ -2,8 +2,8 @@ package middleware import ( "net/http" - "strconv" - "strings" + + "github.com/rs/cors" ) // CORSConfig defines the configuration options for CORS middleware @@ -42,80 +42,16 @@ func DefaultCORSConfig() CORSConfig { // CORS middleware handles Cross-Origin Resource Sharing func CORS(config CORSConfig) Middleware { - // Precompute values for better performance - allowedOriginsMap := make(map[string]bool) - for _, origin := range config.AllowedOrigins { - allowedOriginsMap[origin] = true - } - - allowAllOrigins := allowedOriginsMap["*"] - allowedMethodsValue := strings.Join(config.AllowedMethods, ", ") - allowedHeadersValue := strings.Join(config.AllowedHeaders, ", ") - exposedHeadersValue := strings.Join(config.ExposedHeaders, ", ") - - // Convert MaxAge to a string once - var maxAgeValue string - if config.MaxAge > 0 { - maxAgeValue = strconv.Itoa(config.MaxAge) - } + c := cors.New(cors.Options{ + AllowedOrigins: config.AllowedOrigins, + AllowedMethods: config.AllowedMethods, + AllowedHeaders: config.AllowedHeaders, + ExposedHeaders: config.ExposedHeaders, + AllowCredentials: config.AllowCredentials, + MaxAge: config.MaxAge, + }) return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - origin := r.Header.Get("Origin") - - // Skip if no Origin header is present (same-origin request) - if origin == "" { - next.ServeHTTP(w, r) - return - } - - // Check if the origin is allowed - var allowedOrigin string - if allowAllOrigins { - allowedOrigin = origin - } else if allowedOriginsMap[origin] { - allowedOrigin = origin - } - - // Preflight OPTIONS request - if r.Method == http.MethodOptions { - // Set allowed methods - w.Header().Set("Access-Control-Allow-Methods", allowedMethodsValue) - - // Set allowed headers - if len(allowedHeadersValue) > 0 { - w.Header().Set("Access-Control-Allow-Headers", allowedHeadersValue) - } - - // Set max age for preflight request caching - if maxAgeValue != "" { - w.Header().Set("Access-Control-Max-Age", maxAgeValue) - } - } - - // Set allowed origin - if allowedOrigin != "" { - w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) - } - - // Set exposed headers - if len(exposedHeadersValue) > 0 { - w.Header().Set("Access-Control-Expose-Headers", exposedHeadersValue) - } - - // Set allow credentials - if config.AllowCredentials { - w.Header().Set("Access-Control-Allow-Credentials", "true") - } - - // If it's a preflight request, we're done - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusNoContent) - return - } - - // Process the actual request - next.ServeHTTP(w, r) - }) + return c.Handler(next) } }