148 lines
3.9 KiB
Go

package middleware
import (
"html/template"
"net/http"
"github.com/gorilla/csrf"
)
// CSRFConfig defines the configuration options for CSRF middleware
type CSRFConfig struct {
// Secret is the 32-byte secret key used to generate tokens
Secret []byte
// Cookie defines cookie options
Cookie struct {
Name string
Domain string
HttpOnly bool
MaxAge int
Path string
SameSite csrf.SameSiteMode
Secure bool
}
// ErrorHandler is a custom error handler for CSRF errors
ErrorHandler http.Handler
// FieldName is the name of the hidden form field containing the CSRF token
FieldName string
// RequestHeader is the name of the request header containing the CSRF token
RequestHeader string
// TrustedOrigins defines trusted origins for CSRF protection
TrustedOrigins []string
// Path defines URL paths where CSRF protection applies
// If empty, all paths are protected
Path string
// Ignore functions determine if a request should skip CSRF protection
Ignore []func(r *http.Request) bool
}
// DefaultCSRFConfig returns a minimal configuration for CSRF middleware
// It only sets values that must be provided, allowing the CSRF package
// to use its own defaults for everything else.
func DefaultCSRFConfig() CSRFConfig {
return CSRFConfig{
Secret: nil, // Must be set by the application
Ignore: []func(r *http.Request) bool{},
}
}
// CSRF middleware provides Cross-Site Request Forgery protection
func CSRF(config CSRFConfig) Middleware {
// Only set options that are explicitly configured
var options []csrf.Option
// Cookie options
if config.Cookie.Name != "" {
options = append(options, csrf.CookieName(config.Cookie.Name))
}
if config.Cookie.Path != "" {
options = append(options, csrf.Path(config.Cookie.Path))
}
if config.Cookie.MaxAge != 0 {
options = append(options, csrf.MaxAge(config.Cookie.MaxAge))
}
if config.Cookie.Domain != "" {
options = append(options, csrf.Domain(config.Cookie.Domain))
}
if !config.Cookie.Secure {
options = append(options, csrf.Secure(false))
}
if !config.Cookie.HttpOnly {
options = append(options, csrf.HttpOnly(false))
}
if config.Cookie.SameSite != 0 {
options = append(options, csrf.SameSite(config.Cookie.SameSite))
}
// Other options
if config.FieldName != "" {
options = append(options, csrf.FieldName(config.FieldName))
}
if config.RequestHeader != "" {
options = append(options, csrf.RequestHeader(config.RequestHeader))
}
if config.ErrorHandler != nil {
options = append(options, csrf.ErrorHandler(config.ErrorHandler))
}
if len(config.TrustedOrigins) > 0 {
options = append(options, csrf.TrustedOrigins(config.TrustedOrigins))
}
// Create CSRF protection middleware
csrfHandler := csrf.Protect(config.Secret, options...)
return func(next http.Handler) http.Handler {
// Handle protection path
if config.Path != "" {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == config.Path || (len(r.URL.Path) >= len(config.Path) &&
r.URL.Path[:len(config.Path)] == config.Path) {
// Check if the request should be ignored
for _, ignoreFunc := range config.Ignore {
if ignoreFunc(r) {
next.ServeHTTP(w, r)
return
}
}
csrfHandler(next).ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r)
})
}
// Handle ignore functions
if len(config.Ignore) > 0 {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, ignoreFunc := range config.Ignore {
if ignoreFunc(r) {
next.ServeHTTP(w, r)
return
}
}
csrfHandler(next).ServeHTTP(w, r)
})
}
// Apply CSRF to all routes if no path or ignores specified
return csrfHandler(next)
}
}
// CSRFToken gets the CSRF token from the request context
func CSRFToken(r *http.Request) string {
return csrf.Token(r)
}
// CSRFTemplateField gets the hidden input field containing the CSRF token
func CSRFTemplateField(r *http.Request) template.HTML {
return csrf.TemplateField(r)
}