package auth import ( "context" "crypto/rand" "crypto/sha256" "encoding/base64" "fmt" "log" "net/http" "net/url" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/sessions" "github.com/spf13/viper" "golang.org/x/oauth2" ) // Config holds all auth-related configuration type Config struct { Store *sessions.CookieStore OAuthConfig *oauth2.Config Verifier *oidc.IDTokenVerifier Provider *oidc.Provider SessionName string } // Setup initializes the auth configuration func Setup() (*Config, error) { // Create cookie store store := sessions.NewCookieStore( []byte(viper.GetString("session-secret")), ) store.Options = &sessions.Options{ HttpOnly: true, Secure: viper.GetString("env") == "production", SameSite: http.SameSiteLaxMode, MaxAge: 86400 * 7, // 1 week } // Initialize OIDC provider ctx := context.Background() provider, err := oidc.NewProvider(ctx, viper.GetString("oidc-idp-issuer-url")) if err != nil { return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) } // Create OAuth2 config oauthConfig := &oauth2.Config{ ClientID: viper.GetString("oidc-sp-client-id"), ClientSecret: viper.GetString("oidc-sp-client-secret"), RedirectURL: viper.GetString("hostname") + "/callback", Endpoint: provider.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } // Create auth config config := &Config{ Store: store, OAuthConfig: oauthConfig, Provider: provider, Verifier: provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}), SessionName: "auth-session", } return config, nil } // RegisterHandlers adds all auth-related handlers to the router func (c *Config) RegisterHandlers(mux *http.ServeMux) { mux.HandleFunc("/login", c.LoginHandler) mux.HandleFunc("/callback", c.CallbackHandler) mux.HandleFunc("/logout", c.LogoutHandler) mux.HandleFunc("/logout-callback", c.LogoutCallbackHandler) mux.HandleFunc("/register", c.RegistrationHandler) } // Middleware returns an auth middleware function func (c *Config) Middleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip authentication for public routes if r.URL.Path == "/login" || r.URL.Path == "/callback" || r.URL.Path == "/logout" || r.URL.Path == "/logout-callback" || r.URL.Path == "/register" { next.ServeHTTP(w, r) return } session, err := c.Store.Get(r, c.SessionName) if err != nil { http.Redirect(w, r, "/login", http.StatusFound) return } if auth, ok := session.Values["authenticated"].(bool); !ok || !auth { http.Redirect(w, r, "/login", http.StatusFound) return } next.ServeHTTP(w, r) }) } } // LoginHandler initiates the OIDC authentication flow func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) { state := generateRandomString(32) nonce := generateRandomString(32) codeVerifier := generateRandomString(32) hash := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) session, _ := c.Store.Get(r, c.SessionName) session.Values["state"] = state session.Values["nonce"] = nonce session.Values["code_verifier"] = codeVerifier if err := session.Save(r, w); err != nil { log.Printf("Error saving session: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } authURL := c.OAuthConfig.AuthCodeURL( state, oidc.Nonce(nonce), oauth2.SetAuthURLParam("code_challenge", codeChallenge), oauth2.SetAuthURLParam("code_challenge_method", "S256"), ) http.Redirect(w, r, authURL, http.StatusFound) } // CallbackHandler processes the OIDC callback func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) { ctx := context.Background() session, err := c.Store.Get(r, c.SessionName) if err != nil { http.Error(w, "Session error", http.StatusInternalServerError) return } // Validate state state := r.URL.Query().Get("state") if savedState, ok := session.Values["state"].(string); !ok || state != savedState { log.Printf("State mismatch: got %s, expected %s", state, savedState) http.Error(w, "Invalid state", http.StatusBadRequest) return } // Exchange code code := r.URL.Query().Get("code") codeVerifier, ok := session.Values["code_verifier"].(string) if !ok { http.Error(w, "No code verifier found", http.StatusBadRequest) return } token, err := c.OAuthConfig.Exchange( ctx, code, oauth2.VerifierOption(codeVerifier), ) if err != nil { log.Printf("Token exchange failed: %v", err) http.Error(w, "Token exchange failed", http.StatusInternalServerError) return } // Verify ID token rawIDToken, ok := token.Extra("id_token").(string) if !ok { http.Error(w, "Missing ID token", http.StatusInternalServerError) return } idToken, err := c.Verifier.Verify(ctx, rawIDToken) if err != nil { log.Printf("ID token verification failed: %v", err) http.Error(w, "Invalid ID token", http.StatusInternalServerError) return } // Verify nonce var claims struct { Nonce string `json:"nonce"` } if err := idToken.Claims(&claims); err != nil { http.Error(w, "Could not parse claims", http.StatusInternalServerError) return } if claims.Nonce != session.Values["nonce"].(string) { http.Error(w, "Invalid nonce", http.StatusBadRequest) return } // Authenticate session session.Values["authenticated"] = true session.Values["id_token"] = rawIDToken // Store user info var userInfo struct { Email string `json:"email"` Name string `json:"name"` Username string `json:"preferred_username"` } if err := idToken.Claims(&userInfo); err == nil { session.Values["email"] = userInfo.Email session.Values["name"] = userInfo.Name session.Values["username"] = userInfo.Username } if err := session.Save(r, w); err != nil { log.Printf("Error saving session: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } http.Redirect(w, r, "/", http.StatusFound) } // LogoutHandler handles user logout func (c *Config) LogoutHandler(w http.ResponseWriter, r *http.Request) { // Generate logout state for verification state := generateRandomString(32) // Get session session, err := c.Store.Get(r, c.SessionName) if err != nil { http.Error(w, "Session error", http.StatusInternalServerError) return } // Store logout state for verification session.Values["logout_state"] = state if err := session.Save(r, w); err != nil { log.Printf("Error saving logout state: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } // Build logout URL keycloakLogoutURL, err := url.Parse(viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/logout") if err != nil { log.Printf("Error parsing logout URL: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } // Add query parameters q := keycloakLogoutURL.Query() // Use logout-callback for completing the logout flow q.Set("post_logout_redirect_uri", viper.GetString("hostname")+"/logout-callback") q.Set("client_id", viper.GetString("oidc-sp-client-id")) q.Set("state", state) // Add id_token_hint if available if idToken, ok := session.Values["id_token"].(string); ok { q.Set("id_token_hint", idToken) } keycloakLogoutURL.RawQuery = q.Encode() // Log for debugging log.Printf("Redirecting to logout URL: %s", keycloakLogoutURL.String()) // Redirect to Keycloak http.Redirect(w, r, keycloakLogoutURL.String(), http.StatusFound) } // LogoutCallbackHandler handles the redirect after Keycloak logout func (c *Config) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) { // Get session session, err := c.Store.Get(r, c.SessionName) if err != nil { http.Error(w, "Session error", http.StatusInternalServerError) return } // Verify state parameter returnedState := r.URL.Query().Get("state") log.Printf("Received logout state: %s", returnedState) savedState, ok := session.Values["logout_state"].(string) if !ok { log.Printf("No logout state found in session") http.Error(w, "Invalid session state", http.StatusBadRequest) return } log.Printf("Saved logout state: %s", savedState) if returnedState != savedState { log.Printf("State mismatch: %s != %s", returnedState, savedState) http.Error(w, "Invalid state", http.StatusBadRequest) return } // Clear the session completely session.Options.MaxAge = -1 if err := session.Save(r, w); err != nil { log.Printf("Error clearing session: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } // Redirect to login page http.Redirect(w, r, "/login", http.StatusFound) } // RegistrationHandler redirects to the OIDC registration page func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) { // Generate random state, nonce, and code verifier for security state := generateRandomString(32) nonce := generateRandomString(32) codeVerifier := generateRandomString(32) hash := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) // Store state, nonce, and code verifier in session for verification session, _ := c.Store.Get(r, c.SessionName) session.Values["state"] = state session.Values["nonce"] = nonce session.Values["code_verifier"] = codeVerifier if err := session.Save(r, w); err != nil { log.Printf("Error saving session: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } // Build the registration URL using the specified registrations endpoint baseURL := viper.GetString("oidc-idp-issuer-url") registrationURL, err := url.Parse(baseURL + "/protocol/openid-connect/registrations") if err != nil { log.Printf("Error parsing registration URL: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } // Add query parameters q := registrationURL.Query() q.Set("client_id", viper.GetString("oidc-sp-client-id")) q.Set("response_type", "code") q.Set("scope", "openid email profile") q.Set("redirect_uri", viper.GetString("hostname")+"/callback") q.Set("state", state) q.Set("nonce", nonce) q.Set("code_challenge", codeChallenge) q.Set("code_challenge_method", "S256") registrationURL.RawQuery = q.Encode() // Log for debugging log.Printf("Redirecting to registration URL: %s", registrationURL.String()) // Redirect to registration page http.Redirect(w, r, registrationURL.String(), http.StatusFound) } // Helper function to generate random strings func generateRandomString(n int) string { b := make([]byte, n) if _, err := rand.Read(b); err != nil { panic(err) } return base64.RawURLEncoding.EncodeToString(b) }