Implement database integration with SQLite and refactor user handling for OIDC authentication
This commit is contained in:
@ -4,12 +4,14 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"git.coopcloud.tech/wiki-cafe/member-console/internal/db"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/spf13/viper"
|
||||
@ -23,10 +25,11 @@ type Config struct {
|
||||
Verifier *oidc.IDTokenVerifier
|
||||
Provider *oidc.Provider
|
||||
SessionName string
|
||||
DB db.Querier
|
||||
}
|
||||
|
||||
// Setup initializes the auth configuration
|
||||
func Setup() (*Config, error) {
|
||||
func Setup(database db.Querier) (*Config, error) {
|
||||
// Create cookie store
|
||||
store := sessions.NewCookieStore(
|
||||
[]byte(viper.GetString("session-secret")),
|
||||
@ -61,6 +64,7 @@ func Setup() (*Config, error) {
|
||||
Provider: provider,
|
||||
Verifier: provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}),
|
||||
SessionName: "auth-session",
|
||||
DB: database,
|
||||
}
|
||||
|
||||
return config, nil
|
||||
@ -80,9 +84,15 @@ func (c *Config) Middleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip authentication for public routes
|
||||
if r.URL.Path == "/login" || r.URL.Path == "/callback" ||
|
||||
r.URL.Path == "/logout" || r.URL.Path == "/logout-callback" ||
|
||||
r.URL.Path == "/register" {
|
||||
publicPaths := map[string]bool{
|
||||
"/login": true,
|
||||
"/callback": true,
|
||||
"/logout": true,
|
||||
"/logout-callback": true,
|
||||
"/register": true,
|
||||
}
|
||||
|
||||
if publicPaths[r.URL.Path] {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
@ -97,7 +107,6 @@ func (c *Config) Middleware() func(http.Handler) http.Handler {
|
||||
http.Redirect(w, r, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@ -108,8 +117,8 @@ func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
state := generateRandomString(32)
|
||||
nonce := generateRandomString(32)
|
||||
codeVerifier := generateRandomString(32)
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
hashVal := sha256.Sum256([]byte(codeVerifier)) // Renamed from hash to hashVal to avoid conflict
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:])
|
||||
|
||||
session, _ := c.Store.Get(r, c.SessionName)
|
||||
session.Values["state"] = state
|
||||
@ -132,9 +141,10 @@ func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// CallbackHandler processes the OIDC callback
|
||||
func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.Background()
|
||||
ctx := r.Context() // Use request's context
|
||||
session, err := c.Store.Get(r, c.SessionName)
|
||||
if err != nil {
|
||||
log.Printf("CallbackHandler: Session error: %v", err)
|
||||
http.Error(w, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -142,7 +152,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Validate state
|
||||
state := r.URL.Query().Get("state")
|
||||
if savedState, ok := session.Values["state"].(string); !ok || state != savedState {
|
||||
log.Printf("State mismatch: got %s, expected %s", state, savedState)
|
||||
log.Printf("CallbackHandler: State mismatch: got %s, expected %s", state, savedState)
|
||||
http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@ -151,6 +161,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
codeVerifier, ok := session.Values["code_verifier"].(string)
|
||||
if !ok {
|
||||
log.Printf("CallbackHandler: No code verifier found in session")
|
||||
http.Error(w, "No code verifier found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@ -161,7 +172,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
oauth2.VerifierOption(codeVerifier),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Token exchange failed: %v", err)
|
||||
log.Printf("CallbackHandler: Token exchange failed: %v", err)
|
||||
http.Error(w, "Token exchange failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -169,53 +180,104 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify ID token
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
log.Printf("CallbackHandler: Missing ID token from token extra")
|
||||
http.Error(w, "Missing ID token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := c.Verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
log.Printf("ID token verification failed: %v", err)
|
||||
log.Printf("CallbackHandler: ID token verification failed: %v", err)
|
||||
http.Error(w, "Invalid ID token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify nonce
|
||||
var claims struct {
|
||||
var nonceClaims struct {
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
http.Error(w, "Could not parse claims", http.StatusInternalServerError)
|
||||
if err := idToken.Claims(&nonceClaims); err != nil {
|
||||
log.Printf("CallbackHandler: Could not parse nonce claims: %v", err)
|
||||
http.Error(w, "Could not parse claims for nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if claims.Nonce != session.Values["nonce"].(string) {
|
||||
if nonceClaims.Nonce != session.Values["nonce"].(string) {
|
||||
log.Printf("CallbackHandler: Invalid nonce")
|
||||
http.Error(w, "Invalid nonce", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate session
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = rawIDToken
|
||||
|
||||
// Store user info
|
||||
var userInfo struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"preferred_username"`
|
||||
// ---- Database Interaction & User Info Extraction ----
|
||||
var userInfoClaims struct {
|
||||
Subject string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
GivenName string `json:"given_name"`
|
||||
FamilyName string `json:"family_name"`
|
||||
}
|
||||
if err := idToken.Claims(&userInfo); err == nil {
|
||||
session.Values["email"] = userInfo.Email
|
||||
session.Values["name"] = userInfo.Name
|
||||
session.Values["username"] = userInfo.Username
|
||||
}
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
log.Printf("Error saving session: %v", err)
|
||||
http.Error(w, "Server error", http.StatusInternalServerError)
|
||||
if err := idToken.Claims(&userInfoClaims); err != nil {
|
||||
log.Printf("CallbackHandler: Error parsing user info claims: %v", err)
|
||||
http.Error(w, "Could not parse user info claims", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.DB.GetUserByOIDCSubject(ctx, userInfoClaims.Subject)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows { // User not found, create a new user
|
||||
newUserParams := db.CreateUserParams{
|
||||
OidcSubject: userInfoClaims.Subject,
|
||||
Username: userInfoClaims.PreferredUsername,
|
||||
Email: userInfoClaims.Email,
|
||||
}
|
||||
createdUser, errCreate := c.DB.CreateUser(ctx, newUserParams)
|
||||
if errCreate != nil {
|
||||
log.Printf("CallbackHandler: Error creating user in DB: %v", errCreate)
|
||||
http.Error(w, "Failed to create user record", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
user = createdUser
|
||||
log.Printf("CallbackHandler: New user created: %s (%s), DB ID: %d", user.Username, user.Email, user.ID)
|
||||
} else {
|
||||
log.Printf("CallbackHandler: Error fetching user by OIDC Subject ID '%s': %v", userInfoClaims.Subject, err)
|
||||
http.Error(w, "Database error while fetching user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else { // User found in DB
|
||||
if user.Username != userInfoClaims.PreferredUsername || user.Email != userInfoClaims.Email {
|
||||
updatedUser, errUpdate := c.DB.UpdateUser(ctx, db.UpdateUserParams{
|
||||
ID: user.ID,
|
||||
Username: userInfoClaims.PreferredUsername,
|
||||
Email: userInfoClaims.Email,
|
||||
})
|
||||
if errUpdate != nil {
|
||||
log.Printf("CallbackHandler: Error updating user %d: %v", user.ID, errUpdate)
|
||||
} else {
|
||||
user = updatedUser
|
||||
log.Printf("CallbackHandler: User details updated for %s, DB ID: %d", user.Email, user.ID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("CallbackHandler: User found: %s (%s), DB ID: %d", user.Username, user.Email, user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = rawIDToken
|
||||
session.Values["user_db_id"] = user.ID
|
||||
session.Values["oidc_subject"] = user.OidcSubject
|
||||
session.Values["email"] = user.Email
|
||||
session.Values["name"] = userInfoClaims.Name
|
||||
session.Values["username"] = user.Username
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
log.Printf("CallbackHandler: Error saving session: %v", err)
|
||||
http.Error(w, "Server error saving session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("CallbackHandler: User %s authenticated successfully. Redirecting to /", user.Email)
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
@ -310,35 +372,35 @@ func (c *Config) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// RegistrationHandler redirects to the OIDC registration page
|
||||
func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate random state, nonce, and code verifier for security
|
||||
state := generateRandomString(32)
|
||||
nonce := generateRandomString(32)
|
||||
codeVerifier := generateRandomString(32)
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
hashVal := sha256.Sum256([]byte(codeVerifier)) // Renamed from hash to hashVal
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:])
|
||||
|
||||
// Store state, nonce, and code verifier in session for verification
|
||||
session, _ := c.Store.Get(r, c.SessionName)
|
||||
session.Values["state"] = state
|
||||
session.Values["nonce"] = nonce
|
||||
session.Values["code_verifier"] = codeVerifier
|
||||
if err := session.Save(r, w); err != nil {
|
||||
log.Printf("Error saving session: %v", err)
|
||||
log.Printf("RegistrationHandler: Error saving session: %v", err)
|
||||
http.Error(w, "Server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Build the registration URL using the specified registrations endpoint
|
||||
baseURL := viper.GetString("oidc-idp-issuer-url")
|
||||
registrationURL, err := url.Parse(baseURL + "/protocol/openid-connect/registrations")
|
||||
registrationEndpoint := viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/registrations"
|
||||
// For some providers, you might need to use the authorization endpoint with a specific parameter like `kc_action=register`
|
||||
// e.g., registrationEndpoint := c.OAuthConfig.Endpoint.AuthURL
|
||||
// And then add q.Set("kc_action", "register") or similar. This depends on the OIDC provider.
|
||||
|
||||
parsedRegistrationURL, err := url.Parse(registrationEndpoint)
|
||||
if err != nil {
|
||||
log.Printf("Error parsing registration URL: %v", err)
|
||||
log.Printf("RegistrationHandler: Error parsing registration URL '%s': %v", registrationEndpoint, err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Add query parameters
|
||||
q := registrationURL.Query()
|
||||
q := parsedRegistrationURL.Query()
|
||||
q.Set("client_id", viper.GetString("oidc-sp-client-id"))
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", "openid email profile")
|
||||
@ -348,19 +410,17 @@ func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) {
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
|
||||
registrationURL.RawQuery = q.Encode()
|
||||
parsedRegistrationURL.RawQuery = q.Encode()
|
||||
|
||||
// Log for debugging
|
||||
log.Printf("Redirecting to registration URL: %s", registrationURL.String())
|
||||
|
||||
// Redirect to registration page
|
||||
http.Redirect(w, r, registrationURL.String(), http.StatusFound)
|
||||
log.Printf("RegistrationHandler: Redirecting to registration URL: %s", parsedRegistrationURL.String())
|
||||
http.Redirect(w, r, parsedRegistrationURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// Helper function to generate random strings
|
||||
func generateRandomString(n int) string {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
log.Fatalf("Failed to generate random string: %v", err)
|
||||
panic(err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
|
Reference in New Issue
Block a user