package auth import ( "context" "crypto/rand" "crypto/sha256" "database/sql" "encoding/base64" "encoding/json" "fmt" "log" "net/http" "net/url" "strings" "time" "git.coopcloud.tech/wiki-cafe/member-console/internal/identity" "git.coopcloud.tech/wiki-cafe/member-console/internal/organization" "git.coopcloud.tech/wiki-cafe/member-console/internal/provisioning" "github.com/alexedwards/scs/redisstore" "github.com/alexedwards/scs/v2" "github.com/coreos/go-oidc/v3/oidc" "github.com/gomodule/redigo/redis" "github.com/spf13/viper" "golang.org/x/oauth2" ) // Config holds all auth-related configuration type Config struct { SessionManager *scs.SessionManager OAuthConfig *oauth2.Config Verifier *oidc.IDTokenVerifier Provider *oidc.Provider Database *sql.DB // Raw DB for transactions (auto-provisioning) IdentityQ identity.Querier // Identity module queries OrgQ organization.Querier // Organization module queries } // Setup initializes the auth configuration func Setup(database *sql.DB, identityQ identity.Querier, orgQ organization.Querier) (*Config, error) { // Create Redis pool for Valkey pool := &redis.Pool{ MaxIdle: 10, IdleTimeout: 240 * time.Second, Dial: func() (redis.Conn, error) { return redis.Dial("tcp", viper.GetString("valkey-addr")) }, TestOnBorrow: func(c redis.Conn, t time.Time) error { if time.Since(t) < time.Minute { return nil } _, err := c.Do("PING") return err }, } // Test connection conn := pool.Get() defer conn.Close() if _, err := conn.Do("PING"); err != nil { return nil, fmt.Errorf("failed to connect to Valkey: %w", err) } log.Printf("Connected to Valkey at %s", viper.GetString("valkey-addr")) // Create session manager sessionManager := scs.New() sessionManager.Store = redisstore.New(pool) sessionManager.Lifetime = 7 * 24 * time.Hour // 1 week sessionManager.Cookie.Name = "session" sessionManager.Cookie.Path = "/" sessionManager.Cookie.HttpOnly = true sessionManager.Cookie.Secure = viper.GetString("env") == "production" sessionManager.Cookie.SameSite = http.SameSiteLaxMode // Initialize OIDC provider ctx := context.Background() provider, err := oidc.NewProvider(ctx, viper.GetString("oidc-idp-issuer-url")) if err != nil { return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) } // Create OAuth2 config oauthConfig := &oauth2.Config{ ClientID: viper.GetString("oidc-sp-client-id"), ClientSecret: viper.GetString("oidc-sp-client-secret"), RedirectURL: viper.GetString("base-url") + "/callback", Endpoint: provider.Endpoint(), Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } // Create auth config config := &Config{ SessionManager: sessionManager, OAuthConfig: oauthConfig, Provider: provider, Verifier: provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}), Database: database, IdentityQ: identityQ, OrgQ: orgQ, } return config, nil } // RegisterHandlers adds all auth-related handlers to the router func (c *Config) RegisterHandlers(mux *http.ServeMux) { mux.HandleFunc("/login", c.LoginHandler) mux.HandleFunc("/callback", c.CallbackHandler) mux.HandleFunc("/logout", c.LogoutHandler) mux.HandleFunc("/logout-callback", c.LogoutCallbackHandler) mux.HandleFunc("/register", c.RegistrationHandler) } // Middleware returns an auth middleware function func (c *Config) Middleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip authentication for public routes publicPaths := map[string]bool{ "/login": true, "/callback": true, "/logout": true, "/logout-callback": true, "/register": true, "/favicon.ico": true, // Browser-fired in parallel with page loads; must not bounce through /login "/webhooks/stripe": true, // Webhooks use signature verification, not session auth } if publicPaths[r.URL.Path] { next.ServeHTTP(w, r) return } // Skip authentication for static assets (public, no session needed) if strings.HasPrefix(r.URL.Path, "/static/") { next.ServeHTTP(w, r) return } // Check if authenticated if !c.SessionManager.GetBool(r.Context(), sessionKeyAuthenticated) { http.Redirect(w, r, "/login", http.StatusFound) return } next.ServeHTTP(w, r) }) } } // LoginHandler initiates the OIDC authentication flow func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) { state := generateRandomString(32) nonce := generateRandomString(32) codeVerifier := generateRandomString(32) hashVal := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:]) // Store OIDC flow data in session c.SessionManager.Put(r.Context(), sessionKeyState, state) c.SessionManager.Put(r.Context(), sessionKeyNonce, nonce) c.SessionManager.Put(r.Context(), sessionKeyCodeVerifier, codeVerifier) authURL := c.OAuthConfig.AuthCodeURL( state, oidc.Nonce(nonce), oauth2.SetAuthURLParam("code_challenge", codeChallenge), oauth2.SetAuthURLParam("code_challenge_method", "S256"), ) http.Redirect(w, r, authURL, http.StatusFound) } // CallbackHandler processes the OIDC callback func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Validate state state := r.URL.Query().Get("state") savedState := c.SessionManager.GetString(ctx, sessionKeyState) if state != savedState { log.Printf("CallbackHandler: State mismatch: got %s, expected %s", state, savedState) http.Error(w, "Invalid state", http.StatusBadRequest) return } // Exchange code code := r.URL.Query().Get("code") codeVerifier := c.SessionManager.GetString(ctx, sessionKeyCodeVerifier) if codeVerifier == "" { log.Printf("CallbackHandler: No code verifier found in session") http.Error(w, "No code verifier found", http.StatusBadRequest) return } token, err := c.OAuthConfig.Exchange( ctx, code, oauth2.VerifierOption(codeVerifier), ) if err != nil { log.Printf("CallbackHandler: Token exchange failed: %v", err) http.Error(w, "Token exchange failed", http.StatusInternalServerError) return } // Verify ID token rawIDToken, ok := token.Extra("id_token").(string) if !ok { log.Printf("CallbackHandler: Missing ID token from token extra") http.Error(w, "Missing ID token", http.StatusInternalServerError) return } idToken, err := c.Verifier.Verify(ctx, rawIDToken) if err != nil { log.Printf("CallbackHandler: ID token verification failed: %v", err) http.Error(w, "Invalid ID token", http.StatusInternalServerError) return } // Verify nonce var nonceClaims struct { Nonce string `json:"nonce"` } if err := idToken.Claims(&nonceClaims); err != nil { log.Printf("CallbackHandler: Could not parse nonce claims: %v", err) http.Error(w, "Could not parse claims for nonce", http.StatusInternalServerError) return } savedNonce := c.SessionManager.GetString(ctx, sessionKeyNonce) if nonceClaims.Nonce != savedNonce { log.Printf("CallbackHandler: Invalid nonce") http.Error(w, "Invalid nonce", http.StatusBadRequest) return } // ---- Database Interaction & User Info Extraction ---- var userInfoClaims struct { Subject string `json:"sub"` Email string `json:"email"` EmailVerified bool `json:"email_verified"` Name string `json:"name"` PreferredUsername string `json:"preferred_username"` GivenName string `json:"given_name"` FamilyName string `json:"family_name"` } if err := idToken.Claims(&userInfoClaims); err != nil { log.Printf("CallbackHandler: Error parsing user info claims: %v", err) http.Error(w, "Could not parse user info claims", http.StatusInternalServerError) return } // Session variables to populate var personID, orgID, workspaceID string user, err := c.IdentityQ.GetUserByOIDCSubject(ctx, userInfoClaims.Subject) if err != nil { if err == sql.ErrNoRows { // New user — auto-provision all governance structures result, errProv := provisioning.AutoProvision(ctx, c.Database, provisioning.OIDCClaims{ Subject: userInfoClaims.Subject, Email: userInfoClaims.Email, EmailVerified: userInfoClaims.EmailVerified, Name: userInfoClaims.Name, PreferredUsername: userInfoClaims.PreferredUsername, }) if errProv != nil { log.Printf("CallbackHandler: Error auto-provisioning user: %v", errProv) http.Error(w, "Failed to create user record", http.StatusInternalServerError) return } user = result.User personID = result.Person.PersonID orgID = result.Org.OrgID workspaceID = result.Workspace.WorkspaceID log.Printf("CallbackHandler: New user provisioned: %s (person=%s, org=%s)", userInfoClaims.Email, personID, orgID) } else { log.Printf("CallbackHandler: Error fetching user by OIDC Subject ID '%s': %v", userInfoClaims.Subject, err) http.Error(w, "Database error while fetching user", http.StatusInternalServerError) return } } else { // Returning user — load existing governance records and update login _, errLogin := c.IdentityQ.UpdateUserLogin(ctx, identity.UpdateUserLoginParams{ LastLoginAt: sql.NullTime{Time: time.Now(), Valid: true}, LastLoginIp: sql.NullString{String: r.RemoteAddr, Valid: true}, UserID: user.UserID, }) if errLogin != nil { log.Printf("CallbackHandler: Error updating user login: %v", errLogin) } // Load person person, errPerson := c.IdentityQ.GetPersonByUserID(ctx, user.UserID) if errPerson != nil { log.Printf("CallbackHandler: Error loading person for user %s: %v", user.UserID, errPerson) http.Error(w, "Failed to load user profile", http.StatusInternalServerError) return } personID = person.PersonID // Update person if OIDC claims changed if person.DisplayName != userInfoClaims.Name || person.PrimaryEmail != userInfoClaims.Email { _, errUpdate := c.IdentityQ.UpdatePerson(ctx, identity.UpdatePersonParams{ DisplayName: userInfoClaims.Name, PrimaryEmail: userInfoClaims.Email, PrimaryEmailVerified: userInfoClaims.EmailVerified, PersonID: person.PersonID, }) if errUpdate != nil { log.Printf("CallbackHandler: Error updating person %s: %v", person.PersonID, errUpdate) } } // Load personal org (first org owned by this person) orgs, errOrgs := c.OrgQ.GetOrganizationsByOwner(ctx, personID) if errOrgs != nil || len(orgs) == 0 { log.Printf("CallbackHandler: Error loading org for person %s: %v", personID, errOrgs) http.Error(w, "Failed to load organization", http.StatusInternalServerError) return } orgID = orgs[0].OrgID // Load default workspace workspaces, errWS := c.OrgQ.GetWorkspacesByOrgID(ctx, orgID) if errWS != nil || len(workspaces) == 0 { log.Printf("CallbackHandler: Error loading workspace for org %s: %v", orgID, errWS) http.Error(w, "Failed to load workspace", http.StatusInternalServerError) return } workspaceID = workspaces[0].WorkspaceID log.Printf("CallbackHandler: Returning user: %s (person=%s, org=%s)", userInfoClaims.Email, personID, orgID) } // ---- Extract roles from access token (IDP-agnostic) ---- var rolesClaims struct { Roles []string `json:"roles"` Groups []string `json:"groups"` RealmAccess struct { Roles []string `json:"roles"` } `json:"realm_access"` } accessToken := token.AccessToken if accessToken != "" { if parts := splitJWT(accessToken); len(parts) == 3 { if payload, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil { if err := json.Unmarshal(payload, &rolesClaims); err != nil { log.Printf("CallbackHandler: Could not parse roles from access token: %v", err) } } } } roles := extractRoles(rolesClaims) log.Printf("CallbackHandler: User %s has IdP roles: %v", userInfoClaims.Email, roles) // Clear OIDC flow data and set authenticated session data c.SessionManager.Remove(ctx, sessionKeyState) c.SessionManager.Remove(ctx, sessionKeyNonce) c.SessionManager.Remove(ctx, sessionKeyCodeVerifier) c.SessionManager.Put(ctx, sessionKeyAuthenticated, true) c.SessionManager.Put(ctx, sessionKeyIDToken, rawIDToken) c.SessionManager.Put(ctx, sessionKeyPersonID, personID) c.SessionManager.Put(ctx, sessionKeyOrgID, orgID) c.SessionManager.Put(ctx, sessionKeyWorkspaceID, workspaceID) c.SessionManager.Put(ctx, sessionKeyOIDCSubject, user.OidcSubject) c.SessionManager.Put(ctx, sessionKeyEmail, userInfoClaims.Email) c.SessionManager.Put(ctx, sessionKeyName, userInfoClaims.Name) c.SessionManager.Put(ctx, sessionKeyUsername, userInfoClaims.PreferredUsername) c.SessionManager.Put(ctx, sessionKeyRoles, roles) log.Printf("CallbackHandler: User %s authenticated successfully. Redirecting to /", userInfoClaims.Email) http.Redirect(w, r, "/", http.StatusFound) } // LogoutHandler handles user logout func (c *Config) LogoutHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Generate logout state for verification state := generateRandomString(32) c.SessionManager.Put(ctx, sessionKeyLogoutState, state) // Build logout URL keycloakLogoutURL, err := url.Parse(viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/logout") if err != nil { log.Printf("Error parsing logout URL: %v", err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } // Add query parameters q := keycloakLogoutURL.Query() q.Set("post_logout_redirect_uri", viper.GetString("base-url")+"/logout-callback") q.Set("client_id", viper.GetString("oidc-sp-client-id")) q.Set("state", state) // Add id_token_hint if available if idToken := c.SessionManager.GetString(ctx, sessionKeyIDToken); idToken != "" { q.Set("id_token_hint", idToken) } keycloakLogoutURL.RawQuery = q.Encode() log.Printf("Redirecting to logout URL: %s", keycloakLogoutURL.String()) http.Redirect(w, r, keycloakLogoutURL.String(), http.StatusFound) } // LogoutCallbackHandler handles the redirect after Keycloak logout func (c *Config) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Verify state parameter returnedState := r.URL.Query().Get("state") savedState := c.SessionManager.GetString(ctx, sessionKeyLogoutState) if savedState == "" { log.Printf("No logout state found in session") http.Error(w, "Invalid session state", http.StatusBadRequest) return } if returnedState != savedState { log.Printf("State mismatch: %s != %s", returnedState, savedState) http.Error(w, "Invalid state", http.StatusBadRequest) return } // Destroy the session if err := c.SessionManager.Destroy(ctx); err != nil { log.Printf("Error destroying session: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } // Redirect to login page http.Redirect(w, r, "/login", http.StatusFound) } // RegistrationHandler redirects to the OIDC registration page func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) { ctx := r.Context() state := generateRandomString(32) nonce := generateRandomString(32) codeVerifier := generateRandomString(32) hashVal := sha256.Sum256([]byte(codeVerifier)) codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:]) c.SessionManager.Put(ctx, sessionKeyState, state) c.SessionManager.Put(ctx, sessionKeyNonce, nonce) c.SessionManager.Put(ctx, sessionKeyCodeVerifier, codeVerifier) registrationEndpoint := viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/registrations" parsedRegistrationURL, err := url.Parse(registrationEndpoint) if err != nil { log.Printf("RegistrationHandler: Error parsing registration URL '%s': %v", registrationEndpoint, err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } q := parsedRegistrationURL.Query() q.Set("client_id", viper.GetString("oidc-sp-client-id")) q.Set("response_type", "code") q.Set("scope", "openid email profile") q.Set("redirect_uri", viper.GetString("base-url")+"/callback") q.Set("state", state) q.Set("nonce", nonce) q.Set("code_challenge", codeChallenge) q.Set("code_challenge_method", "S256") parsedRegistrationURL.RawQuery = q.Encode() log.Printf("RegistrationHandler: Redirecting to registration URL: %s", parsedRegistrationURL.String()) http.Redirect(w, r, parsedRegistrationURL.String(), http.StatusFound) } // Session key constants - single source of truth for session data const ( sessionKeyAuthenticated = "authenticated" sessionKeyIDToken = "id_token" sessionKeyPersonID = "person_id" sessionKeyOrgID = "org_id" sessionKeyWorkspaceID = "workspace_id" sessionKeyOIDCSubject = "oidc_subject" sessionKeyEmail = "email" sessionKeyName = "name" sessionKeyUsername = "username" sessionKeyRoles = "roles" // OIDC flow keys (temporary) sessionKeyState = "state" sessionKeyNonce = "nonce" sessionKeyCodeVerifier = "code_verifier" sessionKeyLogoutState = "logout_state" ) // UserSession contains the authenticated user's session data. // This provides type-safe access to session values. type UserSession struct { PersonID string OrgID string WorkspaceID string OIDCSubject string Email string Name string Username string Roles []string } // GetUserSession retrieves the authenticated user's session data. // Returns nil if the user is not authenticated. func (c *Config) GetUserSession(ctx context.Context) *UserSession { if !c.SessionManager.GetBool(ctx, sessionKeyAuthenticated) { return nil } return &UserSession{ PersonID: c.SessionManager.GetString(ctx, sessionKeyPersonID), OrgID: c.SessionManager.GetString(ctx, sessionKeyOrgID), WorkspaceID: c.SessionManager.GetString(ctx, sessionKeyWorkspaceID), OIDCSubject: c.SessionManager.GetString(ctx, sessionKeyOIDCSubject), Email: c.SessionManager.GetString(ctx, sessionKeyEmail), Name: c.SessionManager.GetString(ctx, sessionKeyName), Username: c.SessionManager.GetString(ctx, sessionKeyUsername), Roles: c.getRoles(ctx), } } // GetPersonID returns the person UUID of the authenticated user. // Returns empty string if the user is not authenticated. func (c *Config) GetPersonID(ctx context.Context) string { return c.SessionManager.GetString(ctx, sessionKeyPersonID) } // GetOrgID returns the active organization UUID. func (c *Config) GetOrgID(ctx context.Context) string { return c.SessionManager.GetString(ctx, sessionKeyOrgID) } // GetWorkspaceID returns the active workspace UUID. func (c *Config) GetWorkspaceID(ctx context.Context) string { return c.SessionManager.GetString(ctx, sessionKeyWorkspaceID) } // GetUserEmail returns the email of the authenticated user. func (c *Config) GetUserEmail(ctx context.Context) string { return c.SessionManager.GetString(ctx, sessionKeyEmail) } // GetUserName returns the display name of the authenticated user. func (c *Config) GetUserName(ctx context.Context) string { return c.SessionManager.GetString(ctx, sessionKeyName) } // GetUsername returns the username of the authenticated user. func (c *Config) GetUsername(ctx context.Context) string { return c.SessionManager.GetString(ctx, sessionKeyUsername) } // IsAuthenticated returns true if the user has an active session. func (c *Config) IsAuthenticated(ctx context.Context) bool { return c.SessionManager.GetBool(ctx, sessionKeyAuthenticated) } // HasRole checks if the current user has the specified role. func (c *Config) HasRole(r *http.Request, role string) bool { roles := c.getRoles(r.Context()) for _, r := range roles { if r == role { return true } } return false } // getRoles retrieves the roles from session. func (c *Config) getRoles(ctx context.Context) []string { roles, ok := c.SessionManager.Get(ctx, sessionKeyRoles).([]string) if !ok { return nil } return roles } // Helper function to generate random strings func generateRandomString(n int) string { b := make([]byte, n) if _, err := rand.Read(b); err != nil { log.Fatalf("Failed to generate random string: %v", err) panic(err) } return base64.RawURLEncoding.EncodeToString(b) } // splitJWT splits a JWT into its parts without validation func splitJWT(token string) []string { return strings.Split(token, ".") } // extractRoles extracts roles from claims, checking multiple locations for IDP interoperability func extractRoles(claims struct { Roles []string `json:"roles"` Groups []string `json:"groups"` RealmAccess struct { Roles []string `json:"roles"` } `json:"realm_access"` }) []string { if len(claims.Roles) > 0 { return claims.Roles } if len(claims.Groups) > 0 { return claims.Groups } return claims.RealmAccess.Roles }