package auth import ( "context" "crypto/rand" "crypto/sha256" "database/sql" "encoding/base64" "fmt" "log" "net/http" "net/url" "git.coopcloud.tech/wiki-cafe/member-console/internal/db" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/sessions" "github.com/spf13/viper" "golang.org/x/oauth2" ) // Config holds all auth-related configuration type Config struct { Store *sessions.CookieStore OAuthConfig *oauth2.Config Verifier *oidc.IDTokenVerifier Provider *oidc.Provider SessionName string DB db.Querier } // Setup initializes the auth configuration func Setup(database db.Querier) (*Config, error) { // Create cookie store store := sessions.NewCookieStore( []byte(viper.GetString("session-secret")), ) store.Options = &sessions.Options{ HttpOnly: true, Secure: viper.GetString("env") == "production", SameSite: http.SameSiteLaxMode, MaxAge: 86400 * 7, // 1 week } // Initialize OIDC provider ctx := context.Background() provider, err := oidc.NewProvider(ctx, viper.GetString("oidc-idp-issuer-url")) if err != nil { return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) } // Create OAuth2 config oauthConfig := &oauth2.Config{ ClientID: viper.GetString("oidc-sp-client-id"), ClientSecret: viper.GetString("oidc-sp-client-secret"), RedirectURL: viper.GetString("base-url") + "/callback", Endpoint: provider.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } // Create auth config config := &Config{ Store: store, OAuthConfig: oauthConfig, Provider: provider, Verifier: provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}), SessionName: "auth-session", DB: database, } return config, nil } // RegisterHandlers adds all auth-related handlers to the router func (c *Config) RegisterHandlers(mux *http.ServeMux) { mux.HandleFunc("/login", c.LoginHandler) mux.HandleFunc("/callback", c.CallbackHandler) mux.HandleFunc("/logout", c.LogoutHandler) mux.HandleFunc("/logout-callback", c.LogoutCallbackHandler) mux.HandleFunc("/register", c.RegistrationHandler) } // Middleware returns an auth middleware function func (c *Config) Middleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip authentication for public routes publicPaths := map[string]bool{ "/login": true, "/callback": true, "/logout": true, "/logout-callback": true, "/register": true, } if publicPaths[r.URL.Path] { next.ServeHTTP(w, r) return } session, err := c.Store.Get(r, c.SessionName) if err != nil { http.Redirect(w, r, "/login", http.StatusFound) return } if auth, ok := session.Values["authenticated"].(bool); !ok || !auth { http.Redirect(w, r, "/login", http.StatusFound) return } next.ServeHTTP(w, r) }) } } // LoginHandler initiates the OIDC authentication flow func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) { state := generateRandomString(32) nonce := generateRandomString(32) codeVerifier := generateRandomString(32) hashVal := sha256.Sum256([]byte(codeVerifier)) // Renamed from hash to hashVal to avoid conflict codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:]) session, _ := c.Store.Get(r, c.SessionName) session.Values["state"] = state session.Values["nonce"] = nonce session.Values["code_verifier"] = codeVerifier if err := session.Save(r, w); err != nil { log.Printf("Error saving session: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } authURL := c.OAuthConfig.AuthCodeURL( state, oidc.Nonce(nonce), oauth2.SetAuthURLParam("code_challenge", codeChallenge), oauth2.SetAuthURLParam("code_challenge_method", "S256"), ) http.Redirect(w, r, authURL, http.StatusFound) } // CallbackHandler processes the OIDC callback func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Use request's context session, err := c.Store.Get(r, c.SessionName) if err != nil { log.Printf("CallbackHandler: Session error: %v", err) http.Error(w, "Session error", http.StatusInternalServerError) return } // Validate state state := r.URL.Query().Get("state") if savedState, ok := session.Values["state"].(string); !ok || state != savedState { log.Printf("CallbackHandler: State mismatch: got %s, expected %s", state, savedState) http.Error(w, "Invalid state", http.StatusBadRequest) return } // Exchange code code := r.URL.Query().Get("code") codeVerifier, ok := session.Values["code_verifier"].(string) if !ok { log.Printf("CallbackHandler: No code verifier found in session") http.Error(w, "No code verifier found", http.StatusBadRequest) return } token, err := c.OAuthConfig.Exchange( ctx, code, oauth2.VerifierOption(codeVerifier), ) if err != nil { log.Printf("CallbackHandler: Token exchange failed: %v", err) http.Error(w, "Token exchange failed", http.StatusInternalServerError) return } // Verify ID token rawIDToken, ok := token.Extra("id_token").(string) if !ok { log.Printf("CallbackHandler: Missing ID token from token extra") http.Error(w, "Missing ID token", http.StatusInternalServerError) return } idToken, err := c.Verifier.Verify(ctx, rawIDToken) if err != nil { log.Printf("CallbackHandler: ID token verification failed: %v", err) http.Error(w, "Invalid ID token", http.StatusInternalServerError) return } // Verify nonce var nonceClaims struct { Nonce string `json:"nonce"` } if err := idToken.Claims(&nonceClaims); err != nil { log.Printf("CallbackHandler: Could not parse nonce claims: %v", err) http.Error(w, "Could not parse claims for nonce", http.StatusInternalServerError) return } if nonceClaims.Nonce != session.Values["nonce"].(string) { log.Printf("CallbackHandler: Invalid nonce") http.Error(w, "Invalid nonce", http.StatusBadRequest) return } // ---- Database Interaction & User Info Extraction ---- var userInfoClaims struct { Subject string `json:"sub"` Email string `json:"email"` EmailVerified bool `json:"email_verified"` Name string `json:"name"` PreferredUsername string `json:"preferred_username"` GivenName string `json:"given_name"` FamilyName string `json:"family_name"` } if err := idToken.Claims(&userInfoClaims); err != nil { log.Printf("CallbackHandler: Error parsing user info claims: %v", err) http.Error(w, "Could not parse user info claims", http.StatusInternalServerError) return } user, err := c.DB.GetUserByOIDCSubject(ctx, userInfoClaims.Subject) if err != nil { if err == sql.ErrNoRows { // User not found, create a new user newUserParams := db.CreateUserParams{ OidcSubject: userInfoClaims.Subject, Username: userInfoClaims.PreferredUsername, Email: userInfoClaims.Email, } createdUser, errCreate := c.DB.CreateUser(ctx, newUserParams) if errCreate != nil { log.Printf("CallbackHandler: Error creating user in DB: %v", errCreate) http.Error(w, "Failed to create user record", http.StatusInternalServerError) return } user = createdUser log.Printf("CallbackHandler: New user created: %s (%s), DB ID: %d", user.Username, user.Email, user.ID) } else { log.Printf("CallbackHandler: Error fetching user by OIDC Subject ID '%s': %v", userInfoClaims.Subject, err) http.Error(w, "Database error while fetching user", http.StatusInternalServerError) return } } else { // User found in DB if user.Username != userInfoClaims.PreferredUsername || user.Email != userInfoClaims.Email { updatedUser, errUpdate := c.DB.UpdateUser(ctx, db.UpdateUserParams{ ID: user.ID, Username: userInfoClaims.PreferredUsername, Email: userInfoClaims.Email, }) if errUpdate != nil { log.Printf("CallbackHandler: Error updating user %d: %v", user.ID, errUpdate) } else { user = updatedUser log.Printf("CallbackHandler: User details updated for %s, DB ID: %d", user.Email, user.ID) } } else { log.Printf("CallbackHandler: User found: %s (%s), DB ID: %d", user.Username, user.Email, user.ID) } } session.Values["authenticated"] = true session.Values["id_token"] = rawIDToken session.Values["user_db_id"] = user.ID session.Values["oidc_subject"] = user.OidcSubject session.Values["email"] = user.Email session.Values["name"] = userInfoClaims.Name session.Values["username"] = user.Username if err := session.Save(r, w); err != nil { log.Printf("CallbackHandler: Error saving session: %v", err) http.Error(w, "Server error saving session", http.StatusInternalServerError) return } log.Printf("CallbackHandler: User %s authenticated successfully. Redirecting to /", user.Email) http.Redirect(w, r, "/", http.StatusFound) } // LogoutHandler handles user logout func (c *Config) LogoutHandler(w http.ResponseWriter, r *http.Request) { // Generate logout state for verification state := generateRandomString(32) // Get session session, err := c.Store.Get(r, c.SessionName) if err != nil { http.Error(w, "Session error", http.StatusInternalServerError) return } // Store logout state for verification session.Values["logout_state"] = state if err := session.Save(r, w); err != nil { log.Printf("Error saving logout state: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } // Build logout URL keycloakLogoutURL, err := url.Parse(viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/logout") if err != nil { log.Printf("Error parsing logout URL: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } // Add query parameters q := keycloakLogoutURL.Query() // Use logout-callback for completing the logout flow q.Set("post_logout_redirect_uri", viper.GetString("base-url")+"/logout-callback") q.Set("client_id", viper.GetString("oidc-sp-client-id")) q.Set("state", state) // Add id_token_hint if available if idToken, ok := session.Values["id_token"].(string); ok { q.Set("id_token_hint", idToken) } keycloakLogoutURL.RawQuery = q.Encode() // Log for debugging log.Printf("Redirecting to logout URL: %s", keycloakLogoutURL.String()) // Redirect to Keycloak http.Redirect(w, r, keycloakLogoutURL.String(), http.StatusFound) } // LogoutCallbackHandler handles the redirect after Keycloak logout func (c *Config) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) { // Get session session, err := c.Store.Get(r, c.SessionName) if err != nil { http.Error(w, "Session error", http.StatusInternalServerError) return } // Verify state parameter returnedState := r.URL.Query().Get("state") log.Printf("Received logout state: %s", returnedState) savedState, ok := session.Values["logout_state"].(string) if !ok { log.Printf("No logout state found in session") http.Error(w, "Invalid session state", http.StatusBadRequest) return } log.Printf("Saved logout state: %s", savedState) if returnedState != savedState { log.Printf("State mismatch: %s != %s", returnedState, savedState) http.Error(w, "Invalid state", http.StatusBadRequest) return } // Clear the session completely session.Options.MaxAge = -1 if err := session.Save(r, w); err != nil { log.Printf("Error clearing session: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } // Redirect to login page http.Redirect(w, r, "/login", http.StatusFound) } // RegistrationHandler redirects to the OIDC registration page func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) { state := generateRandomString(32) nonce := generateRandomString(32) codeVerifier := generateRandomString(32) hashVal := sha256.Sum256([]byte(codeVerifier)) // Renamed from hash to hashVal codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:]) session, _ := c.Store.Get(r, c.SessionName) session.Values["state"] = state session.Values["nonce"] = nonce session.Values["code_verifier"] = codeVerifier if err := session.Save(r, w); err != nil { log.Printf("RegistrationHandler: Error saving session: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } registrationEndpoint := viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/registrations" // For some providers, you might need to use the authorization endpoint with a specific parameter like `kc_action=register` // e.g., registrationEndpoint := c.OAuthConfig.Endpoint.AuthURL // And then add q.Set("kc_action", "register") or similar. This depends on the OIDC provider. parsedRegistrationURL, err := url.Parse(registrationEndpoint) if err != nil { log.Printf("RegistrationHandler: Error parsing registration URL '%s': %v", registrationEndpoint, err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } q := parsedRegistrationURL.Query() q.Set("client_id", viper.GetString("oidc-sp-client-id")) q.Set("response_type", "code") q.Set("scope", "openid email profile") q.Set("redirect_uri", viper.GetString("base-url")+"/callback") q.Set("state", state) q.Set("nonce", nonce) q.Set("code_challenge", codeChallenge) q.Set("code_challenge_method", "S256") parsedRegistrationURL.RawQuery = q.Encode() log.Printf("RegistrationHandler: Redirecting to registration URL: %s", parsedRegistrationURL.String()) http.Redirect(w, r, parsedRegistrationURL.String(), http.StatusFound) } // Helper function to generate random strings func generateRandomString(n int) string { b := make([]byte, n) if _, err := rand.Read(b); err != nil { log.Fatalf("Failed to generate random string: %v", err) panic(err) } return base64.RawURLEncoding.EncodeToString(b) }