package middleware import ( "net/http" "strconv" "strings" ) // CORSConfig defines the configuration options for CORS middleware type CORSConfig struct { // AllowedOrigins defines the origins that are allowed to access the resource AllowedOrigins []string // AllowedMethods defines the HTTP methods that are allowed AllowedMethods []string // AllowedHeaders defines the headers that are allowed in requests AllowedHeaders []string // ExposedHeaders defines the headers that can be read by the browser ExposedHeaders []string // AllowCredentials defines whether cookies, HTTP authentication and client SSL // certificates can be transmitted in cross-origin requests AllowCredentials bool // MaxAge defines how long (in seconds) the results of a preflight request can be cached MaxAge int } // DefaultCORSConfig returns a default configuration for CORS middleware func DefaultCORSConfig() CORSConfig { return CORSConfig{ AllowedOrigins: []string{}, AllowedMethods: []string{"GET"}, AllowedHeaders: []string{}, ExposedHeaders: []string{}, AllowCredentials: false, MaxAge: 0, } } // CORS middleware handles Cross-Origin Resource Sharing func CORS(config CORSConfig) Middleware { // Precompute values for better performance allowedOriginsMap := make(map[string]bool) for _, origin := range config.AllowedOrigins { allowedOriginsMap[origin] = true } allowAllOrigins := allowedOriginsMap["*"] allowedMethodsValue := strings.Join(config.AllowedMethods, ", ") allowedHeadersValue := strings.Join(config.AllowedHeaders, ", ") exposedHeadersValue := strings.Join(config.ExposedHeaders, ", ") // Convert MaxAge to a string once var maxAgeValue string if config.MaxAge > 0 { maxAgeValue = strconv.Itoa(config.MaxAge) } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // Skip if no Origin header is present (same-origin request) if origin == "" { next.ServeHTTP(w, r) return } // Check if the origin is allowed var allowedOrigin string if allowAllOrigins { allowedOrigin = origin } else if allowedOriginsMap[origin] { allowedOrigin = origin } // Preflight OPTIONS request if r.Method == http.MethodOptions { // Set allowed methods w.Header().Set("Access-Control-Allow-Methods", allowedMethodsValue) // Set allowed headers if len(allowedHeadersValue) > 0 { w.Header().Set("Access-Control-Allow-Headers", allowedHeadersValue) } // Set max age for preflight request caching if maxAgeValue != "" { w.Header().Set("Access-Control-Max-Age", maxAgeValue) } } // Set allowed origin if allowedOrigin != "" { w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) } // Set exposed headers if len(exposedHeadersValue) > 0 { w.Header().Set("Access-Control-Expose-Headers", exposedHeadersValue) } // Set allow credentials if config.AllowCredentials { w.Header().Set("Access-Control-Allow-Credentials", "true") } // If it's a preflight request, we're done if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } // Process the actual request next.ServeHTTP(w, r) }) } }