package middleware

import (
	"fmt"
	"html/template"
	"net/http"

	"github.com/gorilla/csrf"
)

// CSRFConfig defines the configuration options for CSRF middleware
type CSRFConfig struct {
	// Secret is the 32-byte secret key used to generate tokens
	Secret []byte

	// Cookie defines cookie options
	Cookie struct {
		Name     string
		Domain   string
		HttpOnly bool
		MaxAge   int
		Path     string
		SameSite csrf.SameSiteMode
		Secure   bool
	}

	// ErrorHandler is a custom error handler for CSRF errors
	ErrorHandler http.Handler

	// FieldName is the name of the hidden form field containing the CSRF token
	FieldName string

	// RequestHeader is the name of the request header containing the CSRF token
	RequestHeader string

	// TrustedOrigins defines trusted origins for CSRF protection
	TrustedOrigins []string

	// Path defines URL paths where CSRF protection applies
	// If empty, all paths are protected
	Path string

	// Ignore functions determine if a request should skip CSRF protection
	Ignore []func(r *http.Request) bool
}

// CSRF middleware provides Cross-Site Request Forgery protection
func CSRF(config CSRFConfig) Middleware {
	// Only set options that are explicitly configured
	var options []csrf.Option

	// Cookie options
	if config.Cookie.Name != "" {
		options = append(options, csrf.CookieName(config.Cookie.Name))
	}
	if config.Cookie.Path != "" {
		options = append(options, csrf.Path(config.Cookie.Path))
	}
	if config.Cookie.MaxAge != 0 {
		options = append(options, csrf.MaxAge(config.Cookie.MaxAge))
	}
	if config.Cookie.Domain != "" {
		options = append(options, csrf.Domain(config.Cookie.Domain))
	}
	if !config.Cookie.Secure {
		options = append(options, csrf.Secure(false))
	}
	if !config.Cookie.HttpOnly {
		options = append(options, csrf.HttpOnly(false))
	}
	if config.Cookie.SameSite != 0 {
		options = append(options, csrf.SameSite(config.Cookie.SameSite))
	}

	// Other options
	if config.FieldName != "" {
		options = append(options, csrf.FieldName(config.FieldName))
	}
	if config.RequestHeader != "" {
		options = append(options, csrf.RequestHeader(config.RequestHeader))
	}
	if config.ErrorHandler != nil {
		options = append(options, csrf.ErrorHandler(config.ErrorHandler))
	}
	if len(config.TrustedOrigins) > 0 {
		options = append(options, csrf.TrustedOrigins(config.TrustedOrigins))
	}

	// Create CSRF protection middleware
	csrfHandler := csrf.Protect(config.Secret, options...)

	return func(next http.Handler) http.Handler {
		// Handle protection path
		if config.Path != "" {
			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				if r.URL.Path == config.Path || (len(r.URL.Path) >= len(config.Path) &&
					r.URL.Path[:len(config.Path)] == config.Path) {
					// Check if the request should be ignored
					if config.Ignore != nil {
						for _, ignoreFunc := range config.Ignore {
							if ignoreFunc(r) {
								next.ServeHTTP(w, r)
								return
							}
						}
					}
					csrfHandler(next).ServeHTTP(w, r)
					return
				}
				next.ServeHTTP(w, r)
			})
		}

		// Handle ignore functions
		if len(config.Ignore) > 0 {
			return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
				for _, ignoreFunc := range config.Ignore {
					if ignoreFunc(r) {
						next.ServeHTTP(w, r)
						return
					}
				}
				csrfHandler(next).ServeHTTP(w, r)
			})
		}

		// Apply CSRF to all routes if no path or ignores specified
		return csrfHandler(next)
	}
}

// CSRFToken gets the CSRF token from the request context
func CSRFToken(r *http.Request) string {
	return csrf.Token(r)
}

// CSRFTemplateField gets the hidden input field containing the CSRF token
func CSRFTemplateField(r *http.Request) template.HTML {
	return csrf.TemplateField(r)
}

// ParseCSRFKey validates and converts a CSRF secret string to the required 32-byte key
// It returns the key as a byte slice and an error if the key is invalid
func ParseCSRFKey(secret string) ([]byte, error) {
	if secret == "" {
		return nil, fmt.Errorf("csrf secret is required and must be exactly 32 bytes")
	}
	
	key := []byte(secret)
	if len(key) != 32 {
		return nil, fmt.Errorf("csrf secret must be exactly 32 bytes (got %d bytes)", len(key))
	}
	
	return key, nil
}