package middleware import ( "html/template" "net/http" "github.com/gorilla/csrf" ) // CSRFConfig defines the configuration options for CSRF middleware type CSRFConfig struct { // Secret is the 32-byte secret key used to generate tokens Secret []byte // Cookie defines cookie options Cookie struct { // Name of the CSRF cookie Name string // Path where the cookie is valid Path string // Domain where the cookie is valid Domain string // MaxAge in seconds for the CSRF cookie MaxAge int // Secure sets the Secure flag on the cookie Secure bool // HttpOnly sets the HttpOnly flag on the cookie HttpOnly bool // SameSite sets the SameSite attribute on the cookie SameSite csrf.SameSiteMode } // ErrorHandler to call when CSRF validation fails ErrorHandler http.Handler // FieldName is the name of the hidden input field used by frontend FieldName string // RequestHeader is the name of header used in AJAX requests RequestHeader string // TrustedOrigins lists additional origins that are trusted TrustedOrigins []string // Path defines URL paths where CSRF protection applies // If empty, all paths are protected Path string // Ignore functions determine if a request should skip CSRF protection Ignore []func(r *http.Request) bool } // DefaultCSRFConfig returns a default configuration for CSRF middleware func DefaultCSRFConfig() CSRFConfig { config := CSRFConfig{ Secret: nil, // Must be set by the application FieldName: "gorilla.csrf.Token", RequestHeader: "X-CSRF-Token", Path: "", Ignore: []func(r *http.Request) bool{}, } config.Cookie.Name = "_csrf" config.Cookie.Path = "/" config.Cookie.MaxAge = 86400 // 24 hours config.Cookie.Secure = true config.Cookie.HttpOnly = true config.Cookie.SameSite = csrf.SameSiteStrictMode return config } // CSRF middleware provides Cross-Site Request Forgery protection func CSRF(config CSRFConfig) Middleware { options := []csrf.Option{ csrf.CookieName(config.Cookie.Name), csrf.Path(config.Cookie.Path), csrf.MaxAge(config.Cookie.MaxAge), csrf.FieldName(config.FieldName), csrf.RequestHeader(config.RequestHeader), csrf.Secure(config.Cookie.Secure), csrf.HttpOnly(config.Cookie.HttpOnly), csrf.SameSite(config.Cookie.SameSite), } if config.Cookie.Domain != "" { options = append(options, csrf.Domain(config.Cookie.Domain)) } if config.ErrorHandler != nil { options = append(options, csrf.ErrorHandler(config.ErrorHandler)) } if len(config.TrustedOrigins) > 0 { options = append(options, csrf.TrustedOrigins(config.TrustedOrigins)) } // Create CSRF protection middleware csrfHandler := csrf.Protect(config.Secret, options...) return func(next http.Handler) http.Handler { // Handle protection path if config.Path != "" { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == config.Path || (len(r.URL.Path) >= len(config.Path) && r.URL.Path[:len(config.Path)] == config.Path) { // Check if the request should be ignored for _, ignoreFunc := range config.Ignore { if ignoreFunc(r) { next.ServeHTTP(w, r) return } } csrfHandler(next).ServeHTTP(w, r) return } next.ServeHTTP(w, r) }) } // Handle ignore functions if len(config.Ignore) > 0 { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { for _, ignoreFunc := range config.Ignore { if ignoreFunc(r) { next.ServeHTTP(w, r) return } } csrfHandler(next).ServeHTTP(w, r) }) } // Apply CSRF to all routes if no path or ignores specified return csrfHandler(next) } } // CSRFToken gets the CSRF token from the request context func CSRFToken(r *http.Request) string { return csrf.Token(r) } // CSRFTemplateField gets the hidden input field containing the CSRF token func CSRFTemplateField(r *http.Request) template.HTML { return csrf.TemplateField(r) } /* Usage example: 1. Create a secure key for CSRF protection: key := make([]byte, 32) _, err := rand.Read(key) if err != nil { log.Fatal(err) } 2. Create CSRF middleware with default configuration: csrfConfig := middleware.DefaultCSRFConfig() csrfConfig.Secret = key csrfMiddleware := middleware.CSRF(csrfConfig) 3. Add the middleware to your routing (example with standard http): handler := middleware.CreateStack( middleware.RequestID(), middleware.Logging(logger), middleware.CORS(corsConfig), middleware.CSRF(csrfConfig), // Add CSRF protection )(router) 4. In your HTML templates, include the CSRF token in forms:
Where .CSRFField is provided to the template as: data := map[string]interface{}{ "CSRFField": middleware.CSRFTemplateField(r), } 5. For AJAX requests, include the CSRF token in the request header: // JavaScript example fetch('/api/protected', { method: 'POST', headers: { 'X-CSRF-Token': document.querySelector('meta[name="csrf-token"]').getAttribute('content'), }, body: JSON.stringify(data) }) // Include meta tag in your HTML: // Where .CSRFToken is provided to the template as: data := map[string]interface{}{ "CSRFToken": middleware.CSRFToken(r), } */