368 lines
11 KiB
Go
368 lines
11 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"net/url"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/gorilla/sessions"
|
|
"github.com/spf13/viper"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
// Config holds all auth-related configuration
|
|
type Config struct {
|
|
Store *sessions.CookieStore
|
|
OAuthConfig *oauth2.Config
|
|
Verifier *oidc.IDTokenVerifier
|
|
Provider *oidc.Provider
|
|
SessionName string
|
|
}
|
|
|
|
// Setup initializes the auth configuration
|
|
func Setup() (*Config, error) {
|
|
// Create cookie store
|
|
store := sessions.NewCookieStore(
|
|
[]byte(viper.GetString("session-secret")),
|
|
)
|
|
store.Options = &sessions.Options{
|
|
HttpOnly: true,
|
|
Secure: viper.GetString("env") == "production",
|
|
SameSite: http.SameSiteLaxMode,
|
|
MaxAge: 86400 * 7, // 1 week
|
|
}
|
|
|
|
// Initialize OIDC provider
|
|
ctx := context.Background()
|
|
provider, err := oidc.NewProvider(ctx, viper.GetString("oidc-idp-issuer-url"))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
|
|
}
|
|
|
|
// Create OAuth2 config
|
|
oauthConfig := &oauth2.Config{
|
|
ClientID: viper.GetString("oidc-sp-client-id"),
|
|
ClientSecret: viper.GetString("oidc-sp-client-secret"),
|
|
RedirectURL: viper.GetString("hostname") + "/callback",
|
|
Endpoint: provider.Endpoint(),
|
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
}
|
|
|
|
// Create auth config
|
|
config := &Config{
|
|
Store: store,
|
|
OAuthConfig: oauthConfig,
|
|
Provider: provider,
|
|
Verifier: provider.Verifier(&oidc.Config{ClientID: oauthConfig.ClientID}),
|
|
SessionName: "auth-session",
|
|
}
|
|
|
|
return config, nil
|
|
}
|
|
|
|
// RegisterHandlers adds all auth-related handlers to the router
|
|
func (c *Config) RegisterHandlers(mux *http.ServeMux) {
|
|
mux.HandleFunc("/login", c.LoginHandler)
|
|
mux.HandleFunc("/callback", c.CallbackHandler)
|
|
mux.HandleFunc("/logout", c.LogoutHandler)
|
|
mux.HandleFunc("/logout-callback", c.LogoutCallbackHandler)
|
|
mux.HandleFunc("/register", c.RegistrationHandler)
|
|
}
|
|
|
|
// Middleware returns an auth middleware function
|
|
func (c *Config) Middleware() func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Skip authentication for public routes
|
|
if r.URL.Path == "/login" || r.URL.Path == "/callback" ||
|
|
r.URL.Path == "/logout" || r.URL.Path == "/logout-callback" ||
|
|
r.URL.Path == "/register" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
session, err := c.Store.Get(r, c.SessionName)
|
|
if err != nil {
|
|
http.Redirect(w, r, "/login", http.StatusFound)
|
|
return
|
|
}
|
|
|
|
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
|
|
http.Redirect(w, r, "/login", http.StatusFound)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// LoginHandler initiates the OIDC authentication flow
|
|
func (c *Config) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
|
state := generateRandomString(32)
|
|
nonce := generateRandomString(32)
|
|
codeVerifier := generateRandomString(32)
|
|
hash := sha256.Sum256([]byte(codeVerifier))
|
|
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
|
|
|
session, _ := c.Store.Get(r, c.SessionName)
|
|
session.Values["state"] = state
|
|
session.Values["nonce"] = nonce
|
|
session.Values["code_verifier"] = codeVerifier
|
|
if err := session.Save(r, w); err != nil {
|
|
log.Printf("Error saving session: %v", err)
|
|
http.Error(w, "Server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
authURL := c.OAuthConfig.AuthCodeURL(
|
|
state,
|
|
oidc.Nonce(nonce),
|
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
)
|
|
http.Redirect(w, r, authURL, http.StatusFound)
|
|
}
|
|
|
|
// CallbackHandler processes the OIDC callback
|
|
func (c *Config) CallbackHandler(w http.ResponseWriter, r *http.Request) {
|
|
ctx := context.Background()
|
|
session, err := c.Store.Get(r, c.SessionName)
|
|
if err != nil {
|
|
http.Error(w, "Session error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Validate state
|
|
state := r.URL.Query().Get("state")
|
|
if savedState, ok := session.Values["state"].(string); !ok || state != savedState {
|
|
log.Printf("State mismatch: got %s, expected %s", state, savedState)
|
|
http.Error(w, "Invalid state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Exchange code
|
|
code := r.URL.Query().Get("code")
|
|
codeVerifier, ok := session.Values["code_verifier"].(string)
|
|
if !ok {
|
|
http.Error(w, "No code verifier found", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
token, err := c.OAuthConfig.Exchange(
|
|
ctx,
|
|
code,
|
|
oauth2.VerifierOption(codeVerifier),
|
|
)
|
|
if err != nil {
|
|
log.Printf("Token exchange failed: %v", err)
|
|
http.Error(w, "Token exchange failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Verify ID token
|
|
rawIDToken, ok := token.Extra("id_token").(string)
|
|
if !ok {
|
|
http.Error(w, "Missing ID token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
idToken, err := c.Verifier.Verify(ctx, rawIDToken)
|
|
if err != nil {
|
|
log.Printf("ID token verification failed: %v", err)
|
|
http.Error(w, "Invalid ID token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Verify nonce
|
|
var claims struct {
|
|
Nonce string `json:"nonce"`
|
|
}
|
|
if err := idToken.Claims(&claims); err != nil {
|
|
http.Error(w, "Could not parse claims", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if claims.Nonce != session.Values["nonce"].(string) {
|
|
http.Error(w, "Invalid nonce", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Authenticate session
|
|
session.Values["authenticated"] = true
|
|
session.Values["id_token"] = rawIDToken
|
|
|
|
// Store user info
|
|
var userInfo struct {
|
|
Email string `json:"email"`
|
|
Name string `json:"name"`
|
|
Username string `json:"preferred_username"`
|
|
}
|
|
if err := idToken.Claims(&userInfo); err == nil {
|
|
session.Values["email"] = userInfo.Email
|
|
session.Values["name"] = userInfo.Name
|
|
session.Values["username"] = userInfo.Username
|
|
}
|
|
|
|
if err := session.Save(r, w); err != nil {
|
|
log.Printf("Error saving session: %v", err)
|
|
http.Error(w, "Server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
}
|
|
|
|
// LogoutHandler handles user logout
|
|
func (c *Config) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
|
// Generate logout state for verification
|
|
state := generateRandomString(32)
|
|
|
|
// Get session
|
|
session, err := c.Store.Get(r, c.SessionName)
|
|
if err != nil {
|
|
http.Error(w, "Session error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Store logout state for verification
|
|
session.Values["logout_state"] = state
|
|
if err := session.Save(r, w); err != nil {
|
|
log.Printf("Error saving logout state: %v", err)
|
|
http.Error(w, "Server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Build logout URL
|
|
keycloakLogoutURL, err := url.Parse(viper.GetString("oidc-idp-issuer-url") + "/protocol/openid-connect/logout")
|
|
if err != nil {
|
|
log.Printf("Error parsing logout URL: %v", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Add query parameters
|
|
q := keycloakLogoutURL.Query()
|
|
// Use logout-callback for completing the logout flow
|
|
q.Set("post_logout_redirect_uri", viper.GetString("hostname")+"/logout-callback")
|
|
q.Set("client_id", viper.GetString("oidc-sp-client-id"))
|
|
q.Set("state", state)
|
|
|
|
// Add id_token_hint if available
|
|
if idToken, ok := session.Values["id_token"].(string); ok {
|
|
q.Set("id_token_hint", idToken)
|
|
}
|
|
|
|
keycloakLogoutURL.RawQuery = q.Encode()
|
|
|
|
// Log for debugging
|
|
log.Printf("Redirecting to logout URL: %s", keycloakLogoutURL.String())
|
|
|
|
// Redirect to Keycloak
|
|
http.Redirect(w, r, keycloakLogoutURL.String(), http.StatusFound)
|
|
}
|
|
|
|
// LogoutCallbackHandler handles the redirect after Keycloak logout
|
|
func (c *Config) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
|
// Get session
|
|
session, err := c.Store.Get(r, c.SessionName)
|
|
if err != nil {
|
|
http.Error(w, "Session error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Verify state parameter
|
|
returnedState := r.URL.Query().Get("state")
|
|
log.Printf("Received logout state: %s", returnedState)
|
|
|
|
savedState, ok := session.Values["logout_state"].(string)
|
|
if !ok {
|
|
log.Printf("No logout state found in session")
|
|
http.Error(w, "Invalid session state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
log.Printf("Saved logout state: %s", savedState)
|
|
|
|
if returnedState != savedState {
|
|
log.Printf("State mismatch: %s != %s", returnedState, savedState)
|
|
http.Error(w, "Invalid state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Clear the session completely
|
|
session.Options.MaxAge = -1
|
|
if err := session.Save(r, w); err != nil {
|
|
log.Printf("Error clearing session: %v", err)
|
|
http.Error(w, "Server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Redirect to login page
|
|
http.Redirect(w, r, "/login", http.StatusFound)
|
|
}
|
|
|
|
// RegistrationHandler redirects to the OIDC registration page
|
|
func (c *Config) RegistrationHandler(w http.ResponseWriter, r *http.Request) {
|
|
// Generate random state, nonce, and code verifier for security
|
|
state := generateRandomString(32)
|
|
nonce := generateRandomString(32)
|
|
codeVerifier := generateRandomString(32)
|
|
hash := sha256.Sum256([]byte(codeVerifier))
|
|
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
|
|
|
// Store state, nonce, and code verifier in session for verification
|
|
session, _ := c.Store.Get(r, c.SessionName)
|
|
session.Values["state"] = state
|
|
session.Values["nonce"] = nonce
|
|
session.Values["code_verifier"] = codeVerifier
|
|
if err := session.Save(r, w); err != nil {
|
|
log.Printf("Error saving session: %v", err)
|
|
http.Error(w, "Server error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Build the registration URL using the specified registrations endpoint
|
|
baseURL := viper.GetString("oidc-idp-issuer-url")
|
|
registrationURL, err := url.Parse(baseURL + "/protocol/openid-connect/registrations")
|
|
if err != nil {
|
|
log.Printf("Error parsing registration URL: %v", err)
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Add query parameters
|
|
q := registrationURL.Query()
|
|
q.Set("client_id", viper.GetString("oidc-sp-client-id"))
|
|
q.Set("response_type", "code")
|
|
q.Set("scope", "openid email profile")
|
|
q.Set("redirect_uri", viper.GetString("hostname")+"/callback")
|
|
q.Set("state", state)
|
|
q.Set("nonce", nonce)
|
|
q.Set("code_challenge", codeChallenge)
|
|
q.Set("code_challenge_method", "S256")
|
|
|
|
registrationURL.RawQuery = q.Encode()
|
|
|
|
// Log for debugging
|
|
log.Printf("Redirecting to registration URL: %s", registrationURL.String())
|
|
|
|
// Redirect to registration page
|
|
http.Redirect(w, r, registrationURL.String(), http.StatusFound)
|
|
}
|
|
|
|
// Helper function to generate random strings
|
|
func generateRandomString(n int) string {
|
|
b := make([]byte, n)
|
|
if _, err := rand.Read(b); err != nil {
|
|
panic(err)
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b)
|
|
}
|