CORS middleware. And fix middleware types for logging and securityheaders.

This commit is contained in:
Christian Galo 2025-04-28 02:47:22 -05:00
parent e8f22496e1
commit c2265330dd
4 changed files with 195 additions and 70 deletions

View File

@ -51,12 +51,12 @@ var startCmd = &cobra.Command{
// Create middleware stack
stack := middleware.CreateStack(
middleware.Logging, // Log requests with structured logging
middleware.Logging(), // Log requests with structured logging
middleware.Recovery(), // Catch all panics
middleware.Timeout(32*time.Second), // Set request timeout
middleware.RequestID(), // Generate a unique request ID
middleware.MaxBodySize(1024*1024), // 1MB size limit
middleware.SecureHeaders, // Set secure headers
middleware.SecureHeaders(), // Set secure headers
authConfig.Middleware(), // OIDC authentication middleware
)

121
internal/middleware/cors.go Normal file
View File

@ -0,0 +1,121 @@
package middleware
import (
"net/http"
"strconv"
"strings"
)
// CORSConfig defines the configuration options for CORS middleware
type CORSConfig struct {
// AllowedOrigins defines the origins that are allowed to access the resource
AllowedOrigins []string
// AllowedMethods defines the HTTP methods that are allowed
AllowedMethods []string
// AllowedHeaders defines the headers that are allowed in requests
AllowedHeaders []string
// ExposedHeaders defines the headers that can be read by the browser
ExposedHeaders []string
// AllowCredentials defines whether cookies, HTTP authentication and client SSL
// certificates can be transmitted in cross-origin requests
AllowCredentials bool
// MaxAge defines how long (in seconds) the results of a preflight request can be cached
MaxAge int
}
// DefaultCORSConfig returns a default configuration for CORS middleware
func DefaultCORSConfig() CORSConfig {
return CORSConfig{
AllowedOrigins: []string{},
AllowedMethods: []string{"GET"},
AllowedHeaders: []string{},
ExposedHeaders: []string{},
AllowCredentials: false,
MaxAge: 0,
}
}
// CORS middleware handles Cross-Origin Resource Sharing
func CORS(config CORSConfig) Middleware {
// Precompute values for better performance
allowedOriginsMap := make(map[string]bool)
for _, origin := range config.AllowedOrigins {
allowedOriginsMap[origin] = true
}
allowAllOrigins := allowedOriginsMap["*"]
allowedMethodsValue := strings.Join(config.AllowedMethods, ", ")
allowedHeadersValue := strings.Join(config.AllowedHeaders, ", ")
exposedHeadersValue := strings.Join(config.ExposedHeaders, ", ")
// Convert MaxAge to a string once
var maxAgeValue string
if config.MaxAge > 0 {
maxAgeValue = strconv.Itoa(config.MaxAge)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Skip if no Origin header is present (same-origin request)
if origin == "" {
next.ServeHTTP(w, r)
return
}
// Check if the origin is allowed
var allowedOrigin string
if allowAllOrigins {
allowedOrigin = origin
} else if allowedOriginsMap[origin] {
allowedOrigin = origin
}
// Preflight OPTIONS request
if r.Method == http.MethodOptions {
// Set allowed methods
w.Header().Set("Access-Control-Allow-Methods", allowedMethodsValue)
// Set allowed headers
if len(allowedHeadersValue) > 0 {
w.Header().Set("Access-Control-Allow-Headers", allowedHeadersValue)
}
// Set max age for preflight request caching
if maxAgeValue != "" {
w.Header().Set("Access-Control-Max-Age", maxAgeValue)
}
}
// Set allowed origin
if allowedOrigin != "" {
w.Header().Set("Access-Control-Allow-Origin", allowedOrigin)
}
// Set exposed headers
if len(exposedHeadersValue) > 0 {
w.Header().Set("Access-Control-Expose-Headers", exposedHeadersValue)
}
// Set allow credentials
if config.AllowCredentials {
w.Header().Set("Access-Control-Allow-Credentials", "true")
}
// If it's a preflight request, we're done
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
// Process the actual request
next.ServeHTTP(w, r)
})
}
}

View File

@ -19,7 +19,8 @@ func (w *wrappedWriter) WriteHeader(statusCode int) {
}
// Logging is a middleware function that logs requests with structured logging
func Logging(next http.Handler) http.Handler {
func Logging() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
@ -63,3 +64,4 @@ func Logging(next http.Handler) http.Handler {
)
})
}
}

View File

@ -7,7 +7,8 @@ import (
)
// SecurityHeaders adds security and cache-control headers to all responses
func SecureHeaders(next http.Handler) http.Handler {
func SecureHeaders() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Caching headers
w.Header().Set("Cache-Control", "no-store")
@ -51,6 +52,7 @@ func SecureHeaders(next http.Handler) http.Handler {
next.ServeHTTP(w, r)
})
}
}
// middleware/csrf.go
func CSRFMiddleware(store sessions.Store) Middleware {