package middleware import ( "fmt" "html/template" "net/http" "github.com/gorilla/csrf" ) // CSRFConfig defines the configuration options for CSRF middleware type CSRFConfig struct { // Secret is the 32-byte secret key used to generate tokens Secret []byte // Cookie defines cookie options Cookie struct { Name string Domain string HttpOnly bool MaxAge int Path string SameSite csrf.SameSiteMode Secure bool } // ErrorHandler is a custom error handler for CSRF errors ErrorHandler http.Handler // FieldName is the name of the hidden form field containing the CSRF token FieldName string // RequestHeader is the name of the request header containing the CSRF token RequestHeader string // TrustedOrigins defines trusted origins for CSRF protection TrustedOrigins []string // Path defines URL paths where CSRF protection applies // If empty, all paths are protected Path string // Ignore functions determine if a request should skip CSRF protection Ignore []func(r *http.Request) bool } // CSRF middleware provides Cross-Site Request Forgery protection func CSRF(config CSRFConfig) Middleware { // Only set options that are explicitly configured var options []csrf.Option // Cookie options if config.Cookie.Name != "" { options = append(options, csrf.CookieName(config.Cookie.Name)) } if config.Cookie.Path != "" { options = append(options, csrf.Path(config.Cookie.Path)) } if config.Cookie.MaxAge != 0 { options = append(options, csrf.MaxAge(config.Cookie.MaxAge)) } if config.Cookie.Domain != "" { options = append(options, csrf.Domain(config.Cookie.Domain)) } if !config.Cookie.Secure { options = append(options, csrf.Secure(false)) } if !config.Cookie.HttpOnly { options = append(options, csrf.HttpOnly(false)) } if config.Cookie.SameSite != 0 { options = append(options, csrf.SameSite(config.Cookie.SameSite)) } // Other options if config.FieldName != "" { options = append(options, csrf.FieldName(config.FieldName)) } if config.RequestHeader != "" { options = append(options, csrf.RequestHeader(config.RequestHeader)) } if config.ErrorHandler != nil { options = append(options, csrf.ErrorHandler(config.ErrorHandler)) } if len(config.TrustedOrigins) > 0 { options = append(options, csrf.TrustedOrigins(config.TrustedOrigins)) } // Create CSRF protection middleware csrfHandler := csrf.Protect(config.Secret, options...) return func(next http.Handler) http.Handler { // Handle protection path if config.Path != "" { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == config.Path || (len(r.URL.Path) >= len(config.Path) && r.URL.Path[:len(config.Path)] == config.Path) { // Check if the request should be ignored if config.Ignore != nil { for _, ignoreFunc := range config.Ignore { if ignoreFunc(r) { next.ServeHTTP(w, r) return } } } csrfHandler(next).ServeHTTP(w, r) return } next.ServeHTTP(w, r) }) } // Handle ignore functions if len(config.Ignore) > 0 { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { for _, ignoreFunc := range config.Ignore { if ignoreFunc(r) { next.ServeHTTP(w, r) return } } csrfHandler(next).ServeHTTP(w, r) }) } // Apply CSRF to all routes if no path or ignores specified return csrfHandler(next) } } // CSRFToken gets the CSRF token from the request context func CSRFToken(r *http.Request) string { return csrf.Token(r) } // CSRFTemplateField gets the hidden input field containing the CSRF token func CSRFTemplateField(r *http.Request) template.HTML { return csrf.TemplateField(r) } // ParseCSRFKey validates and converts a CSRF secret string to the required 32-byte key // It returns the key as a byte slice and an error if the key is invalid func ParseCSRFKey(secret string) ([]byte, error) { if secret == "" { return nil, fmt.Errorf("csrf secret is required and must be exactly 32 bytes") } key := []byte(secret) if len(key) != 32 { return nil, fmt.Errorf("csrf secret must be exactly 32 bytes (got %d bytes)", len(key)) } return key, nil }