From a7d3822f9424b7351286242e5cf8cb28336b6a97 Mon Sep 17 00:00:00 2001 From: Christian Galo Date: Mon, 2 Jun 2025 19:35:11 -0500 Subject: [PATCH] Implement database integration with SQLite and refactor user handling for OIDC authentication --- README.md | 1 + cmd/start.go | 15 ++++ go.mod | 1 + go.sum | 2 + internal/auth/auth.go | 160 +++++++++++++++++++++++----------- internal/db/database.go | 49 +++++++++++ internal/db/models.go | 2 +- internal/db/queries/users.sql | 6 +- internal/db/schema.sql | 2 +- internal/db/users.sql.go | 38 ++++---- internal/server/server.go | 6 +- 11 files changed, 206 insertions(+), 76 deletions(-) create mode 100644 internal/db/database.go diff --git a/README.md b/README.md index 4baed0e..2b1c8bc 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,7 @@ Member console application for users to create, acccess, and manage their accoun ## Development notes: +- [ ] Integrate a database migration library. `internal/db/database.go`. - [ ] Make sure viper's 'env' key will work correctly in production - [ ] Should session-secret and csrf-secret be generated on startup instead of in the config file? They should be persisted nonetheless. Do they need to be rotated? - [ ] Add remove trailing slash middleware if we start using more custom handlers that don't end with a slash diff --git a/cmd/start.go b/cmd/start.go index cbf7c76..12adaf3 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -6,6 +6,7 @@ import ( "os" "strings" + "git.coopcloud.tech/wiki-cafe/member-console/internal/db" "git.coopcloud.tech/wiki-cafe/member-console/internal/logging" "git.coopcloud.tech/wiki-cafe/member-console/internal/server" "github.com/spf13/cobra" @@ -31,6 +32,17 @@ var startCmd = &cobra.Command{ // Store logger in context ctx = logging.WithContext(ctx, logger) + // Database Setup + dbDSN := viper.GetString("db-dsn") + database, err := db.NewDB(ctx, logger, dbDSN) + if err != nil { + logger.Error("failed to initialize database", slog.Any("error", err)) + os.Exit(1) + } + defer database.Close() + + // You'll pass 'database' (or your sqlc Querier) to your server/handlers. + // For example, by adding it to the server.Config // Validate and load configurations from files if specified configPairs := []struct { value string @@ -76,6 +88,7 @@ var startCmd = &cobra.Command{ Env: env, CSRFSecret: csrfSecret, Logger: logger, + DB: db.New(database), // Pass the sqlc Querier } // Start the server @@ -100,6 +113,7 @@ func init() { startCmd.Flags().String("session-secret-file", "", "Path to file containing session secret key") startCmd.Flags().String("csrf-secret", "", "Secret key for CSRF protection (must be exactly 32 bytes)") startCmd.Flags().String("csrf-secret-file", "", "Path to file containing CSRF secret key") + startCmd.Flags().String("db-dsn", "", "Database DSN (e.g., ./member_console.db or file:/path/to/data.db?_foreign_keys=on)") // Bind all flags to Viper viper.BindPFlags(startCmd.Flags()) @@ -107,6 +121,7 @@ func init() { // Set default values viper.SetDefault("port", "8080") viper.SetDefault("env", "development") + viper.SetDefault("db-dsn", "./member_console.db") // Add the command to the root command rootCmd.AddCommand(startCmd) diff --git a/go.mod b/go.mod index 4d20ca0..d86c36b 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/magiconair/properties v1.8.7 // indirect + github.com/mattn/go-sqlite3 v1.14.28 github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect diff --git a/go.sum b/go.sum index e6358fa..856b48a 100644 --- a/go.sum +++ b/go.sum @@ -43,6 +43,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A= +github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= diff --git a/internal/auth/auth.go b/internal/auth/auth.go index fe5d5e1..8e99805 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -4,12 +4,14 @@ import ( "context" "crypto/rand" "crypto/sha256" + "database/sql" "encoding/base64" "fmt" "log" "net/http" "net/url" + "git.coopcloud.tech/wiki-cafe/member-console/internal/db" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/sessions" "github.com/spf13/viper" @@ -23,10 +25,11 @@ type Config struct { Verifier *oidc.IDTokenVerifier Provider *oidc.Provider SessionName string + DB db.Querier } // Setup initializes the auth configuration -func Setup() (*Config, error) { +func Setup(database db.Querier) (*Config, error) { // Create cookie store store := sessions.NewCookieStore( []byte(viper.GetString("session-secret")), @@ -61,6 +64,7 @@ func Setup() (*Config, error) { Provider: provider, Verifier: provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}), SessionName: "auth-session", + DB: database, } return config, nil @@ -80,9 +84,15 @@ func (c *Config) Middleware() func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Skip authentication for public routes - if r.URL.Path == "/login" || r.URL.Path == "/callback" || - r.URL.Path == "/logout" || r.URL.Path == "/logout-callback" || - r.URL.Path == "/register" { + publicPaths := map[string]bool{ + "/login": true, + "/callback": true, + "/logout": true, + "/logout-callback": true, + "/register": true, + } + + if publicPaths[r.URL.Path] { next.ServeHTTP(w, r) return } @@ -97,7 +107,6 @@ func (c *Config) Middleware() func(http.Handler) http.Handler { http.Redirect(w, r, "/login", http.StatusFound) return } - next.ServeHTTP(w, r) }) } @@ -108,8 +117,8 @@ func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) { state := generateRandomString(32) nonce := generateRandomString(32) codeVerifier := generateRandomString(32) - hash := sha256.Sum256([]byte(codeVerifier)) - codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) + hashVal := sha256.Sum256([]byte(codeVerifier)) // Renamed from hash to hashVal to avoid conflict + codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:]) session, _ := c.Store.Get(r, c.SessionName) session.Values["state"] = state @@ -132,9 +141,10 @@ func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) { // CallbackHandler processes the OIDC callback func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) { - ctx := context.Background() + ctx := r.Context() // Use request's context session, err := c.Store.Get(r, c.SessionName) if err != nil { + log.Printf("CallbackHandler: Session error: %v", err) http.Error(w, "Session error", http.StatusInternalServerError) return } @@ -142,7 +152,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) { // Validate state state := r.URL.Query().Get("state") if savedState, ok := session.Values["state"].(string); !ok || state != savedState { - log.Printf("State mismatch: got %s, expected %s", state, savedState) + log.Printf("CallbackHandler: State mismatch: got %s, expected %s", state, savedState) http.Error(w, "Invalid state", http.StatusBadRequest) return } @@ -151,6 +161,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) { code := r.URL.Query().Get("code") codeVerifier, ok := session.Values["code_verifier"].(string) if !ok { + log.Printf("CallbackHandler: No code verifier found in session") http.Error(w, "No code verifier found", http.StatusBadRequest) return } @@ -161,7 +172,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) { oauth2.VerifierOption(codeVerifier), ) if err != nil { - log.Printf("Token exchange failed: %v", err) + log.Printf("CallbackHandler: Token exchange failed: %v", err) http.Error(w, "Token exchange failed", http.StatusInternalServerError) return } @@ -169,53 +180,104 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) { // Verify ID token rawIDToken, ok := token.Extra("id_token").(string) if !ok { + log.Printf("CallbackHandler: Missing ID token from token extra") http.Error(w, "Missing ID token", http.StatusInternalServerError) return } idToken, err := c.Verifier.Verify(ctx, rawIDToken) if err != nil { - log.Printf("ID token verification failed: %v", err) + log.Printf("CallbackHandler: ID token verification failed: %v", err) http.Error(w, "Invalid ID token", http.StatusInternalServerError) return } // Verify nonce - var claims struct { + var nonceClaims struct { Nonce string `json:"nonce"` } - if err := idToken.Claims(&claims); err != nil { - http.Error(w, "Could not parse claims", http.StatusInternalServerError) + if err := idToken.Claims(&nonceClaims); err != nil { + log.Printf("CallbackHandler: Could not parse nonce claims: %v", err) + http.Error(w, "Could not parse claims for nonce", http.StatusInternalServerError) return } - if claims.Nonce != session.Values["nonce"].(string) { + if nonceClaims.Nonce != session.Values["nonce"].(string) { + log.Printf("CallbackHandler: Invalid nonce") http.Error(w, "Invalid nonce", http.StatusBadRequest) return } - // Authenticate session - session.Values["authenticated"] = true - session.Values["id_token"] = rawIDToken - - // Store user info - var userInfo struct { - Email string `json:"email"` - Name string `json:"name"` - Username string `json:"preferred_username"` + // ---- Database Interaction & User Info Extraction ---- + var userInfoClaims struct { + Subject string `json:"sub"` + Email string `json:"email"` + EmailVerified bool `json:"email_verified"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + GivenName string `json:"given_name"` + FamilyName string `json:"family_name"` } - if err := idToken.Claims(&userInfo); err == nil { - session.Values["email"] = userInfo.Email - session.Values["name"] = userInfo.Name - session.Values["username"] = userInfo.Username - } - - if err := session.Save(r, w); err != nil { - log.Printf("Error saving session: %v", err) - http.Error(w, "Server error", http.StatusInternalServerError) + if err := idToken.Claims(&userInfoClaims); err != nil { + log.Printf("CallbackHandler: Error parsing user info claims: %v", err) + http.Error(w, "Could not parse user info claims", http.StatusInternalServerError) return } + user, err := c.DB.GetUserByOIDCSubject(ctx, userInfoClaims.Subject) + if err != nil { + if err == sql.ErrNoRows { // User not found, create a new user + newUserParams := db.CreateUserParams{ + OidcSubject: userInfoClaims.Subject, + Username: userInfoClaims.PreferredUsername, + Email: userInfoClaims.Email, + } + createdUser, errCreate := c.DB.CreateUser(ctx, newUserParams) + if errCreate != nil { + log.Printf("CallbackHandler: Error creating user in DB: %v", errCreate) + http.Error(w, "Failed to create user record", http.StatusInternalServerError) + return + } + user = createdUser + log.Printf("CallbackHandler: New user created: %s (%s), DB ID: %d", user.Username, user.Email, user.ID) + } else { + log.Printf("CallbackHandler: Error fetching user by OIDC Subject ID '%s': %v", userInfoClaims.Subject, err) + http.Error(w, "Database error while fetching user", http.StatusInternalServerError) + return + } + } else { // User found in DB + if user.Username != userInfoClaims.PreferredUsername || user.Email != userInfoClaims.Email { + updatedUser, errUpdate := c.DB.UpdateUser(ctx, db.UpdateUserParams{ + ID: user.ID, + Username: userInfoClaims.PreferredUsername, + Email: userInfoClaims.Email, + }) + if errUpdate != nil { + log.Printf("CallbackHandler: Error updating user %d: %v", user.ID, errUpdate) + } else { + user = updatedUser + log.Printf("CallbackHandler: User details updated for %s, DB ID: %d", user.Email, user.ID) + } + } else { + log.Printf("CallbackHandler: User found: %s (%s), DB ID: %d", user.Username, user.Email, user.ID) + } + } + + session.Values["authenticated"] = true + session.Values["id_token"] = rawIDToken + session.Values["user_db_id"] = user.ID + session.Values["oidc_subject"] = user.OidcSubject + session.Values["email"] = user.Email + session.Values["name"] = userInfoClaims.Name + session.Values["username"] = user.Username + + if err := session.Save(r, w); err != nil { + log.Printf("CallbackHandler: Error saving session: %v", err) + http.Error(w, "Server error saving session", http.StatusInternalServerError) + return + } + + log.Printf("CallbackHandler: User %s authenticated successfully. Redirecting to /", user.Email) http.Redirect(w, r, "/", http.StatusFound) } @@ -310,35 +372,35 @@ func (c *Config) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) { // RegistrationHandler redirects to the OIDC registration page func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) { - // Generate random state, nonce, and code verifier for security state := generateRandomString(32) nonce := generateRandomString(32) codeVerifier := generateRandomString(32) - hash := sha256.Sum256([]byte(codeVerifier)) - codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) + hashVal := sha256.Sum256([]byte(codeVerifier)) // Renamed from hash to hashVal + codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:]) - // Store state, nonce, and code verifier in session for verification session, _ := c.Store.Get(r, c.SessionName) session.Values["state"] = state session.Values["nonce"] = nonce session.Values["code_verifier"] = codeVerifier if err := session.Save(r, w); err != nil { - log.Printf("Error saving session: %v", err) + log.Printf("RegistrationHandler: Error saving session: %v", err) http.Error(w, "Server error", http.StatusInternalServerError) return } - // Build the registration URL using the specified registrations endpoint - baseURL := viper.GetString("oidc-idp-issuer-url") - registrationURL, err := url.Parse(baseURL + "/protocol/openid-connect/registrations") + registrationEndpoint := viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/registrations" + // For some providers, you might need to use the authorization endpoint with a specific parameter like `kc_action=register` + // e.g., registrationEndpoint := c.OAuthConfig.Endpoint.AuthURL + // And then add q.Set("kc_action", "register") or similar. This depends on the OIDC provider. + + parsedRegistrationURL, err := url.Parse(registrationEndpoint) if err != nil { - log.Printf("Error parsing registration URL: %v", err) + log.Printf("RegistrationHandler: Error parsing registration URL '%s': %v", registrationEndpoint, err) http.Error(w, "Internal Server Error", http.StatusInternalServerError) return } - // Add query parameters - q := registrationURL.Query() + q := parsedRegistrationURL.Query() q.Set("client_id", viper.GetString("oidc-sp-client-id")) q.Set("response_type", "code") q.Set("scope", "openid email profile") @@ -348,19 +410,17 @@ func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) { q.Set("code_challenge", codeChallenge) q.Set("code_challenge_method", "S256") - registrationURL.RawQuery = q.Encode() + parsedRegistrationURL.RawQuery = q.Encode() - // Log for debugging - log.Printf("Redirecting to registration URL: %s", registrationURL.String()) - - // Redirect to registration page - http.Redirect(w, r, registrationURL.String(), http.StatusFound) + log.Printf("RegistrationHandler: Redirecting to registration URL: %s", parsedRegistrationURL.String()) + http.Redirect(w, r, parsedRegistrationURL.String(), http.StatusFound) } // Helper function to generate random strings func generateRandomString(n int) string { b := make([]byte, n) if _, err := rand.Read(b); err != nil { + log.Fatalf("Failed to generate random string: %v", err) panic(err) } return base64.RawURLEncoding.EncodeToString(b) diff --git a/internal/db/database.go b/internal/db/database.go new file mode 100644 index 0000000..aebc31c --- /dev/null +++ b/internal/db/database.go @@ -0,0 +1,49 @@ +package db + +import ( + "context" + "database/sql" + _ "embed" // Required for go:embed + "fmt" + "log/slog" + "os" + "path/filepath" + + _ "github.com/mattn/go-sqlite3" // SQLite driver +) + +//go:embed schema.sql +var ddl string + +// DBConfig holds database configuration. +type DBConfig struct { + DSN string // Data Source Name for SQLite +} + +// NewDB initializes and returns a new database connection pool and runs migrations. +func NewDB(ctx context.Context, logger *slog.Logger, dsn string) (*sql.DB, error) { + // Ensure the directory for the SQLite file exists + dbDir := filepath.Dir(dsn) + if err := os.MkdirAll(dbDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create database directory %s: %w", dbDir, err) + } + + db, err := sql.Open("sqlite3", dsn+"?_foreign_keys=on") // Enable foreign key constraints + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + if err = db.PingContext(ctx); err != nil { + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + logger.Info("database connection established", slog.String("dsn", dsn)) + + // Execute schema. + if _, err := db.ExecContext(ctx, ddl); err != nil { + return nil, fmt.Errorf("failed to execute DDL: %w", err) + } + logger.Info("database schema applied") + + return db, nil +} diff --git a/internal/db/models.go b/internal/db/models.go index 678ec50..f9ae0f4 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -32,7 +32,7 @@ type Site struct { type User struct { ID int64 `json:"id"` OidcSubject string `json:"oidc_subject"` - Name string `json:"name"` + Username string `json:"username"` Email string `json:"email"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` diff --git a/internal/db/queries/users.sql b/internal/db/queries/users.sql index b44a10d..8c0a649 100644 --- a/internal/db/queries/users.sql +++ b/internal/db/queries/users.sql @@ -1,5 +1,5 @@ -- name: CreateUser :one -INSERT INTO users (oidc_subject, name, email) +INSERT INTO users (oidc_subject, username, email) VALUES (?, ?, ?) RETURNING *; @@ -13,7 +13,7 @@ WHERE oidc_subject = ?; -- name: UpdateUser :one UPDATE users -SET name = ?, email = ? +SET username = ?, email = ? WHERE id = ? RETURNING *; @@ -23,4 +23,4 @@ WHERE id = ?; -- name: ListUsers :many SELECT * FROM users -ORDER BY name; \ No newline at end of file +ORDER BY username; \ No newline at end of file diff --git a/internal/db/schema.sql b/internal/db/schema.sql index 08fb129..d14598d 100644 --- a/internal/db/schema.sql +++ b/internal/db/schema.sql @@ -3,7 +3,7 @@ CREATE TABLE users ( id INTEGER PRIMARY KEY AUTOINCREMENT, oidc_subject TEXT UNIQUE NOT NULL, - name TEXT NOT NULL, + username TEXT NOT NULL, email TEXT UNIQUE NOT NULL, created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP diff --git a/internal/db/users.sql.go b/internal/db/users.sql.go index 4ce7d2e..5dfdfad 100644 --- a/internal/db/users.sql.go +++ b/internal/db/users.sql.go @@ -10,24 +10,24 @@ import ( ) const createUser = `-- name: CreateUser :one -INSERT INTO users (oidc_subject, name, email) +INSERT INTO users (oidc_subject, username, email) VALUES (?, ?, ?) -RETURNING id, oidc_subject, name, email, created_at, updated_at +RETURNING id, oidc_subject, username, email, created_at, updated_at ` type CreateUserParams struct { OidcSubject string `json:"oidc_subject"` - Name string `json:"name"` + Username string `json:"username"` Email string `json:"email"` } func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) { - row := q.db.QueryRowContext(ctx, createUser, arg.OidcSubject, arg.Name, arg.Email) + row := q.db.QueryRowContext(ctx, createUser, arg.OidcSubject, arg.Username, arg.Email) var i User err := row.Scan( &i.ID, &i.OidcSubject, - &i.Name, + &i.Username, &i.Email, &i.CreatedAt, &i.UpdatedAt, @@ -46,7 +46,7 @@ func (q *Queries) DeleteUser(ctx context.Context, id int64) error { } const getUserByID = `-- name: GetUserByID :one -SELECT id, oidc_subject, name, email, created_at, updated_at FROM users +SELECT id, oidc_subject, username, email, created_at, updated_at FROM users WHERE id = ? ` @@ -56,7 +56,7 @@ func (q *Queries) GetUserByID(ctx context.Context, id int64) (User, error) { err := row.Scan( &i.ID, &i.OidcSubject, - &i.Name, + &i.Username, &i.Email, &i.CreatedAt, &i.UpdatedAt, @@ -65,7 +65,7 @@ func (q *Queries) GetUserByID(ctx context.Context, id int64) (User, error) { } const getUserByOIDCSubject = `-- name: GetUserByOIDCSubject :one -SELECT id, oidc_subject, name, email, created_at, updated_at FROM users +SELECT id, oidc_subject, username, email, created_at, updated_at FROM users WHERE oidc_subject = ? ` @@ -75,7 +75,7 @@ func (q *Queries) GetUserByOIDCSubject(ctx context.Context, oidcSubject string) err := row.Scan( &i.ID, &i.OidcSubject, - &i.Name, + &i.Username, &i.Email, &i.CreatedAt, &i.UpdatedAt, @@ -84,8 +84,8 @@ func (q *Queries) GetUserByOIDCSubject(ctx context.Context, oidcSubject string) } const listUsers = `-- name: ListUsers :many -SELECT id, oidc_subject, name, email, created_at, updated_at FROM users -ORDER BY name +SELECT id, oidc_subject, username, email, created_at, updated_at FROM users +ORDER BY username ` func (q *Queries) ListUsers(ctx context.Context) ([]User, error) { @@ -100,7 +100,7 @@ func (q *Queries) ListUsers(ctx context.Context) ([]User, error) { if err := rows.Scan( &i.ID, &i.OidcSubject, - &i.Name, + &i.Username, &i.Email, &i.CreatedAt, &i.UpdatedAt, @@ -120,24 +120,24 @@ func (q *Queries) ListUsers(ctx context.Context) ([]User, error) { const updateUser = `-- name: UpdateUser :one UPDATE users -SET name = ?, email = ? +SET username = ?, email = ? WHERE id = ? -RETURNING id, oidc_subject, name, email, created_at, updated_at +RETURNING id, oidc_subject, username, email, created_at, updated_at ` type UpdateUserParams struct { - Name string `json:"name"` - Email string `json:"email"` - ID int64 `json:"id"` + Username string `json:"username"` + Email string `json:"email"` + ID int64 `json:"id"` } func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) { - row := q.db.QueryRowContext(ctx, updateUser, arg.Name, arg.Email, arg.ID) + row := q.db.QueryRowContext(ctx, updateUser, arg.Username, arg.Email, arg.ID) var i User err := row.Scan( &i.ID, &i.OidcSubject, - &i.Name, + &i.Username, &i.Email, &i.CreatedAt, &i.UpdatedAt, diff --git a/internal/server/server.go b/internal/server/server.go index ec9e5ea..b5ea446 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,6 +10,7 @@ import ( "time" "git.coopcloud.tech/wiki-cafe/member-console/internal/auth" + "git.coopcloud.tech/wiki-cafe/member-console/internal/db" "git.coopcloud.tech/wiki-cafe/member-console/internal/embeds" "git.coopcloud.tech/wiki-cafe/member-console/internal/middleware" "github.com/rs/cors" @@ -22,6 +23,7 @@ type Config struct { Env string CSRFSecret string Logger *slog.Logger + DB db.Querier } // Start initializes and starts the HTTP server. @@ -29,8 +31,8 @@ func Start(ctx context.Context, cfg Config) error { // Create a new HTTP request router httpRequestRouter := http.NewServeMux() - // Set up authentication - authConfig, err := auth.Setup() + // Set up authentication. Pass the database connection to the auth package. + authConfig, err := auth.Setup(cfg.DB) if err != nil { cfg.Logger.Error("failed to set up authentication", slog.Any("error", err)) return err