Implement database integration with SQLite and refactor user handling for OIDC authentication
This commit is contained in:
parent
77b2e6c24e
commit
a7d3822f94
@ -4,6 +4,7 @@ Member console application for users to create, acccess, and manage their accoun
|
||||
|
||||
## Development notes:
|
||||
|
||||
- [ ] Integrate a database migration library. `internal/db/database.go`.
|
||||
- [ ] Make sure viper's 'env' key will work correctly in production
|
||||
- [ ] Should session-secret and csrf-secret be generated on startup instead of in the config file? They should be persisted nonetheless. Do they need to be rotated?
|
||||
- [ ] Add remove trailing slash middleware if we start using more custom handlers that don't end with a slash
|
||||
|
15
cmd/start.go
15
cmd/start.go
@ -6,6 +6,7 @@ import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.coopcloud.tech/wiki-cafe/member-console/internal/db"
|
||||
"git.coopcloud.tech/wiki-cafe/member-console/internal/logging"
|
||||
"git.coopcloud.tech/wiki-cafe/member-console/internal/server"
|
||||
"github.com/spf13/cobra"
|
||||
@ -31,6 +32,17 @@ var startCmd = &cobra.Command{
|
||||
// Store logger in context
|
||||
ctx = logging.WithContext(ctx, logger)
|
||||
|
||||
// Database Setup
|
||||
dbDSN := viper.GetString("db-dsn")
|
||||
database, err := db.NewDB(ctx, logger, dbDSN)
|
||||
if err != nil {
|
||||
logger.Error("failed to initialize database", slog.Any("error", err))
|
||||
os.Exit(1)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
// You'll pass 'database' (or your sqlc Querier) to your server/handlers.
|
||||
// For example, by adding it to the server.Config
|
||||
// Validate and load configurations from files if specified
|
||||
configPairs := []struct {
|
||||
value string
|
||||
@ -76,6 +88,7 @@ var startCmd = &cobra.Command{
|
||||
Env: env,
|
||||
CSRFSecret: csrfSecret,
|
||||
Logger: logger,
|
||||
DB: db.New(database), // Pass the sqlc Querier
|
||||
}
|
||||
|
||||
// Start the server
|
||||
@ -100,6 +113,7 @@ func init() {
|
||||
startCmd.Flags().String("session-secret-file", "", "Path to file containing session secret key")
|
||||
startCmd.Flags().String("csrf-secret", "", "Secret key for CSRF protection (must be exactly 32 bytes)")
|
||||
startCmd.Flags().String("csrf-secret-file", "", "Path to file containing CSRF secret key")
|
||||
startCmd.Flags().String("db-dsn", "", "Database DSN (e.g., ./member_console.db or file:/path/to/data.db?_foreign_keys=on)")
|
||||
|
||||
// Bind all flags to Viper
|
||||
viper.BindPFlags(startCmd.Flags())
|
||||
@ -107,6 +121,7 @@ func init() {
|
||||
// Set default values
|
||||
viper.SetDefault("port", "8080")
|
||||
viper.SetDefault("env", "development")
|
||||
viper.SetDefault("db-dsn", "./member_console.db")
|
||||
|
||||
// Add the command to the root command
|
||||
rootCmd.AddCommand(startCmd)
|
||||
|
1
go.mod
1
go.mod
@ -25,6 +25,7 @@ require (
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/magiconair/properties v1.8.7 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.28
|
||||
github.com/mitchellh/mapstructure v1.5.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||
github.com/sagikazarmark/locafero v0.4.0 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -43,6 +43,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
|
||||
github.com/mattn/go-sqlite3 v1.14.28 h1:ThEiQrnbtumT+QMknw63Befp/ce/nUPgBPMlRFEum7A=
|
||||
github.com/mattn/go-sqlite3 v1.14.28/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
|
||||
|
@ -4,12 +4,14 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"git.coopcloud.tech/wiki-cafe/member-console/internal/db"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/spf13/viper"
|
||||
@ -23,10 +25,11 @@ type Config struct {
|
||||
Verifier *oidc.IDTokenVerifier
|
||||
Provider *oidc.Provider
|
||||
SessionName string
|
||||
DB db.Querier
|
||||
}
|
||||
|
||||
// Setup initializes the auth configuration
|
||||
func Setup() (*Config, error) {
|
||||
func Setup(database db.Querier) (*Config, error) {
|
||||
// Create cookie store
|
||||
store := sessions.NewCookieStore(
|
||||
[]byte(viper.GetString("session-secret")),
|
||||
@ -61,6 +64,7 @@ func Setup() (*Config, error) {
|
||||
Provider: provider,
|
||||
Verifier: provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}),
|
||||
SessionName: "auth-session",
|
||||
DB: database,
|
||||
}
|
||||
|
||||
return config, nil
|
||||
@ -80,9 +84,15 @@ func (c *Config) Middleware() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Skip authentication for public routes
|
||||
if r.URL.Path == "/login" || r.URL.Path == "/callback" ||
|
||||
r.URL.Path == "/logout" || r.URL.Path == "/logout-callback" ||
|
||||
r.URL.Path == "/register" {
|
||||
publicPaths := map[string]bool{
|
||||
"/login": true,
|
||||
"/callback": true,
|
||||
"/logout": true,
|
||||
"/logout-callback": true,
|
||||
"/register": true,
|
||||
}
|
||||
|
||||
if publicPaths[r.URL.Path] {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
@ -97,7 +107,6 @@ func (c *Config) Middleware() func(http.Handler) http.Handler {
|
||||
http.Redirect(w, r, "/login", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
@ -108,8 +117,8 @@ func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
state := generateRandomString(32)
|
||||
nonce := generateRandomString(32)
|
||||
codeVerifier := generateRandomString(32)
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
hashVal := sha256.Sum256([]byte(codeVerifier)) // Renamed from hash to hashVal to avoid conflict
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:])
|
||||
|
||||
session, _ := c.Store.Get(r, c.SessionName)
|
||||
session.Values["state"] = state
|
||||
@ -132,9 +141,10 @@ func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// CallbackHandler processes the OIDC callback
|
||||
func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := context.Background()
|
||||
ctx := r.Context() // Use request's context
|
||||
session, err := c.Store.Get(r, c.SessionName)
|
||||
if err != nil {
|
||||
log.Printf("CallbackHandler: Session error: %v", err)
|
||||
http.Error(w, "Session error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -142,7 +152,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Validate state
|
||||
state := r.URL.Query().Get("state")
|
||||
if savedState, ok := session.Values["state"].(string); !ok || state != savedState {
|
||||
log.Printf("State mismatch: got %s, expected %s", state, savedState)
|
||||
log.Printf("CallbackHandler: State mismatch: got %s, expected %s", state, savedState)
|
||||
http.Error(w, "Invalid state", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@ -151,6 +161,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
code := r.URL.Query().Get("code")
|
||||
codeVerifier, ok := session.Values["code_verifier"].(string)
|
||||
if !ok {
|
||||
log.Printf("CallbackHandler: No code verifier found in session")
|
||||
http.Error(w, "No code verifier found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
@ -161,7 +172,7 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
oauth2.VerifierOption(codeVerifier),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Token exchange failed: %v", err)
|
||||
log.Printf("CallbackHandler: Token exchange failed: %v", err)
|
||||
http.Error(w, "Token exchange failed", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@ -169,53 +180,104 @@ func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Verify ID token
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
log.Printf("CallbackHandler: Missing ID token from token extra")
|
||||
http.Error(w, "Missing ID token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := c.Verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
log.Printf("ID token verification failed: %v", err)
|
||||
log.Printf("CallbackHandler: ID token verification failed: %v", err)
|
||||
http.Error(w, "Invalid ID token", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify nonce
|
||||
var claims struct {
|
||||
var nonceClaims struct {
|
||||
Nonce string `json:"nonce"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
http.Error(w, "Could not parse claims", http.StatusInternalServerError)
|
||||
if err := idToken.Claims(&nonceClaims); err != nil {
|
||||
log.Printf("CallbackHandler: Could not parse nonce claims: %v", err)
|
||||
http.Error(w, "Could not parse claims for nonce", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if claims.Nonce != session.Values["nonce"].(string) {
|
||||
if nonceClaims.Nonce != session.Values["nonce"].(string) {
|
||||
log.Printf("CallbackHandler: Invalid nonce")
|
||||
http.Error(w, "Invalid nonce", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Authenticate session
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = rawIDToken
|
||||
|
||||
// Store user info
|
||||
var userInfo struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"preferred_username"`
|
||||
// ---- Database Interaction & User Info Extraction ----
|
||||
var userInfoClaims struct {
|
||||
Subject string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
EmailVerified bool `json:"email_verified"`
|
||||
Name string `json:"name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
GivenName string `json:"given_name"`
|
||||
FamilyName string `json:"family_name"`
|
||||
}
|
||||
if err := idToken.Claims(&userInfo); err == nil {
|
||||
session.Values["email"] = userInfo.Email
|
||||
session.Values["name"] = userInfo.Name
|
||||
session.Values["username"] = userInfo.Username
|
||||
}
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
log.Printf("Error saving session: %v", err)
|
||||
http.Error(w, "Server error", http.StatusInternalServerError)
|
||||
if err := idToken.Claims(&userInfoClaims); err != nil {
|
||||
log.Printf("CallbackHandler: Error parsing user info claims: %v", err)
|
||||
http.Error(w, "Could not parse user info claims", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.DB.GetUserByOIDCSubject(ctx, userInfoClaims.Subject)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows { // User not found, create a new user
|
||||
newUserParams := db.CreateUserParams{
|
||||
OidcSubject: userInfoClaims.Subject,
|
||||
Username: userInfoClaims.PreferredUsername,
|
||||
Email: userInfoClaims.Email,
|
||||
}
|
||||
createdUser, errCreate := c.DB.CreateUser(ctx, newUserParams)
|
||||
if errCreate != nil {
|
||||
log.Printf("CallbackHandler: Error creating user in DB: %v", errCreate)
|
||||
http.Error(w, "Failed to create user record", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
user = createdUser
|
||||
log.Printf("CallbackHandler: New user created: %s (%s), DB ID: %d", user.Username, user.Email, user.ID)
|
||||
} else {
|
||||
log.Printf("CallbackHandler: Error fetching user by OIDC Subject ID '%s': %v", userInfoClaims.Subject, err)
|
||||
http.Error(w, "Database error while fetching user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
} else { // User found in DB
|
||||
if user.Username != userInfoClaims.PreferredUsername || user.Email != userInfoClaims.Email {
|
||||
updatedUser, errUpdate := c.DB.UpdateUser(ctx, db.UpdateUserParams{
|
||||
ID: user.ID,
|
||||
Username: userInfoClaims.PreferredUsername,
|
||||
Email: userInfoClaims.Email,
|
||||
})
|
||||
if errUpdate != nil {
|
||||
log.Printf("CallbackHandler: Error updating user %d: %v", user.ID, errUpdate)
|
||||
} else {
|
||||
user = updatedUser
|
||||
log.Printf("CallbackHandler: User details updated for %s, DB ID: %d", user.Email, user.ID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("CallbackHandler: User found: %s (%s), DB ID: %d", user.Username, user.Email, user.ID)
|
||||
}
|
||||
}
|
||||
|
||||
session.Values["authenticated"] = true
|
||||
session.Values["id_token"] = rawIDToken
|
||||
session.Values["user_db_id"] = user.ID
|
||||
session.Values["oidc_subject"] = user.OidcSubject
|
||||
session.Values["email"] = user.Email
|
||||
session.Values["name"] = userInfoClaims.Name
|
||||
session.Values["username"] = user.Username
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
log.Printf("CallbackHandler: Error saving session: %v", err)
|
||||
http.Error(w, "Server error saving session", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("CallbackHandler: User %s authenticated successfully. Redirecting to /", user.Email)
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
@ -310,35 +372,35 @@ func (c *Config) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// RegistrationHandler redirects to the OIDC registration page
|
||||
func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Generate random state, nonce, and code verifier for security
|
||||
state := generateRandomString(32)
|
||||
nonce := generateRandomString(32)
|
||||
codeVerifier := generateRandomString(32)
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
hashVal := sha256.Sum256([]byte(codeVerifier)) // Renamed from hash to hashVal
|
||||
codeChallenge := base64.RawURLEncoding.EncodeToString(hashVal[:])
|
||||
|
||||
// Store state, nonce, and code verifier in session for verification
|
||||
session, _ := c.Store.Get(r, c.SessionName)
|
||||
session.Values["state"] = state
|
||||
session.Values["nonce"] = nonce
|
||||
session.Values["code_verifier"] = codeVerifier
|
||||
if err := session.Save(r, w); err != nil {
|
||||
log.Printf("Error saving session: %v", err)
|
||||
log.Printf("RegistrationHandler: Error saving session: %v", err)
|
||||
http.Error(w, "Server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Build the registration URL using the specified registrations endpoint
|
||||
baseURL := viper.GetString("oidc-idp-issuer-url")
|
||||
registrationURL, err := url.Parse(baseURL + "/protocol/openid-connect/registrations")
|
||||
registrationEndpoint := viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/registrations"
|
||||
// For some providers, you might need to use the authorization endpoint with a specific parameter like `kc_action=register`
|
||||
// e.g., registrationEndpoint := c.OAuthConfig.Endpoint.AuthURL
|
||||
// And then add q.Set("kc_action", "register") or similar. This depends on the OIDC provider.
|
||||
|
||||
parsedRegistrationURL, err := url.Parse(registrationEndpoint)
|
||||
if err != nil {
|
||||
log.Printf("Error parsing registration URL: %v", err)
|
||||
log.Printf("RegistrationHandler: Error parsing registration URL '%s': %v", registrationEndpoint, err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Add query parameters
|
||||
q := registrationURL.Query()
|
||||
q := parsedRegistrationURL.Query()
|
||||
q.Set("client_id", viper.GetString("oidc-sp-client-id"))
|
||||
q.Set("response_type", "code")
|
||||
q.Set("scope", "openid email profile")
|
||||
@ -348,19 +410,17 @@ func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) {
|
||||
q.Set("code_challenge", codeChallenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
|
||||
registrationURL.RawQuery = q.Encode()
|
||||
parsedRegistrationURL.RawQuery = q.Encode()
|
||||
|
||||
// Log for debugging
|
||||
log.Printf("Redirecting to registration URL: %s", registrationURL.String())
|
||||
|
||||
// Redirect to registration page
|
||||
http.Redirect(w, r, registrationURL.String(), http.StatusFound)
|
||||
log.Printf("RegistrationHandler: Redirecting to registration URL: %s", parsedRegistrationURL.String())
|
||||
http.Redirect(w, r, parsedRegistrationURL.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
// Helper function to generate random strings
|
||||
func generateRandomString(n int) string {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
log.Fatalf("Failed to generate random string: %v", err)
|
||||
panic(err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
|
49
internal/db/database.go
Normal file
49
internal/db/database.go
Normal file
@ -0,0 +1,49 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed" // Required for go:embed
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
//go:embed schema.sql
|
||||
var ddl string
|
||||
|
||||
// DBConfig holds database configuration.
|
||||
type DBConfig struct {
|
||||
DSN string // Data Source Name for SQLite
|
||||
}
|
||||
|
||||
// NewDB initializes and returns a new database connection pool and runs migrations.
|
||||
func NewDB(ctx context.Context, logger *slog.Logger, dsn string) (*sql.DB, error) {
|
||||
// Ensure the directory for the SQLite file exists
|
||||
dbDir := filepath.Dir(dsn)
|
||||
if err := os.MkdirAll(dbDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory %s: %w", dbDir, err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", dsn+"?_foreign_keys=on") // Enable foreign key constraints
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
if err = db.PingContext(ctx); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
logger.Info("database connection established", slog.String("dsn", dsn))
|
||||
|
||||
// Execute schema.
|
||||
if _, err := db.ExecContext(ctx, ddl); err != nil {
|
||||
return nil, fmt.Errorf("failed to execute DDL: %w", err)
|
||||
}
|
||||
logger.Info("database schema applied")
|
||||
|
||||
return db, nil
|
||||
}
|
@ -32,7 +32,7 @@ type Site struct {
|
||||
type User struct {
|
||||
ID int64 `json:"id"`
|
||||
OidcSubject string `json:"oidc_subject"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
@ -1,5 +1,5 @@
|
||||
-- name: CreateUser :one
|
||||
INSERT INTO users (oidc_subject, name, email)
|
||||
INSERT INTO users (oidc_subject, username, email)
|
||||
VALUES (?, ?, ?)
|
||||
RETURNING *;
|
||||
|
||||
@ -13,7 +13,7 @@ WHERE oidc_subject = ?;
|
||||
|
||||
-- name: UpdateUser :one
|
||||
UPDATE users
|
||||
SET name = ?, email = ?
|
||||
SET username = ?, email = ?
|
||||
WHERE id = ?
|
||||
RETURNING *;
|
||||
|
||||
@ -23,4 +23,4 @@ WHERE id = ?;
|
||||
|
||||
-- name: ListUsers :many
|
||||
SELECT * FROM users
|
||||
ORDER BY name;
|
||||
ORDER BY username;
|
@ -3,7 +3,7 @@
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
oidc_subject TEXT UNIQUE NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
username TEXT NOT NULL,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
|
@ -10,24 +10,24 @@ import (
|
||||
)
|
||||
|
||||
const createUser = `-- name: CreateUser :one
|
||||
INSERT INTO users (oidc_subject, name, email)
|
||||
INSERT INTO users (oidc_subject, username, email)
|
||||
VALUES (?, ?, ?)
|
||||
RETURNING id, oidc_subject, name, email, created_at, updated_at
|
||||
RETURNING id, oidc_subject, username, email, created_at, updated_at
|
||||
`
|
||||
|
||||
type CreateUserParams struct {
|
||||
OidcSubject string `json:"oidc_subject"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func (q *Queries) CreateUser(ctx context.Context, arg CreateUserParams) (User, error) {
|
||||
row := q.db.QueryRowContext(ctx, createUser, arg.OidcSubject, arg.Name, arg.Email)
|
||||
row := q.db.QueryRowContext(ctx, createUser, arg.OidcSubject, arg.Username, arg.Email)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OidcSubject,
|
||||
&i.Name,
|
||||
&i.Username,
|
||||
&i.Email,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
@ -46,7 +46,7 @@ func (q *Queries) DeleteUser(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
const getUserByID = `-- name: GetUserByID :one
|
||||
SELECT id, oidc_subject, name, email, created_at, updated_at FROM users
|
||||
SELECT id, oidc_subject, username, email, created_at, updated_at FROM users
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
@ -56,7 +56,7 @@ func (q *Queries) GetUserByID(ctx context.Context, id int64) (User, error) {
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OidcSubject,
|
||||
&i.Name,
|
||||
&i.Username,
|
||||
&i.Email,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
@ -65,7 +65,7 @@ func (q *Queries) GetUserByID(ctx context.Context, id int64) (User, error) {
|
||||
}
|
||||
|
||||
const getUserByOIDCSubject = `-- name: GetUserByOIDCSubject :one
|
||||
SELECT id, oidc_subject, name, email, created_at, updated_at FROM users
|
||||
SELECT id, oidc_subject, username, email, created_at, updated_at FROM users
|
||||
WHERE oidc_subject = ?
|
||||
`
|
||||
|
||||
@ -75,7 +75,7 @@ func (q *Queries) GetUserByOIDCSubject(ctx context.Context, oidcSubject string)
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OidcSubject,
|
||||
&i.Name,
|
||||
&i.Username,
|
||||
&i.Email,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
@ -84,8 +84,8 @@ func (q *Queries) GetUserByOIDCSubject(ctx context.Context, oidcSubject string)
|
||||
}
|
||||
|
||||
const listUsers = `-- name: ListUsers :many
|
||||
SELECT id, oidc_subject, name, email, created_at, updated_at FROM users
|
||||
ORDER BY name
|
||||
SELECT id, oidc_subject, username, email, created_at, updated_at FROM users
|
||||
ORDER BY username
|
||||
`
|
||||
|
||||
func (q *Queries) ListUsers(ctx context.Context) ([]User, error) {
|
||||
@ -100,7 +100,7 @@ func (q *Queries) ListUsers(ctx context.Context) ([]User, error) {
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.OidcSubject,
|
||||
&i.Name,
|
||||
&i.Username,
|
||||
&i.Email,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
@ -120,24 +120,24 @@ func (q *Queries) ListUsers(ctx context.Context) ([]User, error) {
|
||||
|
||||
const updateUser = `-- name: UpdateUser :one
|
||||
UPDATE users
|
||||
SET name = ?, email = ?
|
||||
SET username = ?, email = ?
|
||||
WHERE id = ?
|
||||
RETURNING id, oidc_subject, name, email, created_at, updated_at
|
||||
RETURNING id, oidc_subject, username, email, created_at, updated_at
|
||||
`
|
||||
|
||||
type UpdateUserParams struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
ID int64 `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
ID int64 `json:"id"`
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateUser(ctx context.Context, arg UpdateUserParams) (User, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateUser, arg.Name, arg.Email, arg.ID)
|
||||
row := q.db.QueryRowContext(ctx, updateUser, arg.Username, arg.Email, arg.ID)
|
||||
var i User
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OidcSubject,
|
||||
&i.Name,
|
||||
&i.Username,
|
||||
&i.Email,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"git.coopcloud.tech/wiki-cafe/member-console/internal/auth"
|
||||
"git.coopcloud.tech/wiki-cafe/member-console/internal/db"
|
||||
"git.coopcloud.tech/wiki-cafe/member-console/internal/embeds"
|
||||
"git.coopcloud.tech/wiki-cafe/member-console/internal/middleware"
|
||||
"github.com/rs/cors"
|
||||
@ -22,6 +23,7 @@ type Config struct {
|
||||
Env string
|
||||
CSRFSecret string
|
||||
Logger *slog.Logger
|
||||
DB db.Querier
|
||||
}
|
||||
|
||||
// Start initializes and starts the HTTP server.
|
||||
@ -29,8 +31,8 @@ func Start(ctx context.Context, cfg Config) error {
|
||||
// Create a new HTTP request router
|
||||
httpRequestRouter := http.NewServeMux()
|
||||
|
||||
// Set up authentication
|
||||
authConfig, err := auth.Setup()
|
||||
// Set up authentication. Pass the database connection to the auth package.
|
||||
authConfig, err := auth.Setup(cfg.DB)
|
||||
if err != nil {
|
||||
cfg.Logger.Error("failed to set up authentication", slog.Any("error", err))
|
||||
return err
|
||||
|
Loading…
x
Reference in New Issue
Block a user