diff --git a/cmd/start.go b/cmd/start.go index 244fb59..bda0249 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -51,12 +51,12 @@ var startCmd = &cobra.Command{ // Create middleware stack stack := middleware.CreateStack( - middleware.Logging, // Log requests with structured logging + middleware.Logging(), // Log requests with structured logging middleware.Recovery(), // Catch all panics middleware.Timeout(32*time.Second), // Set request timeout middleware.RequestID(), // Generate a unique request ID middleware.MaxBodySize(1024*1024), // 1MB size limit - middleware.SecureHeaders, // Set secure headers + middleware.SecureHeaders(), // Set secure headers authConfig.Middleware(), // OIDC authentication middleware ) diff --git a/internal/middleware/cors.go b/internal/middleware/cors.go new file mode 100644 index 0000000..33c4228 --- /dev/null +++ b/internal/middleware/cors.go @@ -0,0 +1,121 @@ +package middleware + +import ( + "net/http" + "strconv" + "strings" +) + +// CORSConfig defines the configuration options for CORS middleware +type CORSConfig struct { + // AllowedOrigins defines the origins that are allowed to access the resource + AllowedOrigins []string + + // AllowedMethods defines the HTTP methods that are allowed + AllowedMethods []string + + // AllowedHeaders defines the headers that are allowed in requests + AllowedHeaders []string + + // ExposedHeaders defines the headers that can be read by the browser + ExposedHeaders []string + + // AllowCredentials defines whether cookies, HTTP authentication and client SSL + // certificates can be transmitted in cross-origin requests + AllowCredentials bool + + // MaxAge defines how long (in seconds) the results of a preflight request can be cached + MaxAge int +} + +// DefaultCORSConfig returns a default configuration for CORS middleware +func DefaultCORSConfig() CORSConfig { + return CORSConfig{ + AllowedOrigins: []string{}, + AllowedMethods: []string{"GET"}, + AllowedHeaders: []string{}, + ExposedHeaders: []string{}, + AllowCredentials: false, + MaxAge: 0, + } +} + +// CORS middleware handles Cross-Origin Resource Sharing +func CORS(config CORSConfig) Middleware { + // Precompute values for better performance + allowedOriginsMap := make(map[string]bool) + for _, origin := range config.AllowedOrigins { + allowedOriginsMap[origin] = true + } + + allowAllOrigins := allowedOriginsMap["*"] + allowedMethodsValue := strings.Join(config.AllowedMethods, ", ") + allowedHeadersValue := strings.Join(config.AllowedHeaders, ", ") + exposedHeadersValue := strings.Join(config.ExposedHeaders, ", ") + + // Convert MaxAge to a string once + var maxAgeValue string + if config.MaxAge > 0 { + maxAgeValue = strconv.Itoa(config.MaxAge) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + // Skip if no Origin header is present (same-origin request) + if origin == "" { + next.ServeHTTP(w, r) + return + } + + // Check if the origin is allowed + var allowedOrigin string + if allowAllOrigins { + allowedOrigin = origin + } else if allowedOriginsMap[origin] { + allowedOrigin = origin + } + + // Preflight OPTIONS request + if r.Method == http.MethodOptions { + // Set allowed methods + w.Header().Set("Access-Control-Allow-Methods", allowedMethodsValue) + + // Set allowed headers + if len(allowedHeadersValue) > 0 { + w.Header().Set("Access-Control-Allow-Headers", allowedHeadersValue) + } + + // Set max age for preflight request caching + if maxAgeValue != "" { + w.Header().Set("Access-Control-Max-Age", maxAgeValue) + } + } + + // Set allowed origin + if allowedOrigin != "" { + w.Header().Set("Access-Control-Allow-Origin", allowedOrigin) + } + + // Set exposed headers + if len(exposedHeadersValue) > 0 { + w.Header().Set("Access-Control-Expose-Headers", exposedHeadersValue) + } + + // Set allow credentials + if config.AllowCredentials { + w.Header().Set("Access-Control-Allow-Credentials", "true") + } + + // If it's a preflight request, we're done + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + + // Process the actual request + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/middleware/logging.go b/internal/middleware/logging.go index e7870ce..574cc45 100644 --- a/internal/middleware/logging.go +++ b/internal/middleware/logging.go @@ -19,47 +19,49 @@ func (w *wrappedWriter) WriteHeader(statusCode int) { } // Logging is a middleware function that logs requests with structured logging -func Logging(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - start := time.Now() +func Logging() Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() - // Prepare wrapped writer to capture status code - wrapped := &wrappedWriter{ - ResponseWriter: w, - statusCode: http.StatusOK, - } + // Prepare wrapped writer to capture status code + wrapped := &wrappedWriter{ + ResponseWriter: w, + statusCode: http.StatusOK, + } - // Get request ID from context - requestID := GetRequestID(r.Context()) + // Get request ID from context + requestID := GetRequestID(r.Context()) - // Get logger from the application and add request information - logger := logging.FromContext(r.Context()) - reqLogger := logging.WithValues(logger, - slog.String("request_id", requestID), - slog.String("method", r.Method), - slog.String("path", r.URL.Path), - slog.String("remote_ip", r.RemoteAddr), - slog.String("user_agent", r.UserAgent()), - ) + // Get logger from the application and add request information + logger := logging.FromContext(r.Context()) + reqLogger := logging.WithValues(logger, + slog.String("request_id", requestID), + slog.String("method", r.Method), + slog.String("path", r.URL.Path), + slog.String("remote_ip", r.RemoteAddr), + slog.String("user_agent", r.UserAgent()), + ) - // Store the request-specific logger in context - ctx := logging.WithContext(r.Context(), reqLogger) - r = r.WithContext(ctx) + // Store the request-specific logger in context + ctx := logging.WithContext(r.Context(), reqLogger) + r = r.WithContext(ctx) - // Log request start if in debug mode - reqLogger.Debug("request started") + // Log request start if in debug mode + reqLogger.Debug("request started") - // Process the request with updated context - next.ServeHTTP(wrapped, r) + // Process the request with updated context + next.ServeHTTP(wrapped, r) - // Calculate duration - duration := time.Since(start) + // Calculate duration + duration := time.Since(start) - // Log request completion with status and duration - reqLogger.Info("request completed", - slog.Int("status", wrapped.statusCode), - slog.Duration("duration", duration), - slog.String("duration_human", duration.String()), - ) - }) + // Log request completion with status and duration + reqLogger.Info("request completed", + slog.Int("status", wrapped.statusCode), + slog.Duration("duration", duration), + slog.String("duration_human", duration.String()), + ) + }) + } } diff --git a/internal/middleware/security.go b/internal/middleware/security.go index a390100..63f975f 100644 --- a/internal/middleware/security.go +++ b/internal/middleware/security.go @@ -7,49 +7,51 @@ import ( ) // SecurityHeaders adds security and cache-control headers to all responses -func SecureHeaders(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Caching headers - w.Header().Set("Cache-Control", "no-store") +func SecureHeaders() Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Caching headers + w.Header().Set("Cache-Control", "no-store") - // XSSProtection provides protection against cross-site scripting attack (XSS) - w.Header().Set("X-XSS-Protection", "1; mode=block") + // XSSProtection provides protection against cross-site scripting attack (XSS) + w.Header().Set("X-XSS-Protection", "1; mode=block") - // ContentTypeNosniff provides protection against overriding Content-Type - w.Header().Set("X-Content-Type-Options", "nosniff") + // ContentTypeNosniff provides protection against overriding Content-Type + w.Header().Set("X-Content-Type-Options", "nosniff") - // XFrameOptions prevents the page from being displayed in a frame - w.Header().Set("X-Frame-Options", "DENY") + // XFrameOptions prevents the page from being displayed in a frame + w.Header().Set("X-Frame-Options", "DENY") - // HSTS (HTTP Strict Transport Security) forces the browser to use HTTPS - w.Header().Set("Strict-Transport-Security", "max-age=3600; includeSubDomains") + // HSTS (HTTP Strict Transport Security) forces the browser to use HTTPS + w.Header().Set("Strict-Transport-Security", "max-age=3600; includeSubDomains") - // ReferrerPolicy sets the referrer information passed during navigation - w.Header().Set("Referrer-Policy", "no-referrer") + // ReferrerPolicy sets the referrer information passed during navigation + w.Header().Set("Referrer-Policy", "no-referrer") - // CSP controls the resources the user agent is allowed to load for a page - w.Header().Set("Content-Security-Policy", - "default-src 'self'; "+ - // Allow HTMX to load from unpkg.com - "script-src 'self' https://unpkg.com/htmx.org@*; "+ - "style-src 'self'; "+ - "img-src 'self' data:; "+ - "connect-src 'self'; "+ - "frame-ancestors 'none'; "+ - "form-action 'self'; "+ - "base-uri 'self';") + // CSP controls the resources the user agent is allowed to load for a page + w.Header().Set("Content-Security-Policy", + "default-src 'self'; "+ + // Allow HTMX to load from unpkg.com + "script-src 'self' https://unpkg.com/htmx.org@*; "+ + "style-src 'self'; "+ + "img-src 'self' data:; "+ + "connect-src 'self'; "+ + "frame-ancestors 'none'; "+ + "form-action 'self'; "+ + "base-uri 'self';") - // Cross-Origin-Embedder-Policy prevents cross-origin resources from being loaded - w.Header().Set("Cross-Origin-Embedder-Policy", "require-corp") + // Cross-Origin-Embedder-Policy prevents cross-origin resources from being loaded + w.Header().Set("Cross-Origin-Embedder-Policy", "require-corp") - // Cross-Origin-Opener-Policy prevents cross-origin documents from being loaded - w.Header().Set("Cross-Origin-Opener-Policy", "same-origin") + // Cross-Origin-Opener-Policy prevents cross-origin documents from being loaded + w.Header().Set("Cross-Origin-Opener-Policy", "same-origin") - // Cross-Origin-Resource-Policy prevents cross-origin resources from being loaded - w.Header().Set("Cross-Origin-Resource-Policy", "same-origin") + // Cross-Origin-Resource-Policy prevents cross-origin resources from being loaded + w.Header().Set("Cross-Origin-Resource-Policy", "same-origin") - next.ServeHTTP(w, r) - }) + next.ServeHTTP(w, r) + }) + } } // middleware/csrf.go