Implement database integration with SQLite and refactor user handling for OIDC authentication

This commit is contained in:
2025-06-02 19:35:11 -05:00
parent 77b2e6c24e
commit a7d3822f94
11 changed files with 206 additions and 76 deletions

View File

@ -4,12 +4,14 @@ import (
"context"
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"fmt"
"log"
"net/http"
"net/url"
"git.coopcloud.tech/wiki-cafe/member-console/internal/db"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions"
"github.com/spf13/viper"
@ -23,10 +25,11 @@ type Config struct {
Verifier *oidc.IDTokenVerifier
Provider *oidc.Provider
SessionName string
DB db.Querier
}
// Setup initializes the auth configuration
func Setup() (*Config, error) {
func Setup(database db.Querier) (*Config, error) {
// Create cookie store
store := sessions.NewCookieStore(
[]byte(viper.GetString("session-secret")),
@ -61,6 +64,7 @@ func Setup() (*Config, error) {
Provider: provider,
Verifier: provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}),
SessionName: "auth-session",
DB: database,
}
return config, nil
@ -80,9 +84,15 @@ func (c *Config) Middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Skip authentication for public routes
if r.URL.Path == "/login" || r.URL.Path == "/callback" ||
r.URL.Path == "/logout" || r.URL.Path == "/logout-callback" ||
r.URL.Path == "/register" {
publicPaths := map[string]bool{
"/login": true,
"/callback": true,
"/logout": true,
"/logout-callback": true,
"/register": true,
}
if publicPaths[r.URL.Path] {
next.ServeHTTP(w, r)
return
}
@ -97,7 +107,6 @@ func (c *Config) Middleware() func(http.Handler) http.Handler {
http.Redirect(w, r, "/login", http.StatusFound)
return
}
next.ServeHTTP(w, r)
})
}
@ -108,8 +117,8 @@ func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) {
state := generateRandomString(32)
nonce := generateRandomString(32)
codeVerifier := generateRandomString(32)
hash := sha256.Sum256([]byte(codeVerifier))
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
hashVal := sha256.Sum256([]byte(codeVerifier)) // Renamed from hash to hashVal to avoid conflict
codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:])
session, _ := c.Store.Get(r, c.SessionName)
session.Values["state"] = state
@ -132,9 +141,10 @@ func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) {
// CallbackHandler processes the OIDC callback
func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
ctx := r.Context() // Use request's context
session, err := c.Store.Get(r, c.SessionName)
if err != nil {
log.Printf("CallbackHandler: Session error: %v", err)
http.Error(w, "Session error", http.StatusInternalServerError)
return
}
@ -142,7 +152,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
// Validate state
state := r.URL.Query().Get("state")
if savedState, ok := session.Values["state"].(string); !ok || state != savedState {
log.Printf("State mismatch: got %s, expected %s", state, savedState)
log.Printf("CallbackHandler: State mismatch: got %s, expected %s", state, savedState)
http.Error(w, "Invalid state", http.StatusBadRequest)
return
}
@ -151,6 +161,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
codeVerifier, ok := session.Values["code_verifier"].(string)
if !ok {
log.Printf("CallbackHandler: No code verifier found in session")
http.Error(w, "No code verifier found", http.StatusBadRequest)
return
}
@ -161,7 +172,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
oauth2.VerifierOption(codeVerifier),
)
if err != nil {
log.Printf("Token exchange failed: %v", err)
log.Printf("CallbackHandler: Token exchange failed: %v", err)
http.Error(w, "Token exchange failed", http.StatusInternalServerError)
return
}
@ -169,53 +180,104 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
// Verify ID token
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
log.Printf("CallbackHandler: Missing ID token from token extra")
http.Error(w, "Missing ID token", http.StatusInternalServerError)
return
}
idToken, err := c.Verifier.Verify(ctx, rawIDToken)
if err != nil {
log.Printf("ID token verification failed: %v", err)
log.Printf("CallbackHandler: ID token verification failed: %v", err)
http.Error(w, "Invalid ID token", http.StatusInternalServerError)
return
}
// Verify nonce
var claims struct {
var nonceClaims struct {
Nonce string `json:"nonce"`
}
if err := idToken.Claims(&claims); err != nil {
http.Error(w, "Could not parse claims", http.StatusInternalServerError)
if err := idToken.Claims(&nonceClaims); err != nil {
log.Printf("CallbackHandler: Could not parse nonce claims: %v", err)
http.Error(w, "Could not parse claims for nonce", http.StatusInternalServerError)
return
}
if claims.Nonce != session.Values["nonce"].(string) {
if nonceClaims.Nonce != session.Values["nonce"].(string) {
log.Printf("CallbackHandler: Invalid nonce")
http.Error(w, "Invalid nonce", http.StatusBadRequest)
return
}
// Authenticate session
session.Values["authenticated"] = true
session.Values["id_token"] = rawIDToken
// Store user info
var userInfo struct {
Email string `json:"email"`
Name string `json:"name"`
Username string `json:"preferred_username"`
// ---- Database Interaction & User Info Extraction ----
var userInfoClaims struct {
Subject string `json:"sub"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
}
if err := idToken.Claims(&userInfo); err == nil {
session.Values["email"] = userInfo.Email
session.Values["name"] = userInfo.Name
session.Values["username"] = userInfo.Username
}
if err := session.Save(r, w); err != nil {
log.Printf("Error saving session: %v", err)
http.Error(w, "Server error", http.StatusInternalServerError)
if err := idToken.Claims(&userInfoClaims); err != nil {
log.Printf("CallbackHandler: Error parsing user info claims: %v", err)
http.Error(w, "Could not parse user info claims", http.StatusInternalServerError)
return
}
user, err := c.DB.GetUserByOIDCSubject(ctx, userInfoClaims.Subject)
if err != nil {
if err == sql.ErrNoRows { // User not found, create a new user
newUserParams := db.CreateUserParams{
OidcSubject: userInfoClaims.Subject,
Username: userInfoClaims.PreferredUsername,
Email: userInfoClaims.Email,
}
createdUser, errCreate := c.DB.CreateUser(ctx, newUserParams)
if errCreate != nil {
log.Printf("CallbackHandler: Error creating user in DB: %v", errCreate)
http.Error(w, "Failed to create user record", http.StatusInternalServerError)
return
}
user = createdUser
log.Printf("CallbackHandler: New user created: %s (%s), DB ID: %d", user.Username, user.Email, user.ID)
} else {
log.Printf("CallbackHandler: Error fetching user by OIDC Subject ID '%s': %v", userInfoClaims.Subject, err)
http.Error(w, "Database error while fetching user", http.StatusInternalServerError)
return
}
} else { // User found in DB
if user.Username != userInfoClaims.PreferredUsername || user.Email != userInfoClaims.Email {
updatedUser, errUpdate := c.DB.UpdateUser(ctx, db.UpdateUserParams{
ID: user.ID,
Username: userInfoClaims.PreferredUsername,
Email: userInfoClaims.Email,
})
if errUpdate != nil {
log.Printf("CallbackHandler: Error updating user %d: %v", user.ID, errUpdate)
} else {
user = updatedUser
log.Printf("CallbackHandler: User details updated for %s, DB ID: %d", user.Email, user.ID)
}
} else {
log.Printf("CallbackHandler: User found: %s (%s), DB ID: %d", user.Username, user.Email, user.ID)
}
}
session.Values["authenticated"] = true
session.Values["id_token"] = rawIDToken
session.Values["user_db_id"] = user.ID
session.Values["oidc_subject"] = user.OidcSubject
session.Values["email"] = user.Email
session.Values["name"] = userInfoClaims.Name
session.Values["username"] = user.Username
if err := session.Save(r, w); err != nil {
log.Printf("CallbackHandler: Error saving session: %v", err)
http.Error(w, "Server error saving session", http.StatusInternalServerError)
return
}
log.Printf("CallbackHandler: User %s authenticated successfully. Redirecting to /", user.Email)
http.Redirect(w, r, "/", http.StatusFound)
}
@ -310,35 +372,35 @@ func (c *Config) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) {
// RegistrationHandler redirects to the OIDC registration page
func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) {
// Generate random state, nonce, and code verifier for security
state := generateRandomString(32)
nonce := generateRandomString(32)
codeVerifier := generateRandomString(32)
hash := sha256.Sum256([]byte(codeVerifier))
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
hashVal := sha256.Sum256([]byte(codeVerifier)) // Renamed from hash to hashVal
codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:])
// Store state, nonce, and code verifier in session for verification
session, _ := c.Store.Get(r, c.SessionName)
session.Values["state"] = state
session.Values["nonce"] = nonce
session.Values["code_verifier"] = codeVerifier
if err := session.Save(r, w); err != nil {
log.Printf("Error saving session: %v", err)
log.Printf("RegistrationHandler: Error saving session: %v", err)
http.Error(w, "Server error", http.StatusInternalServerError)
return
}
// Build the registration URL using the specified registrations endpoint
baseURL := viper.GetString("oidc-idp-issuer-url")
registrationURL, err := url.Parse(baseURL + "/protocol/openid-connect/registrations")
registrationEndpoint := viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/registrations"
// For some providers, you might need to use the authorization endpoint with a specific parameter like `kc_action=register`
// e.g., registrationEndpoint := c.OAuthConfig.Endpoint.AuthURL
// And then add q.Set("kc_action", "register") or similar. This depends on the OIDC provider.
parsedRegistrationURL, err := url.Parse(registrationEndpoint)
if err != nil {
log.Printf("Error parsing registration URL: %v", err)
log.Printf("RegistrationHandler: Error parsing registration URL '%s': %v", registrationEndpoint, err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
// Add query parameters
q := registrationURL.Query()
q := parsedRegistrationURL.Query()
q.Set("client_id", viper.GetString("oidc-sp-client-id"))
q.Set("response_type", "code")
q.Set("scope", "openid email profile")
@ -348,19 +410,17 @@ func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) {
q.Set("code_challenge", codeChallenge)
q.Set("code_challenge_method", "S256")
registrationURL.RawQuery = q.Encode()
parsedRegistrationURL.RawQuery = q.Encode()
// Log for debugging
log.Printf("Redirecting to registration URL: %s", registrationURL.String())
// Redirect to registration page
http.Redirect(w, r, registrationURL.String(), http.StatusFound)
log.Printf("RegistrationHandler: Redirecting to registration URL: %s", parsedRegistrationURL.String())
http.Redirect(w, r, parsedRegistrationURL.String(), http.StatusFound)
}
// Helper function to generate random strings
func generateRandomString(n int) string {
b := make([]byte, n)
if _, err := rand.Read(b); err != nil {
log.Fatalf("Failed to generate random string: %v", err)
panic(err)
}
return base64.RawURLEncoding.EncodeToString(b)