package middleware import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "net/http" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/sessions" "golang.org/x/oauth2" ) type AuthConfig struct { Store *sessions.CookieStore OAuthConfig *oauth2.Config Verifier *oidc.IDTokenVerifier } func AuthMiddleware(config *AuthConfig) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip authentication for public routes if r.URL.Path == "/login" || r.URL.Path == "/callback" { next.ServeHTTP(w, r) return } session, _ := config.Store.Get(r, "auth-session") if auth, ok := session.Values["authenticated"].(bool); !ok || !auth { http.Redirect(w, r, "/login", http.StatusFound) return } next.ServeHTTP(w, r) }) } } func generateRandomString(n int) string { b := make([]byte, n) rand.Read(b) return base64.RawURLEncoding.EncodeToString(b) } func LoginHandler(config *AuthConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { state := generateRandomString(32) nonce := generateRandomString(32) codeVerifier := generateRandomString(32) hash := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) session, _ := config.Store.Get(r, "auth-session") session.Values["state"] = state session.Values["nonce"] = nonce session.Values["code_verifier"] = codeVerifier session.Save(r, w) authURL := config.OAuthConfig.AuthCodeURL( state, oidc.Nonce(nonce), oauth2.SetAuthURLParam("code_challenge", codeChallenge), oauth2.SetAuthURLParam("code_challenge_method", "S256"), ) http.Redirect(w, r, authURL, http.StatusFound) } } func CallbackHandler(config *AuthConfig) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := context.Background() session, _ := config.Store.Get(r, "auth-session") // Validate state state := r.URL.Query().Get("state") if savedState, ok := session.Values["state"].(string); !ok || state != savedState { http.Error(w, "Invalid state", http.StatusBadRequest) return } // Exchange code codeVerifier := session.Values["code_verifier"].(string) token, err := config.OAuthConfig.Exchange( ctx, r.URL.Query().Get("code"), oauth2.VerifierOption(codeVerifier), ) if err != nil { http.Error(w, "Token exchange failed", http.StatusInternalServerError) return } // Verify ID token rawIDToken, ok := token.Extra("id_token").(string) if !ok { http.Error(w, "Missing ID token", http.StatusInternalServerError) return } idToken, err := config.Verifier.Verify(ctx, rawIDToken) if err != nil { http.Error(w, "Invalid ID token", http.StatusInternalServerError) return } // Verify nonce var claims struct { Nonce string `json:"nonce"` } if err := idToken.Claims(&claims); err != nil || claims.Nonce != session.Values["nonce"].(string) { http.Error(w, "Invalid nonce", http.StatusBadRequest) return } // Authenticate session session.Values["authenticated"] = true session.Save(r, w) http.Redirect(w, r, "/", http.StatusFound) } }