213 lines
5.5 KiB
Go

package middleware
import (
"html/template"
"net/http"
"github.com/gorilla/csrf"
)
// CSRFConfig defines the configuration options for CSRF middleware
type CSRFConfig struct {
// Secret is the 32-byte secret key used to generate tokens
Secret []byte
// Cookie defines cookie options
Cookie struct {
// Name of the CSRF cookie
Name string
// Path where the cookie is valid
Path string
// Domain where the cookie is valid
Domain string
// MaxAge in seconds for the CSRF cookie
MaxAge int
// Secure sets the Secure flag on the cookie
Secure bool
// HttpOnly sets the HttpOnly flag on the cookie
HttpOnly bool
// SameSite sets the SameSite attribute on the cookie
SameSite csrf.SameSiteMode
}
// ErrorHandler to call when CSRF validation fails
ErrorHandler http.Handler
// FieldName is the name of the hidden input field used by frontend
FieldName string
// RequestHeader is the name of header used in AJAX requests
RequestHeader string
// TrustedOrigins lists additional origins that are trusted
TrustedOrigins []string
// Path defines URL paths where CSRF protection applies
// If empty, all paths are protected
Path string
// Ignore functions determine if a request should skip CSRF protection
Ignore []func(r *http.Request) bool
}
// DefaultCSRFConfig returns a default configuration for CSRF middleware
func DefaultCSRFConfig() CSRFConfig {
config := CSRFConfig{
Secret: nil, // Must be set by the application
FieldName: "gorilla.csrf.Token",
RequestHeader: "X-CSRF-Token",
Path: "",
Ignore: []func(r *http.Request) bool{},
}
config.Cookie.Name = "_csrf"
config.Cookie.Path = "/"
config.Cookie.MaxAge = 86400 // 24 hours
config.Cookie.Secure = true
config.Cookie.HttpOnly = true
config.Cookie.SameSite = csrf.SameSiteStrictMode
return config
}
// CSRF middleware provides Cross-Site Request Forgery protection
func CSRF(config CSRFConfig) Middleware {
options := []csrf.Option{
csrf.CookieName(config.Cookie.Name),
csrf.Path(config.Cookie.Path),
csrf.MaxAge(config.Cookie.MaxAge),
csrf.FieldName(config.FieldName),
csrf.RequestHeader(config.RequestHeader),
csrf.Secure(config.Cookie.Secure),
csrf.HttpOnly(config.Cookie.HttpOnly),
csrf.SameSite(config.Cookie.SameSite),
}
if config.Cookie.Domain != "" {
options = append(options, csrf.Domain(config.Cookie.Domain))
}
if config.ErrorHandler != nil {
options = append(options, csrf.ErrorHandler(config.ErrorHandler))
}
if len(config.TrustedOrigins) > 0 {
options = append(options, csrf.TrustedOrigins(config.TrustedOrigins))
}
// Create CSRF protection middleware
csrfHandler := csrf.Protect(config.Secret, options...)
return func(next http.Handler) http.Handler {
// Handle protection path
if config.Path != "" {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == config.Path || (len(r.URL.Path) >= len(config.Path) &&
r.URL.Path[:len(config.Path)] == config.Path) {
// Check if the request should be ignored
for _, ignoreFunc := range config.Ignore {
if ignoreFunc(r) {
next.ServeHTTP(w, r)
return
}
}
csrfHandler(next).ServeHTTP(w, r)
return
}
next.ServeHTTP(w, r)
})
}
// Handle ignore functions
if len(config.Ignore) > 0 {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, ignoreFunc := range config.Ignore {
if ignoreFunc(r) {
next.ServeHTTP(w, r)
return
}
}
csrfHandler(next).ServeHTTP(w, r)
})
}
// Apply CSRF to all routes if no path or ignores specified
return csrfHandler(next)
}
}
// CSRFToken gets the CSRF token from the request context
func CSRFToken(r *http.Request) string {
return csrf.Token(r)
}
// CSRFTemplateField gets the hidden input field containing the CSRF token
func CSRFTemplateField(r *http.Request) template.HTML {
return csrf.TemplateField(r)
}
/* Usage example:
1. Create a secure key for CSRF protection:
key := make([]byte, 32)
_, err := rand.Read(key)
if err != nil {
log.Fatal(err)
}
2. Create CSRF middleware with default configuration:
csrfConfig := middleware.DefaultCSRFConfig()
csrfConfig.Secret = key
csrfMiddleware := middleware.CSRF(csrfConfig)
3. Add the middleware to your routing (example with standard http):
handler := middleware.CreateStack(
middleware.RequestID(),
middleware.Logging(logger),
middleware.CORS(corsConfig),
middleware.CSRF(csrfConfig), // Add CSRF protection
)(router)
4. In your HTML templates, include the CSRF token in forms:
<form method="post" action="/protected">
{{ .CSRFField }}
<input type="text" name="username">
<button type="submit">Submit</button>
</form>
Where .CSRFField is provided to the template as:
data := map[string]interface{}{
"CSRFField": middleware.CSRFTemplateField(r),
}
5. For AJAX requests, include the CSRF token in the request header:
// JavaScript example
fetch('/api/protected', {
method: 'POST',
headers: {
'X-CSRF-Token': document.querySelector('meta[name="csrf-token"]').getAttribute('content'),
},
body: JSON.stringify(data)
})
// Include meta tag in your HTML:
// <meta name="csrf-token" content="{{ .CSRFToken }}">
Where .CSRFToken is provided to the template as:
data := map[string]interface{}{
"CSRFToken": middleware.CSRFToken(r),
}
*/