123 lines
3.2 KiB
Go
123 lines
3.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"net/http"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/gorilla/sessions"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type AuthConfig struct {
|
|
Store *sessions.CookieStore
|
|
OAuthConfig *oauth2.Config
|
|
Verifier *oidc.IDTokenVerifier
|
|
}
|
|
|
|
func AuthMiddleware(config *AuthConfig) Middleware {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Skip authentication for public routes
|
|
if r.URL.Path == "/login" || r.URL.Path == "/callback" {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
session, _ := config.Store.Get(r, "auth-session")
|
|
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
|
|
http.Redirect(w, r, "/login", http.StatusFound)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
func generateRandomString(n int) string {
|
|
b := make([]byte, n)
|
|
rand.Read(b)
|
|
return base64.RawURLEncoding.EncodeToString(b)
|
|
}
|
|
|
|
func LoginHandler(config *AuthConfig) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
state := generateRandomString(32)
|
|
nonce := generateRandomString(32)
|
|
codeVerifier := generateRandomString(32)
|
|
hash := sha256.Sum256([]byte(codeVerifier))
|
|
codeChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
|
|
|
session, _ := config.Store.Get(r, "auth-session")
|
|
session.Values["state"] = state
|
|
session.Values["nonce"] = nonce
|
|
session.Values["code_verifier"] = codeVerifier
|
|
session.Save(r, w)
|
|
|
|
authURL := config.OAuthConfig.AuthCodeURL(
|
|
state,
|
|
oidc.Nonce(nonce),
|
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
|
)
|
|
http.Redirect(w, r, authURL, http.StatusFound)
|
|
}
|
|
}
|
|
|
|
func CallbackHandler(config *AuthConfig) http.HandlerFunc {
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := context.Background()
|
|
session, _ := config.Store.Get(r, "auth-session")
|
|
|
|
// Validate state
|
|
state := r.URL.Query().Get("state")
|
|
if savedState, ok := session.Values["state"].(string); !ok || state != savedState {
|
|
http.Error(w, "Invalid state", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Exchange code
|
|
codeVerifier := session.Values["code_verifier"].(string)
|
|
token, err := config.OAuthConfig.Exchange(
|
|
ctx,
|
|
r.URL.Query().Get("code"),
|
|
oauth2.VerifierOption(codeVerifier),
|
|
)
|
|
if err != nil {
|
|
http.Error(w, "Token exchange failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Verify ID token
|
|
rawIDToken, ok := token.Extra("id_token").(string)
|
|
if !ok {
|
|
http.Error(w, "Missing ID token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
idToken, err := config.Verifier.Verify(ctx, rawIDToken)
|
|
if err != nil {
|
|
http.Error(w, "Invalid ID token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Verify nonce
|
|
var claims struct {
|
|
Nonce string `json:"nonce"`
|
|
}
|
|
if err := idToken.Claims(&claims); err != nil || claims.Nonce != session.Values["nonce"].(string) {
|
|
http.Error(w, "Invalid nonce", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Authenticate session
|
|
session.Values["authenticated"] = true
|
|
session.Save(r, w)
|
|
http.Redirect(w, r, "/", http.StatusFound)
|
|
}
|
|
}
|