diff --git a/cmd/start.go b/cmd/start.go index a4f4934..bc84f45 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -2,6 +2,7 @@ package cmd import ( "context" + "crypto/rand" "log/slog" "net" "net/http" @@ -51,6 +52,39 @@ var startCmd = &cobra.Command{ // Create CORS configuration corsConfig := middleware.DefaultCORSConfig() + + // Start with minimal default configuration + csrfConfig := middleware.DefaultCSRFConfig() + + // Set CSRF secret from config or generate a random one + csrfSecret := viper.GetString("csrf-secret") + var csrfKey []byte + + if csrfSecret != "" { + // Use configured secret - must be at least 32 bytes + csrfKey = []byte(csrfSecret) + if len(csrfKey) < 32 { + logger.Error("csrf-secret must be at least 32 bytes") + return + } + } else { + // Generate a random secret + csrfKey = make([]byte, 32) + _, err = rand.Read(csrfKey) + if err != nil { + logger.Error("failed to generate CSRF key", slog.Any("error", err)) + return + } + logger.Info("generated random CSRF key, consider setting csrf-secret for stability across restarts") + } + + csrfConfig.Secret = csrfKey + + // Only override specific settings when needed + if env == "development" { + // In development, cookies often need to work without HTTPS + csrfConfig.Cookie.Secure = false + } // Create middleware stack stack := middleware.CreateStack( @@ -61,6 +95,7 @@ var startCmd = &cobra.Command{ middleware.MaxBodySize(1024*1024), // 1MB size limit middleware.SecureHeaders(), // Set secure headers middleware.CORS(corsConfig), // CORS configuration + middleware.CSRF(csrfConfig), // CSRF protection authConfig.Middleware(), // OIDC authentication middleware ) @@ -99,6 +134,7 @@ func init() { startCmd.Flags().String("issuer-url", "", "Identity Provider Issuer URL") startCmd.Flags().String("hostname", "", "Address at which the server is exposed") startCmd.Flags().String("session-secret", "", "Session encryption secret") + startCmd.Flags().String("csrf-secret", "", "Secret key for CSRF protection (min 32 bytes)") startCmd.Flags().String("env", "", "Environment (development/production)") // Bind all flags to Viper @@ -110,4 +146,4 @@ func init() { // Add the command to the root command rootCmd.AddCommand(startCmd) -} +} \ No newline at end of file diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go index 5280e37..fa91cce 100644 --- a/internal/middleware/csrf.go +++ b/internal/middleware/csrf.go @@ -14,38 +14,25 @@ type CSRFConfig struct { // Cookie defines cookie options Cookie struct { - // Name of the CSRF cookie - Name string - - // Path where the cookie is valid - Path string - - // Domain where the cookie is valid - Domain string - - // MaxAge in seconds for the CSRF cookie - MaxAge int - - // Secure sets the Secure flag on the cookie - Secure bool - - // HttpOnly sets the HttpOnly flag on the cookie + Name string + Domain string HttpOnly bool - - // SameSite sets the SameSite attribute on the cookie + MaxAge int + Path string SameSite csrf.SameSiteMode + Secure bool } - // ErrorHandler to call when CSRF validation fails + // ErrorHandler is a custom error handler for CSRF errors ErrorHandler http.Handler - // FieldName is the name of the hidden input field used by frontend + // FieldName is the name of the hidden form field containing the CSRF token FieldName string - // RequestHeader is the name of header used in AJAX requests + // RequestHeader is the name of the request header containing the CSRF token RequestHeader string - // TrustedOrigins lists additional origins that are trusted + // TrustedOrigins defines trusted origins for CSRF protection TrustedOrigins []string // Path defines URL paths where CSRF protection applies @@ -56,47 +43,58 @@ type CSRFConfig struct { Ignore []func(r *http.Request) bool } -// DefaultCSRFConfig returns a default configuration for CSRF middleware +// DefaultCSRFConfig returns a minimal configuration for CSRF middleware +// It only sets values that must be provided, allowing the CSRF package +// to use its own defaults for everything else. func DefaultCSRFConfig() CSRFConfig { - config := CSRFConfig{ - Secret: nil, // Must be set by the application - FieldName: "gorilla.csrf.Token", - RequestHeader: "X-CSRF-Token", - Path: "", - Ignore: []func(r *http.Request) bool{}, + return CSRFConfig{ + Secret: nil, // Must be set by the application + Ignore: []func(r *http.Request) bool{}, } - - config.Cookie.Name = "_csrf" - config.Cookie.Path = "/" - config.Cookie.MaxAge = 86400 // 24 hours - config.Cookie.Secure = true - config.Cookie.HttpOnly = true - config.Cookie.SameSite = csrf.SameSiteStrictMode - - return config } // CSRF middleware provides Cross-Site Request Forgery protection func CSRF(config CSRFConfig) Middleware { - options := []csrf.Option{ - csrf.CookieName(config.Cookie.Name), - csrf.Path(config.Cookie.Path), - csrf.MaxAge(config.Cookie.MaxAge), - csrf.FieldName(config.FieldName), - csrf.RequestHeader(config.RequestHeader), - csrf.Secure(config.Cookie.Secure), - csrf.HttpOnly(config.Cookie.HttpOnly), - csrf.SameSite(config.Cookie.SameSite), - } + // Only set options that are explicitly configured + var options []csrf.Option + // Cookie options + if config.Cookie.Name != "" { + options = append(options, csrf.CookieName(config.Cookie.Name)) + } + if config.Cookie.Path != "" { + options = append(options, csrf.Path(config.Cookie.Path)) + } + if config.Cookie.MaxAge != 0 { + options = append(options, csrf.MaxAge(config.Cookie.MaxAge)) + } if config.Cookie.Domain != "" { options = append(options, csrf.Domain(config.Cookie.Domain)) } + + // Only check boolean fields if they've been explicitly set + // For cookie security settings, we only set them if they're being turned off + // since the secure defaults should be used otherwise + if !config.Cookie.Secure { + options = append(options, csrf.Secure(false)) + } + if !config.Cookie.HttpOnly { + options = append(options, csrf.HttpOnly(false)) + } + if config.Cookie.SameSite != 0 { + options = append(options, csrf.SameSite(config.Cookie.SameSite)) + } + // Other options + if config.FieldName != "" { + options = append(options, csrf.FieldName(config.FieldName)) + } + if config.RequestHeader != "" { + options = append(options, csrf.RequestHeader(config.RequestHeader)) + } if config.ErrorHandler != nil { options = append(options, csrf.ErrorHandler(config.ErrorHandler)) } - if len(config.TrustedOrigins) > 0 { options = append(options, csrf.TrustedOrigins(config.TrustedOrigins)) } @@ -108,7 +106,7 @@ func CSRF(config CSRFConfig) Middleware { // Handle protection path if config.Path != "" { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == config.Path || (len(r.URL.Path) >= len(config.Path) && + if r.URL.Path == config.Path || (len(r.URL.Path) >= len(config.Path) && r.URL.Path[:len(config.Path)] == config.Path) { // Check if the request should be ignored for _, ignoreFunc := range config.Ignore { @@ -151,63 +149,3 @@ func CSRFToken(r *http.Request) string { func CSRFTemplateField(r *http.Request) template.HTML { return csrf.TemplateField(r) } - -/* Usage example: - -1. Create a secure key for CSRF protection: - - key := make([]byte, 32) - _, err := rand.Read(key) - if err != nil { - log.Fatal(err) - } - -2. Create CSRF middleware with default configuration: - - csrfConfig := middleware.DefaultCSRFConfig() - csrfConfig.Secret = key - csrfMiddleware := middleware.CSRF(csrfConfig) - -3. Add the middleware to your routing (example with standard http): - - handler := middleware.CreateStack( - middleware.RequestID(), - middleware.Logging(logger), - middleware.CORS(corsConfig), - middleware.CSRF(csrfConfig), // Add CSRF protection - )(router) - -4. In your HTML templates, include the CSRF token in forms: - -
- - Where .CSRFField is provided to the template as: - - data := map[string]interface{}{ - "CSRFField": middleware.CSRFTemplateField(r), - } - -5. For AJAX requests, include the CSRF token in the request header: - - // JavaScript example - fetch('/api/protected', { - method: 'POST', - headers: { - 'X-CSRF-Token': document.querySelector('meta[name="csrf-token"]').getAttribute('content'), - }, - body: JSON.stringify(data) - }) - - // Include meta tag in your HTML: - // - - Where .CSRFToken is provided to the template as: - - data := map[string]interface{}{ - "CSRFToken": middleware.CSRFToken(r), - } -*/ \ No newline at end of file