Files
member-console/internal/auth/auth.go

623 lines
21 KiB
Go

package auth
import (
"context"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"time"
"git.coopcloud.tech/wiki-cafe/member-console/internal/identity"
"git.coopcloud.tech/wiki-cafe/member-console/internal/organization"
"git.coopcloud.tech/wiki-cafe/member-console/internal/provisioning"
"github.com/alexedwards/scs/redisstore"
"github.com/alexedwards/scs/v2"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gomodule/redigo/redis"
"github.com/spf13/viper"
"golang.org/x/oauth2"
)
// Config holds all auth-related configuration
type Config struct {
SessionManager *scs.SessionManager
OAuthConfig *oauth2.Config
Verifier *oidc.IDTokenVerifier
Provider *oidc.Provider
Database *sql.DB // Raw DB for transactions (auto-provisioning)
IdentityQ identity.Querier // Identity module queries
OrgQ organization.Querier // Organization module queries
}
// Setup initializes the auth configuration
func Setup(database *sql.DB, identityQ identity.Querier, orgQ organization.Querier) (*Config, error) {
// Create Redis pool for Valkey
pool := &redis.Pool{
MaxIdle: 10,
IdleTimeout: 240 * time.Second,
Dial: func() (redis.Conn, error) {
return redis.Dial("tcp", viper.GetString("valkey-addr"))
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
if time.Since(t) < time.Minute {
return nil
}
_, err := c.Do("PING")
return err
},
}
// Test connection
conn := pool.Get()
defer conn.Close()
if _, err := conn.Do("PING"); err != nil {
return nil, fmt.Errorf("failed to connect to Valkey: %w", err)
}
log.Printf("Connected to Valkey at %s", viper.GetString("valkey-addr"))
// Create session manager
sessionManager := scs.New()
sessionManager.Store = redisstore.New(pool)
sessionManager.Lifetime = 7 * 24 * time.Hour // 1 week
sessionManager.Cookie.Name = "session"
sessionManager.Cookie.Path = "/"
sessionManager.Cookie.HttpOnly = true
sessionManager.Cookie.Secure = viper.GetString("env") == "production"
sessionManager.Cookie.SameSite = http.SameSiteLaxMode
// Initialize OIDC provider
ctx := context.Background()
provider, err := oidc.NewProvider(ctx, viper.GetString("oidc-idp-issuer-url"))
if err != nil {
return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
}
// Create OAuth2 config
oauthConfig := &oauth2.Config{
ClientID: viper.GetString("oidc-sp-client-id"),
ClientSecret: viper.GetString("oidc-sp-client-secret"),
RedirectURL: viper.GetString("base-url") + "/callback",
Endpoint: provider.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
// Create auth config
config := &Config{
SessionManager: sessionManager,
OAuthConfig: oauthConfig,
Provider: provider,
Verifier: provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}),
Database: database,
IdentityQ: identityQ,
OrgQ: orgQ,
}
return config, nil
}
// RegisterHandlers adds all auth-related handlers to the router
func (c *Config) RegisterHandlers(mux *http.ServeMux) {
mux.HandleFunc("/login", c.LoginHandler)
mux.HandleFunc("/callback", c.CallbackHandler)
mux.HandleFunc("/logout", c.LogoutHandler)
mux.HandleFunc("/logout-callback", c.LogoutCallbackHandler)
mux.HandleFunc("/register", c.RegistrationHandler)
}
// Middleware returns an auth middleware function
func (c *Config) Middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip authentication for public routes
publicPaths := map[string]bool{
"/login": true,
"/callback": true,
"/logout": true,
"/logout-callback": true,
"/register": true,
"/favicon.ico": true, // Browser-fired in parallel with page loads; must not bounce through /login
"/webhooks/stripe": true, // Webhooks use signature verification, not session auth
}
if publicPaths[r.URL.Path] {
next.ServeHTTP(w, r)
return
}
// Skip authentication for static assets (public, no session needed)
if strings.HasPrefix(r.URL.Path, "/static/") {
next.ServeHTTP(w, r)
return
}
// Check if authenticated
if !c.SessionManager.GetBool(r.Context(), sessionKeyAuthenticated) {
http.Redirect(w, r, "/login", http.StatusFound)
return
}
next.ServeHTTP(w, r)
})
}
}
// LoginHandler initiates the OIDC authentication flow
func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) {
state := generateRandomString(32)
nonce := generateRandomString(32)
codeVerifier := generateRandomString(32)
hashVal := sha256.Sum256([]byte(codeVerifier))
codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:])
// Store OIDC flow data in session
c.SessionManager.Put(r.Context(), sessionKeyState, state)
c.SessionManager.Put(r.Context(), sessionKeyNonce, nonce)
c.SessionManager.Put(r.Context(), sessionKeyCodeVerifier, codeVerifier)
authURL := c.OAuthConfig.AuthCodeURL(
state,
oidc.Nonce(nonce),
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
)
http.Redirect(w, r, authURL, http.StatusFound)
}
// CallbackHandler processes the OIDC callback
func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Validate state
state := r.URL.Query().Get("state")
savedState := c.SessionManager.GetString(ctx, sessionKeyState)
if state != savedState {
log.Printf("CallbackHandler: State mismatch: got %s, expected %s", state, savedState)
http.Error(w, "Invalid state", http.StatusBadRequest)
return
}
// Exchange code
code := r.URL.Query().Get("code")
codeVerifier := c.SessionManager.GetString(ctx, sessionKeyCodeVerifier)
if codeVerifier == "" {
log.Printf("CallbackHandler: No code verifier found in session")
http.Error(w, "No code verifier found", http.StatusBadRequest)
return
}
token, err := c.OAuthConfig.Exchange(
ctx,
code,
oauth2.VerifierOption(codeVerifier),
)
if err != nil {
log.Printf("CallbackHandler: Token exchange failed: %v", err)
http.Error(w, "Token exchange failed", http.StatusInternalServerError)
return
}
// Verify ID token
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
log.Printf("CallbackHandler: Missing ID token from token extra")
http.Error(w, "Missing ID token", http.StatusInternalServerError)
return
}
idToken, err := c.Verifier.Verify(ctx, rawIDToken)
if err != nil {
log.Printf("CallbackHandler: ID token verification failed: %v", err)
http.Error(w, "Invalid ID token", http.StatusInternalServerError)
return
}
// Verify nonce
var nonceClaims struct {
Nonce string `json:"nonce"`
}
if err := idToken.Claims(&nonceClaims); err != nil {
log.Printf("CallbackHandler: Could not parse nonce claims: %v", err)
http.Error(w, "Could not parse claims for nonce", http.StatusInternalServerError)
return
}
savedNonce := c.SessionManager.GetString(ctx, sessionKeyNonce)
if nonceClaims.Nonce != savedNonce {
log.Printf("CallbackHandler: Invalid nonce")
http.Error(w, "Invalid nonce", http.StatusBadRequest)
return
}
// ---- Database Interaction & User Info Extraction ----
var userInfoClaims struct {
Subject string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
}
if err := idToken.Claims(&userInfoClaims); err != nil {
log.Printf("CallbackHandler: Error parsing user info claims: %v", err)
http.Error(w, "Could not parse user info claims", http.StatusInternalServerError)
return
}
// Session variables to populate
var personID, orgID, workspaceID string
user, err := c.IdentityQ.GetUserByOIDCSubject(ctx, userInfoClaims.Subject)
if err != nil {
if err == sql.ErrNoRows {
// New user — auto-provision all governance structures
result, errProv := provisioning.AutoProvision(ctx, c.Database, provisioning.OIDCClaims{
Subject: userInfoClaims.Subject,
Email: userInfoClaims.Email,
EmailVerified: userInfoClaims.EmailVerified,
Name: userInfoClaims.Name,
PreferredUsername: userInfoClaims.PreferredUsername,
})
if errProv != nil {
log.Printf("CallbackHandler: Error auto-provisioning user: %v", errProv)
http.Error(w, "Failed to create user record", http.StatusInternalServerError)
return
}
user = result.User
personID = result.Person.PersonID
orgID = result.Org.OrgID
workspaceID = result.Workspace.WorkspaceID
log.Printf("CallbackHandler: New user provisioned: %s (person=%s, org=%s)", userInfoClaims.Email, personID, orgID)
} else {
log.Printf("CallbackHandler: Error fetching user by OIDC Subject ID '%s': %v", userInfoClaims.Subject, err)
http.Error(w, "Database error while fetching user", http.StatusInternalServerError)
return
}
} else {
// Returning user — load existing governance records and update login
_, errLogin := c.IdentityQ.UpdateUserLogin(ctx, identity.UpdateUserLoginParams{
LastLoginAt: sql.NullTime{Time: time.Now(), Valid: true},
LastLoginIp: sql.NullString{String: r.RemoteAddr, Valid: true},
UserID: user.UserID,
})
if errLogin != nil {
log.Printf("CallbackHandler: Error updating user login: %v", errLogin)
}
// Load person
person, errPerson := c.IdentityQ.GetPersonByUserID(ctx, user.UserID)
if errPerson != nil {
log.Printf("CallbackHandler: Error loading person for user %s: %v", user.UserID, errPerson)
http.Error(w, "Failed to load user profile", http.StatusInternalServerError)
return
}
personID = person.PersonID
// Update person if OIDC claims changed
if person.DisplayName != userInfoClaims.Name || person.PrimaryEmail != userInfoClaims.Email {
_, errUpdate := c.IdentityQ.UpdatePerson(ctx, identity.UpdatePersonParams{
DisplayName: userInfoClaims.Name,
PrimaryEmail: userInfoClaims.Email,
PrimaryEmailVerified: userInfoClaims.EmailVerified,
PersonID: person.PersonID,
})
if errUpdate != nil {
log.Printf("CallbackHandler: Error updating person %s: %v", person.PersonID, errUpdate)
}
}
// Load personal org (first org owned by this person)
orgs, errOrgs := c.OrgQ.GetOrganizationsByOwner(ctx, personID)
if errOrgs != nil || len(orgs) == 0 {
log.Printf("CallbackHandler: Error loading org for person %s: %v", personID, errOrgs)
http.Error(w, "Failed to load organization", http.StatusInternalServerError)
return
}
orgID = orgs[0].OrgID
// Load default workspace
workspaces, errWS := c.OrgQ.GetWorkspacesByOrgID(ctx, orgID)
if errWS != nil || len(workspaces) == 0 {
log.Printf("CallbackHandler: Error loading workspace for org %s: %v", orgID, errWS)
http.Error(w, "Failed to load workspace", http.StatusInternalServerError)
return
}
workspaceID = workspaces[0].WorkspaceID
log.Printf("CallbackHandler: Returning user: %s (person=%s, org=%s)", userInfoClaims.Email, personID, orgID)
}
// ---- Extract roles from access token (IDP-agnostic) ----
var rolesClaims struct {
Roles []string `json:"roles"`
Groups []string `json:"groups"`
RealmAccess struct {
Roles []string `json:"roles"`
} `json:"realm_access"`
}
accessToken := token.AccessToken
if accessToken != "" {
if parts := splitJWT(accessToken); len(parts) == 3 {
if payload, err := base64.RawURLEncoding.DecodeString(parts[1]); err == nil {
if err := json.Unmarshal(payload, &rolesClaims); err != nil {
log.Printf("CallbackHandler: Could not parse roles from access token: %v", err)
}
}
}
}
roles := extractRoles(rolesClaims)
log.Printf("CallbackHandler: User %s has IdP roles: %v", userInfoClaims.Email, roles)
// Clear OIDC flow data and set authenticated session data
c.SessionManager.Remove(ctx, sessionKeyState)
c.SessionManager.Remove(ctx, sessionKeyNonce)
c.SessionManager.Remove(ctx, sessionKeyCodeVerifier)
c.SessionManager.Put(ctx, sessionKeyAuthenticated, true)
c.SessionManager.Put(ctx, sessionKeyIDToken, rawIDToken)
c.SessionManager.Put(ctx, sessionKeyPersonID, personID)
c.SessionManager.Put(ctx, sessionKeyOrgID, orgID)
c.SessionManager.Put(ctx, sessionKeyWorkspaceID, workspaceID)
c.SessionManager.Put(ctx, sessionKeyOIDCSubject, user.OidcSubject)
c.SessionManager.Put(ctx, sessionKeyEmail, userInfoClaims.Email)
c.SessionManager.Put(ctx, sessionKeyName, userInfoClaims.Name)
c.SessionManager.Put(ctx, sessionKeyUsername, userInfoClaims.PreferredUsername)
c.SessionManager.Put(ctx, sessionKeyRoles, roles)
log.Printf("CallbackHandler: User %s authenticated successfully. Redirecting to /", userInfoClaims.Email)
http.Redirect(w, r, "/", http.StatusFound)
}
// LogoutHandler handles user logout
func (c *Config) LogoutHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Generate logout state for verification
state := generateRandomString(32)
c.SessionManager.Put(ctx, sessionKeyLogoutState, state)
// Build logout URL
keycloakLogoutURL, err := url.Parse(viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/logout")
if err != nil {
log.Printf("Error parsing logout URL: %v", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Add query parameters
q := keycloakLogoutURL.Query()
q.Set("post_logout_redirect_uri", viper.GetString("base-url")+"/logout-callback")
q.Set("client_id", viper.GetString("oidc-sp-client-id"))
q.Set("state", state)
// Add id_token_hint if available
if idToken := c.SessionManager.GetString(ctx, sessionKeyIDToken); idToken != "" {
q.Set("id_token_hint", idToken)
}
keycloakLogoutURL.RawQuery = q.Encode()
log.Printf("Redirecting to logout URL: %s", keycloakLogoutURL.String())
http.Redirect(w, r, keycloakLogoutURL.String(), http.StatusFound)
}
// LogoutCallbackHandler handles the redirect after Keycloak logout
func (c *Config) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Verify state parameter
returnedState := r.URL.Query().Get("state")
savedState := c.SessionManager.GetString(ctx, sessionKeyLogoutState)
if savedState == "" {
log.Printf("No logout state found in session")
http.Error(w, "Invalid session state", http.StatusBadRequest)
return
}
if returnedState != savedState {
log.Printf("State mismatch: %s != %s", returnedState, savedState)
http.Error(w, "Invalid state", http.StatusBadRequest)
return
}
// Destroy the session
if err := c.SessionManager.Destroy(ctx); err != nil {
log.Printf("Error destroying session: %v", err)
http.Error(w, "Server error", http.StatusInternalServerError)
return
}
// Redirect to login page
http.Redirect(w, r, "/login", http.StatusFound)
}
// RegistrationHandler redirects to the OIDC registration page
func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
state := generateRandomString(32)
nonce := generateRandomString(32)
codeVerifier := generateRandomString(32)
hashVal := sha256.Sum256([]byte(codeVerifier))
codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:])
c.SessionManager.Put(ctx, sessionKeyState, state)
c.SessionManager.Put(ctx, sessionKeyNonce, nonce)
c.SessionManager.Put(ctx, sessionKeyCodeVerifier, codeVerifier)
registrationEndpoint := viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/registrations"
parsedRegistrationURL, err := url.Parse(registrationEndpoint)
if err != nil {
log.Printf("RegistrationHandler: Error parsing registration URL '%s': %v", registrationEndpoint, err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
q := parsedRegistrationURL.Query()
q.Set("client_id", viper.GetString("oidc-sp-client-id"))
q.Set("response_type", "code")
q.Set("scope", "openid email profile")
q.Set("redirect_uri", viper.GetString("base-url")+"/callback")
q.Set("state", state)
q.Set("nonce", nonce)
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
parsedRegistrationURL.RawQuery = q.Encode()
log.Printf("RegistrationHandler: Redirecting to registration URL: %s", parsedRegistrationURL.String())
http.Redirect(w, r, parsedRegistrationURL.String(), http.StatusFound)
}
// Session key constants - single source of truth for session data
const (
sessionKeyAuthenticated = "authenticated"
sessionKeyIDToken = "id_token"
sessionKeyPersonID = "person_id"
sessionKeyOrgID = "org_id"
sessionKeyWorkspaceID = "workspace_id"
sessionKeyOIDCSubject = "oidc_subject"
sessionKeyEmail = "email"
sessionKeyName = "name"
sessionKeyUsername = "username"
sessionKeyRoles = "roles"
// OIDC flow keys (temporary)
sessionKeyState = "state"
sessionKeyNonce = "nonce"
sessionKeyCodeVerifier = "code_verifier"
sessionKeyLogoutState = "logout_state"
)
// UserSession contains the authenticated user's session data.
// This provides type-safe access to session values.
type UserSession struct {
PersonID string
OrgID string
WorkspaceID string
OIDCSubject string
Email string
Name string
Username string
Roles []string
}
// GetUserSession retrieves the authenticated user's session data.
// Returns nil if the user is not authenticated.
func (c *Config) GetUserSession(ctx context.Context) *UserSession {
if !c.SessionManager.GetBool(ctx, sessionKeyAuthenticated) {
return nil
}
return &UserSession{
PersonID: c.SessionManager.GetString(ctx, sessionKeyPersonID),
OrgID: c.SessionManager.GetString(ctx, sessionKeyOrgID),
WorkspaceID: c.SessionManager.GetString(ctx, sessionKeyWorkspaceID),
OIDCSubject: c.SessionManager.GetString(ctx, sessionKeyOIDCSubject),
Email: c.SessionManager.GetString(ctx, sessionKeyEmail),
Name: c.SessionManager.GetString(ctx, sessionKeyName),
Username: c.SessionManager.GetString(ctx, sessionKeyUsername),
Roles: c.getRoles(ctx),
}
}
// GetPersonID returns the person UUID of the authenticated user.
// Returns empty string if the user is not authenticated.
func (c *Config) GetPersonID(ctx context.Context) string {
return c.SessionManager.GetString(ctx, sessionKeyPersonID)
}
// GetOrgID returns the active organization UUID.
func (c *Config) GetOrgID(ctx context.Context) string {
return c.SessionManager.GetString(ctx, sessionKeyOrgID)
}
// GetWorkspaceID returns the active workspace UUID.
func (c *Config) GetWorkspaceID(ctx context.Context) string {
return c.SessionManager.GetString(ctx, sessionKeyWorkspaceID)
}
// GetUserEmail returns the email of the authenticated user.
func (c *Config) GetUserEmail(ctx context.Context) string {
return c.SessionManager.GetString(ctx, sessionKeyEmail)
}
// GetUserName returns the display name of the authenticated user.
func (c *Config) GetUserName(ctx context.Context) string {
return c.SessionManager.GetString(ctx, sessionKeyName)
}
// GetUsername returns the username of the authenticated user.
func (c *Config) GetUsername(ctx context.Context) string {
return c.SessionManager.GetString(ctx, sessionKeyUsername)
}
// IsAuthenticated returns true if the user has an active session.
func (c *Config) IsAuthenticated(ctx context.Context) bool {
return c.SessionManager.GetBool(ctx, sessionKeyAuthenticated)
}
// HasRole checks if the current user has the specified role.
func (c *Config) HasRole(r *http.Request, role string) bool {
roles := c.getRoles(r.Context())
for _, r := range roles {
if r == role {
return true
}
}
return false
}
// getRoles retrieves the roles from session.
func (c *Config) getRoles(ctx context.Context) []string {
roles, ok := c.SessionManager.Get(ctx, sessionKeyRoles).([]string)
if !ok {
return nil
}
return roles
}
// Helper function to generate random strings
func generateRandomString(n int) string {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
log.Fatalf("Failed to generate random string: %v", err)
panic(err)
}
return base64.RawURLEncoding.EncodeToString(b)
}
// splitJWT splits a JWT into its parts without validation
func splitJWT(token string) []string {
return strings.Split(token, ".")
}
// extractRoles extracts roles from claims, checking multiple locations for IDP interoperability
func extractRoles(claims struct {
Roles []string `json:"roles"`
Groups []string `json:"groups"`
RealmAccess struct {
Roles []string `json:"roles"`
} `json:"realm_access"`
}) []string {
if len(claims.Roles) > 0 {
return claims.Roles
}
if len(claims.Groups) > 0 {
return claims.Groups
}
return claims.RealmAccess.Roles
}