diff --git a/cmd/start.go b/cmd/start.go index 49676a8..791601e 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -2,7 +2,6 @@ package cmd import ( "context" - "crypto/rand" "log/slog" "net" "net/http" @@ -60,26 +59,14 @@ var startCmd = &cobra.Command{ // Create empty CSRF configuration with default values var csrfConfig middleware.CSRFConfig - // Set CSRF secret from config or generate a random one + // Get and validate CSRF secret from config csrfSecret := viper.GetString("csrf-secret") - var csrfKey []byte - - if csrfSecret != "" { - // Use configured secret - must be at least 32 bytes - csrfKey = []byte(csrfSecret) - if len(csrfKey) < 32 { - logger.Error("csrf-secret must be at least 32 bytes") - return - } - } else { - // Generate a random secret - csrfKey = make([]byte, 32) - _, err = rand.Read(csrfKey) - if err != nil { - logger.Error("failed to generate CSRF key", slog.Any("error", err)) - return - } - logger.Info("generated random CSRF key, consider setting csrf-secret for stability across restarts") + csrfKey, err := middleware.ParseCSRFKey(csrfSecret) + if err != nil { + logger.Error("invalid csrf-secret", + slog.String("error", err.Error()), + slog.String("hint", "must be exactly 32 bytes and persist across restarts")) + return } csrfConfig.Secret = csrfKey @@ -138,7 +125,7 @@ func init() { startCmd.Flags().String("issuer-url", "", "Identity Provider Issuer URL") startCmd.Flags().String("hostname", "", "Address at which the server is exposed") startCmd.Flags().String("session-secret", "", "Session encryption secret") - startCmd.Flags().String("csrf-secret", "", "Secret key for CSRF protection (min 32 bytes)") + startCmd.Flags().String("csrf-secret", "", "Secret key for CSRF protection (must be exactly 32 bytes)") startCmd.Flags().String("env", "", "Environment (development/production)") // Bind all flags to Viper diff --git a/internal/middleware/csrf.go b/internal/middleware/csrf.go index 55050aa..4523252 100644 --- a/internal/middleware/csrf.go +++ b/internal/middleware/csrf.go @@ -1,6 +1,7 @@ package middleware import ( + "fmt" "html/template" "net/http" @@ -137,3 +138,18 @@ func CSRFToken(r *http.Request) string { func CSRFTemplateField(r *http.Request) template.HTML { return csrf.TemplateField(r) } + +// ParseCSRFKey validates and converts a CSRF secret string to the required 32-byte key +// It returns the key as a byte slice and an error if the key is invalid +func ParseCSRFKey(secret string) ([]byte, error) { + if secret == "" { + return nil, fmt.Errorf("csrf secret is required and must be exactly 32 bytes") + } + + key := []byte(secret) + if len(key) != 32 { + return nil, fmt.Errorf("csrf secret must be exactly 32 bytes (got %d bytes)", len(key)) + } + + return key, nil +}