Implement CSRF protection middleware with configurable secret and cookie options
This commit is contained in:
36
cmd/start.go
36
cmd/start.go
@ -2,6 +2,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -52,6 +53,39 @@ var startCmd = &cobra.Command{
|
|||||||
// Create CORS configuration
|
// Create CORS configuration
|
||||||
corsConfig := middleware.DefaultCORSConfig()
|
corsConfig := middleware.DefaultCORSConfig()
|
||||||
|
|
||||||
|
// Start with minimal default configuration
|
||||||
|
csrfConfig := middleware.DefaultCSRFConfig()
|
||||||
|
|
||||||
|
// Set CSRF secret from config or generate a random one
|
||||||
|
csrfSecret := viper.GetString("csrf-secret")
|
||||||
|
var csrfKey []byte
|
||||||
|
|
||||||
|
if csrfSecret != "" {
|
||||||
|
// Use configured secret - must be at least 32 bytes
|
||||||
|
csrfKey = []byte(csrfSecret)
|
||||||
|
if len(csrfKey) < 32 {
|
||||||
|
logger.Error("csrf-secret must be at least 32 bytes")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Generate a random secret
|
||||||
|
csrfKey = make([]byte, 32)
|
||||||
|
_, err = rand.Read(csrfKey)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("failed to generate CSRF key", slog.Any("error", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info("generated random CSRF key, consider setting csrf-secret for stability across restarts")
|
||||||
|
}
|
||||||
|
|
||||||
|
csrfConfig.Secret = csrfKey
|
||||||
|
|
||||||
|
// Only override specific settings when needed
|
||||||
|
if env == "development" {
|
||||||
|
// In development, cookies often need to work without HTTPS
|
||||||
|
csrfConfig.Cookie.Secure = false
|
||||||
|
}
|
||||||
|
|
||||||
// Create middleware stack
|
// Create middleware stack
|
||||||
stack := middleware.CreateStack(
|
stack := middleware.CreateStack(
|
||||||
middleware.RequestID(), // Generate a unique request ID
|
middleware.RequestID(), // Generate a unique request ID
|
||||||
@ -61,6 +95,7 @@ var startCmd = &cobra.Command{
|
|||||||
middleware.MaxBodySize(1024*1024), // 1MB size limit
|
middleware.MaxBodySize(1024*1024), // 1MB size limit
|
||||||
middleware.SecureHeaders(), // Set secure headers
|
middleware.SecureHeaders(), // Set secure headers
|
||||||
middleware.CORS(corsConfig), // CORS configuration
|
middleware.CORS(corsConfig), // CORS configuration
|
||||||
|
middleware.CSRF(csrfConfig), // CSRF protection
|
||||||
authConfig.Middleware(), // OIDC authentication middleware
|
authConfig.Middleware(), // OIDC authentication middleware
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -99,6 +134,7 @@ func init() {
|
|||||||
startCmd.Flags().String("issuer-url", "", "Identity Provider Issuer URL")
|
startCmd.Flags().String("issuer-url", "", "Identity Provider Issuer URL")
|
||||||
startCmd.Flags().String("hostname", "", "Address at which the server is exposed")
|
startCmd.Flags().String("hostname", "", "Address at which the server is exposed")
|
||||||
startCmd.Flags().String("session-secret", "", "Session encryption secret")
|
startCmd.Flags().String("session-secret", "", "Session encryption secret")
|
||||||
|
startCmd.Flags().String("csrf-secret", "", "Secret key for CSRF protection (min 32 bytes)")
|
||||||
startCmd.Flags().String("env", "", "Environment (development/production)")
|
startCmd.Flags().String("env", "", "Environment (development/production)")
|
||||||
|
|
||||||
// Bind all flags to Viper
|
// Bind all flags to Viper
|
||||||
|
@ -14,38 +14,25 @@ type CSRFConfig struct {
|
|||||||
|
|
||||||
// Cookie defines cookie options
|
// Cookie defines cookie options
|
||||||
Cookie struct {
|
Cookie struct {
|
||||||
// Name of the CSRF cookie
|
|
||||||
Name string
|
Name string
|
||||||
|
|
||||||
// Path where the cookie is valid
|
|
||||||
Path string
|
|
||||||
|
|
||||||
// Domain where the cookie is valid
|
|
||||||
Domain string
|
Domain string
|
||||||
|
|
||||||
// MaxAge in seconds for the CSRF cookie
|
|
||||||
MaxAge int
|
|
||||||
|
|
||||||
// Secure sets the Secure flag on the cookie
|
|
||||||
Secure bool
|
|
||||||
|
|
||||||
// HttpOnly sets the HttpOnly flag on the cookie
|
|
||||||
HttpOnly bool
|
HttpOnly bool
|
||||||
|
MaxAge int
|
||||||
// SameSite sets the SameSite attribute on the cookie
|
Path string
|
||||||
SameSite csrf.SameSiteMode
|
SameSite csrf.SameSiteMode
|
||||||
|
Secure bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorHandler to call when CSRF validation fails
|
// ErrorHandler is a custom error handler for CSRF errors
|
||||||
ErrorHandler http.Handler
|
ErrorHandler http.Handler
|
||||||
|
|
||||||
// FieldName is the name of the hidden input field used by frontend
|
// FieldName is the name of the hidden form field containing the CSRF token
|
||||||
FieldName string
|
FieldName string
|
||||||
|
|
||||||
// RequestHeader is the name of header used in AJAX requests
|
// RequestHeader is the name of the request header containing the CSRF token
|
||||||
RequestHeader string
|
RequestHeader string
|
||||||
|
|
||||||
// TrustedOrigins lists additional origins that are trusted
|
// TrustedOrigins defines trusted origins for CSRF protection
|
||||||
TrustedOrigins []string
|
TrustedOrigins []string
|
||||||
|
|
||||||
// Path defines URL paths where CSRF protection applies
|
// Path defines URL paths where CSRF protection applies
|
||||||
@ -56,47 +43,58 @@ type CSRFConfig struct {
|
|||||||
Ignore []func(r *http.Request) bool
|
Ignore []func(r *http.Request) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultCSRFConfig returns a default configuration for CSRF middleware
|
// DefaultCSRFConfig returns a minimal configuration for CSRF middleware
|
||||||
|
// It only sets values that must be provided, allowing the CSRF package
|
||||||
|
// to use its own defaults for everything else.
|
||||||
func DefaultCSRFConfig() CSRFConfig {
|
func DefaultCSRFConfig() CSRFConfig {
|
||||||
config := CSRFConfig{
|
return CSRFConfig{
|
||||||
Secret: nil, // Must be set by the application
|
Secret: nil, // Must be set by the application
|
||||||
FieldName: "gorilla.csrf.Token",
|
|
||||||
RequestHeader: "X-CSRF-Token",
|
|
||||||
Path: "",
|
|
||||||
Ignore: []func(r *http.Request) bool{},
|
Ignore: []func(r *http.Request) bool{},
|
||||||
}
|
}
|
||||||
|
|
||||||
config.Cookie.Name = "_csrf"
|
|
||||||
config.Cookie.Path = "/"
|
|
||||||
config.Cookie.MaxAge = 86400 // 24 hours
|
|
||||||
config.Cookie.Secure = true
|
|
||||||
config.Cookie.HttpOnly = true
|
|
||||||
config.Cookie.SameSite = csrf.SameSiteStrictMode
|
|
||||||
|
|
||||||
return config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CSRF middleware provides Cross-Site Request Forgery protection
|
// CSRF middleware provides Cross-Site Request Forgery protection
|
||||||
func CSRF(config CSRFConfig) Middleware {
|
func CSRF(config CSRFConfig) Middleware {
|
||||||
options := []csrf.Option{
|
// Only set options that are explicitly configured
|
||||||
csrf.CookieName(config.Cookie.Name),
|
var options []csrf.Option
|
||||||
csrf.Path(config.Cookie.Path),
|
|
||||||
csrf.MaxAge(config.Cookie.MaxAge),
|
|
||||||
csrf.FieldName(config.FieldName),
|
|
||||||
csrf.RequestHeader(config.RequestHeader),
|
|
||||||
csrf.Secure(config.Cookie.Secure),
|
|
||||||
csrf.HttpOnly(config.Cookie.HttpOnly),
|
|
||||||
csrf.SameSite(config.Cookie.SameSite),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Cookie options
|
||||||
|
if config.Cookie.Name != "" {
|
||||||
|
options = append(options, csrf.CookieName(config.Cookie.Name))
|
||||||
|
}
|
||||||
|
if config.Cookie.Path != "" {
|
||||||
|
options = append(options, csrf.Path(config.Cookie.Path))
|
||||||
|
}
|
||||||
|
if config.Cookie.MaxAge != 0 {
|
||||||
|
options = append(options, csrf.MaxAge(config.Cookie.MaxAge))
|
||||||
|
}
|
||||||
if config.Cookie.Domain != "" {
|
if config.Cookie.Domain != "" {
|
||||||
options = append(options, csrf.Domain(config.Cookie.Domain))
|
options = append(options, csrf.Domain(config.Cookie.Domain))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Only check boolean fields if they've been explicitly set
|
||||||
|
// For cookie security settings, we only set them if they're being turned off
|
||||||
|
// since the secure defaults should be used otherwise
|
||||||
|
if !config.Cookie.Secure {
|
||||||
|
options = append(options, csrf.Secure(false))
|
||||||
|
}
|
||||||
|
if !config.Cookie.HttpOnly {
|
||||||
|
options = append(options, csrf.HttpOnly(false))
|
||||||
|
}
|
||||||
|
if config.Cookie.SameSite != 0 {
|
||||||
|
options = append(options, csrf.SameSite(config.Cookie.SameSite))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other options
|
||||||
|
if config.FieldName != "" {
|
||||||
|
options = append(options, csrf.FieldName(config.FieldName))
|
||||||
|
}
|
||||||
|
if config.RequestHeader != "" {
|
||||||
|
options = append(options, csrf.RequestHeader(config.RequestHeader))
|
||||||
|
}
|
||||||
if config.ErrorHandler != nil {
|
if config.ErrorHandler != nil {
|
||||||
options = append(options, csrf.ErrorHandler(config.ErrorHandler))
|
options = append(options, csrf.ErrorHandler(config.ErrorHandler))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(config.TrustedOrigins) > 0 {
|
if len(config.TrustedOrigins) > 0 {
|
||||||
options = append(options, csrf.TrustedOrigins(config.TrustedOrigins))
|
options = append(options, csrf.TrustedOrigins(config.TrustedOrigins))
|
||||||
}
|
}
|
||||||
@ -151,63 +149,3 @@ func CSRFToken(r *http.Request) string {
|
|||||||
func CSRFTemplateField(r *http.Request) template.HTML {
|
func CSRFTemplateField(r *http.Request) template.HTML {
|
||||||
return csrf.TemplateField(r)
|
return csrf.TemplateField(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Usage example:
|
|
||||||
|
|
||||||
1. Create a secure key for CSRF protection:
|
|
||||||
|
|
||||||
key := make([]byte, 32)
|
|
||||||
_, err := rand.Read(key)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
2. Create CSRF middleware with default configuration:
|
|
||||||
|
|
||||||
csrfConfig := middleware.DefaultCSRFConfig()
|
|
||||||
csrfConfig.Secret = key
|
|
||||||
csrfMiddleware := middleware.CSRF(csrfConfig)
|
|
||||||
|
|
||||||
3. Add the middleware to your routing (example with standard http):
|
|
||||||
|
|
||||||
handler := middleware.CreateStack(
|
|
||||||
middleware.RequestID(),
|
|
||||||
middleware.Logging(logger),
|
|
||||||
middleware.CORS(corsConfig),
|
|
||||||
middleware.CSRF(csrfConfig), // Add CSRF protection
|
|
||||||
)(router)
|
|
||||||
|
|
||||||
4. In your HTML templates, include the CSRF token in forms:
|
|
||||||
|
|
||||||
<form method="post" action="/protected">
|
|
||||||
{{ .CSRFField }}
|
|
||||||
<input type="text" name="username">
|
|
||||||
<button type="submit">Submit</button>
|
|
||||||
</form>
|
|
||||||
|
|
||||||
Where .CSRFField is provided to the template as:
|
|
||||||
|
|
||||||
data := map[string]interface{}{
|
|
||||||
"CSRFField": middleware.CSRFTemplateField(r),
|
|
||||||
}
|
|
||||||
|
|
||||||
5. For AJAX requests, include the CSRF token in the request header:
|
|
||||||
|
|
||||||
// JavaScript example
|
|
||||||
fetch('/api/protected', {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'X-CSRF-Token': document.querySelector('meta[name="csrf-token"]').getAttribute('content'),
|
|
||||||
},
|
|
||||||
body: JSON.stringify(data)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Include meta tag in your HTML:
|
|
||||||
// <meta name="csrf-token" content="{{ .CSRFToken }}">
|
|
||||||
|
|
||||||
Where .CSRFToken is provided to the template as:
|
|
||||||
|
|
||||||
data := map[string]interface{}{
|
|
||||||
"CSRFToken": middleware.CSRFToken(r),
|
|
||||||
}
|
|
||||||
*/
|
|
Reference in New Issue
Block a user