Implement CSRF protection middleware with configurable secret and cookie options

This commit is contained in:
2025-04-29 01:42:54 -05:00
parent 8cf7841f20
commit b2ead348e5
2 changed files with 85 additions and 111 deletions

View File

@ -2,6 +2,7 @@ package cmd
import ( import (
"context" "context"
"crypto/rand"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
@ -52,6 +53,39 @@ var startCmd = &cobra.Command{
// Create CORS configuration // Create CORS configuration
corsConfig := middleware.DefaultCORSConfig() corsConfig := middleware.DefaultCORSConfig()
// Start with minimal default configuration
csrfConfig := middleware.DefaultCSRFConfig()
// Set CSRF secret from config or generate a random one
csrfSecret := viper.GetString("csrf-secret")
var csrfKey []byte
if csrfSecret != "" {
// Use configured secret - must be at least 32 bytes
csrfKey = []byte(csrfSecret)
if len(csrfKey) < 32 {
logger.Error("csrf-secret must be at least 32 bytes")
return
}
} else {
// Generate a random secret
csrfKey = make([]byte, 32)
_, err = rand.Read(csrfKey)
if err != nil {
logger.Error("failed to generate CSRF key", slog.Any("error", err))
return
}
logger.Info("generated random CSRF key, consider setting csrf-secret for stability across restarts")
}
csrfConfig.Secret = csrfKey
// Only override specific settings when needed
if env == "development" {
// In development, cookies often need to work without HTTPS
csrfConfig.Cookie.Secure = false
}
// Create middleware stack // Create middleware stack
stack := middleware.CreateStack( stack := middleware.CreateStack(
middleware.RequestID(), // Generate a unique request ID middleware.RequestID(), // Generate a unique request ID
@ -61,6 +95,7 @@ var startCmd = &cobra.Command{
middleware.MaxBodySize(1024*1024), // 1MB size limit middleware.MaxBodySize(1024*1024), // 1MB size limit
middleware.SecureHeaders(), // Set secure headers middleware.SecureHeaders(), // Set secure headers
middleware.CORS(corsConfig), // CORS configuration middleware.CORS(corsConfig), // CORS configuration
middleware.CSRF(csrfConfig), // CSRF protection
authConfig.Middleware(), // OIDC authentication middleware authConfig.Middleware(), // OIDC authentication middleware
) )
@ -99,6 +134,7 @@ func init() {
startCmd.Flags().String("issuer-url", "", "Identity Provider Issuer URL") startCmd.Flags().String("issuer-url", "", "Identity Provider Issuer URL")
startCmd.Flags().String("hostname", "", "Address at which the server is exposed") startCmd.Flags().String("hostname", "", "Address at which the server is exposed")
startCmd.Flags().String("session-secret", "", "Session encryption secret") startCmd.Flags().String("session-secret", "", "Session encryption secret")
startCmd.Flags().String("csrf-secret", "", "Secret key for CSRF protection (min 32 bytes)")
startCmd.Flags().String("env", "", "Environment (development/production)") startCmd.Flags().String("env", "", "Environment (development/production)")
// Bind all flags to Viper // Bind all flags to Viper

View File

@ -14,38 +14,25 @@ type CSRFConfig struct {
// Cookie defines cookie options // Cookie defines cookie options
Cookie struct { Cookie struct {
// Name of the CSRF cookie
Name string Name string
// Path where the cookie is valid
Path string
// Domain where the cookie is valid
Domain string Domain string
// MaxAge in seconds for the CSRF cookie
MaxAge int
// Secure sets the Secure flag on the cookie
Secure bool
// HttpOnly sets the HttpOnly flag on the cookie
HttpOnly bool HttpOnly bool
MaxAge int
// SameSite sets the SameSite attribute on the cookie Path string
SameSite csrf.SameSiteMode SameSite csrf.SameSiteMode
Secure bool
} }
// ErrorHandler to call when CSRF validation fails // ErrorHandler is a custom error handler for CSRF errors
ErrorHandler http.Handler ErrorHandler http.Handler
// FieldName is the name of the hidden input field used by frontend // FieldName is the name of the hidden form field containing the CSRF token
FieldName string FieldName string
// RequestHeader is the name of header used in AJAX requests // RequestHeader is the name of the request header containing the CSRF token
RequestHeader string RequestHeader string
// TrustedOrigins lists additional origins that are trusted // TrustedOrigins defines trusted origins for CSRF protection
TrustedOrigins []string TrustedOrigins []string
// Path defines URL paths where CSRF protection applies // Path defines URL paths where CSRF protection applies
@ -56,47 +43,58 @@ type CSRFConfig struct {
Ignore []func(r *http.Request) bool Ignore []func(r *http.Request) bool
} }
// DefaultCSRFConfig returns a default configuration for CSRF middleware // DefaultCSRFConfig returns a minimal configuration for CSRF middleware
// It only sets values that must be provided, allowing the CSRF package
// to use its own defaults for everything else.
func DefaultCSRFConfig() CSRFConfig { func DefaultCSRFConfig() CSRFConfig {
config := CSRFConfig{ return CSRFConfig{
Secret: nil, // Must be set by the application Secret: nil, // Must be set by the application
FieldName: "gorilla.csrf.Token",
RequestHeader: "X-CSRF-Token",
Path: "",
Ignore: []func(r *http.Request) bool{}, Ignore: []func(r *http.Request) bool{},
} }
config.Cookie.Name = "_csrf"
config.Cookie.Path = "/"
config.Cookie.MaxAge = 86400 // 24 hours
config.Cookie.Secure = true
config.Cookie.HttpOnly = true
config.Cookie.SameSite = csrf.SameSiteStrictMode
return config
} }
// CSRF middleware provides Cross-Site Request Forgery protection // CSRF middleware provides Cross-Site Request Forgery protection
func CSRF(config CSRFConfig) Middleware { func CSRF(config CSRFConfig) Middleware {
options := []csrf.Option{ // Only set options that are explicitly configured
csrf.CookieName(config.Cookie.Name), var options []csrf.Option
csrf.Path(config.Cookie.Path),
csrf.MaxAge(config.Cookie.MaxAge),
csrf.FieldName(config.FieldName),
csrf.RequestHeader(config.RequestHeader),
csrf.Secure(config.Cookie.Secure),
csrf.HttpOnly(config.Cookie.HttpOnly),
csrf.SameSite(config.Cookie.SameSite),
}
// Cookie options
if config.Cookie.Name != "" {
options = append(options, csrf.CookieName(config.Cookie.Name))
}
if config.Cookie.Path != "" {
options = append(options, csrf.Path(config.Cookie.Path))
}
if config.Cookie.MaxAge != 0 {
options = append(options, csrf.MaxAge(config.Cookie.MaxAge))
}
if config.Cookie.Domain != "" { if config.Cookie.Domain != "" {
options = append(options, csrf.Domain(config.Cookie.Domain)) options = append(options, csrf.Domain(config.Cookie.Domain))
} }
// Only check boolean fields if they've been explicitly set
// For cookie security settings, we only set them if they're being turned off
// since the secure defaults should be used otherwise
if !config.Cookie.Secure {
options = append(options, csrf.Secure(false))
}
if !config.Cookie.HttpOnly {
options = append(options, csrf.HttpOnly(false))
}
if config.Cookie.SameSite != 0 {
options = append(options, csrf.SameSite(config.Cookie.SameSite))
}
// Other options
if config.FieldName != "" {
options = append(options, csrf.FieldName(config.FieldName))
}
if config.RequestHeader != "" {
options = append(options, csrf.RequestHeader(config.RequestHeader))
}
if config.ErrorHandler != nil { if config.ErrorHandler != nil {
options = append(options, csrf.ErrorHandler(config.ErrorHandler)) options = append(options, csrf.ErrorHandler(config.ErrorHandler))
} }
if len(config.TrustedOrigins) > 0 { if len(config.TrustedOrigins) > 0 {
options = append(options, csrf.TrustedOrigins(config.TrustedOrigins)) options = append(options, csrf.TrustedOrigins(config.TrustedOrigins))
} }
@ -151,63 +149,3 @@ func CSRFToken(r *http.Request) string {
func CSRFTemplateField(r *http.Request) template.HTML { func CSRFTemplateField(r *http.Request) template.HTML {
return csrf.TemplateField(r) return csrf.TemplateField(r)
} }
/* Usage example:
1. Create a secure key for CSRF protection:
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
log.Fatal(err)
}
2. Create CSRF middleware with default configuration:
csrfConfig := middleware.DefaultCSRFConfig()
csrfConfig.Secret = key
csrfMiddleware := middleware.CSRF(csrfConfig)
3. Add the middleware to your routing (example with standard http):
handler := middleware.CreateStack(
middleware.RequestID(),
middleware.Logging(logger),
middleware.CORS(corsConfig),
middleware.CSRF(csrfConfig), // Add CSRF protection
)(router)
4. In your HTML templates, include the CSRF token in forms:
<form method="post" action="/protected">
{{ .CSRFField }}
<input type="text" name="username">
<button type="submit">Submit</button>
</form>
Where .CSRFField is provided to the template as:
data := map[string]interface{}{
"CSRFField": middleware.CSRFTemplateField(r),
}
5. For AJAX requests, include the CSRF token in the request header:
// JavaScript example
fetch('/api/protected', {
method: 'POST',
headers: {
'X-CSRF-Token': document.querySelector('meta[name="csrf-token"]').getAttribute('content'),
},
body: JSON.stringify(data)
})
// Include meta tag in your HTML:
// <meta name="csrf-token" content="{{ .CSRFToken }}">
Where .CSRFToken is provided to the template as:
data := map[string]interface{}{
"CSRFToken": middleware.CSRFToken(r),
}
*/