Files
member-console/internal/workflows/stripe/webhook_subscription.go
Christian Galo 667e9ffe24 Add plan ladders and pool provision transitions
Introduce DB migrations for ladder and pool-attachment tables and an
audit log for provision transitions. Make product_type nullable and add
lifecycle_status plus a product_kinds view. Implement Transition and
ReapplyDefaultsForPool primitives, SQLC queries/models, webhook and
Temporal workflow integration, and accompanying unit/integration tests.
2026-04-19 20:45:56 -05:00

604 lines
21 KiB
Go

package stripe
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log/slog"
"time"
"git.coopcloud.tech/wiki-cafe/member-console/internal/billing"
"git.coopcloud.tech/wiki-cafe/member-console/internal/entitlements"
internalstripe "git.coopcloud.tech/wiki-cafe/member-console/internal/stripe"
"github.com/google/uuid"
)
// webhookCheckoutSessionPayload represents the scrubbed payload for checkout.session.completed events.
type webhookCheckoutSessionPayload struct {
ID string `json:"id"`
Mode string `json:"mode"`
Customer string `json:"customer"`
Subscription string `json:"subscription"`
Metadata map[string]string `json:"metadata"`
LineItems []checkoutSessionLineItem `json:"line_items"`
}
type checkoutSessionLineItem struct {
PriceID string `json:"price_id"`
Quantity int64 `json:"quantity"`
}
// webhookSubscriptionPayload represents the scrubbed payload for customer.subscription.* events.
type webhookSubscriptionPayload struct {
ID string `json:"id"`
Customer string `json:"customer"`
Status string `json:"status"`
CurrentPeriodStart int64 `json:"current_period_start"`
CurrentPeriodEnd int64 `json:"current_period_end"`
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
CanceledAt *int64 `json:"canceled_at"`
EndedAt *int64 `json:"ended_at"`
Items []subscriptionPayloadItem `json:"items"`
Metadata map[string]string `json:"metadata"`
}
type subscriptionPayloadItem struct {
ID string `json:"id"`
PriceID string `json:"price_id"`
Quantity int64 `json:"quantity"`
}
// handleCheckoutSessionEvent processes checkout.session.completed events.
func (a *WebhookActivities) handleCheckoutSessionEvent(ctx context.Context, evt WebhookEvent) (string, error) {
if evt.EventType != "checkout.session.completed" {
a.Logger.Info("checkout sub-event skipped", slog.String("event_type", evt.EventType))
return "skipped", nil
}
payload, err := a.readPayload(ctx, evt.ID)
if err != nil {
return "", err
}
var session webhookCheckoutSessionPayload
if err := json.Unmarshal(payload, &session); err != nil {
return "", fmt.Errorf("parse checkout session payload: %w", err)
}
if session.Mode != "subscription" {
a.Logger.Info("checkout.session.completed skipped (not subscription mode)",
slog.String("mode", session.Mode))
return "completed", nil
}
if session.Subscription == "" {
return "", fmt.Errorf("checkout.session.completed missing subscription id")
}
// Idempotency: check if mapping already exists
stripeQ := internalstripe.New(a.DB)
_, err = stripeQ.GetSubscriptionMappingByStripeID(ctx, sql.NullString{String: session.Subscription, Valid: true})
if err == nil {
a.Logger.Info("checkout.session.completed: subscription mapping already exists, skipping",
slog.String("stripe_subscription_id", session.Subscription))
return "completed", nil
}
if err != sql.ErrNoRows {
return "", fmt.Errorf("check existing subscription mapping: %w", err)
}
// Resolve billing_account_id from metadata or customer mapping
billingAccountID := session.Metadata["billing_account_id"]
if billingAccountID == "" {
custMapping, err := stripeQ.GetCustomerMappingByStripeCustomerID(ctx, sql.NullString{String: session.Customer, Valid: true})
if err != nil {
return "", fmt.Errorf("resolve billing account from customer %s: %w", session.Customer, err)
}
billingAccountID = custMapping.BillingAccountID
}
// Build subscription payload from the session for fulfillment
subPayload := webhookSubscriptionPayload{
ID: session.Subscription,
Customer: session.Customer,
Status: "active",
Metadata: session.Metadata,
}
for _, li := range session.LineItems {
subPayload.Items = append(subPayload.Items, subscriptionPayloadItem{
PriceID: li.PriceID,
Quantity: li.Quantity,
})
}
if err := a.fulfillSubscription(ctx, billingAccountID, subPayload, evt.ProviderEventID); err != nil {
return "", fmt.Errorf("fulfill subscription: %w", err)
}
a.Logger.Info("checkout.session.completed processed",
slog.String("stripe_subscription_id", session.Subscription),
slog.String("billing_account_id", billingAccountID))
return "completed", nil
}
// handleSubscriptionEvent processes customer.subscription.created/updated/deleted events.
func (a *WebhookActivities) handleSubscriptionEvent(ctx context.Context, evt WebhookEvent) (string, error) {
payload, err := a.readPayload(ctx, evt.ID)
if err != nil {
return "", err
}
var sub webhookSubscriptionPayload
if err := json.Unmarshal(payload, &sub); err != nil {
return "", fmt.Errorf("parse subscription payload: %w", err)
}
if sub.ID == "" {
return "", fmt.Errorf("subscription payload missing id")
}
stripeQ := internalstripe.New(a.DB)
switch evt.EventType {
case "customer.subscription.created":
return a.handleSubscriptionCreated(ctx, evt, sub, stripeQ)
case "customer.subscription.updated":
return a.handleSubscriptionUpdated(ctx, evt, sub, stripeQ)
case "customer.subscription.deleted":
return a.handleSubscriptionDeleted(ctx, evt, sub, stripeQ)
default:
a.Logger.Info("subscription sub-event skipped", slog.String("event_type", evt.EventType))
return "skipped", nil
}
}
func (a *WebhookActivities) handleSubscriptionCreated(ctx context.Context, evt WebhookEvent, sub webhookSubscriptionPayload, stripeQ *internalstripe.Queries) (string, error) {
// Idempotent: if mapping already exists, no-op
_, err := stripeQ.GetSubscriptionMappingByStripeID(ctx, sql.NullString{String: sub.ID, Valid: true})
if err == nil {
a.Logger.Info("customer.subscription.created: mapping exists, skipping",
slog.String("stripe_subscription_id", sub.ID))
return "completed", nil
}
if err != sql.ErrNoRows {
return "", fmt.Errorf("check subscription mapping: %w", err)
}
// Resolve billing account via customer mapping
custMapping, err := stripeQ.GetCustomerMappingByStripeCustomerID(ctx, sql.NullString{String: sub.Customer, Valid: true})
if err != nil {
return "", fmt.Errorf("resolve billing account from customer %s: %w", sub.Customer, err)
}
if err := a.fulfillSubscription(ctx, custMapping.BillingAccountID, sub, evt.ProviderEventID); err != nil {
return "", fmt.Errorf("fulfill subscription: %w", err)
}
a.Logger.Info("customer.subscription.created processed",
slog.String("stripe_subscription_id", sub.ID))
return "completed", nil
}
func (a *WebhookActivities) handleSubscriptionUpdated(ctx context.Context, evt WebhookEvent, sub webhookSubscriptionPayload, stripeQ *internalstripe.Queries) (string, error) {
// Look up existing mapping
mapping, err := stripeQ.GetSubscriptionMappingByStripeID(ctx, sql.NullString{String: sub.ID, Valid: true})
if err == sql.ErrNoRows {
a.Logger.Warn("customer.subscription.updated: no mapping found",
slog.String("stripe_subscription_id", sub.ID))
return "completed", nil
}
if err != nil {
return "", fmt.Errorf("lookup subscription mapping: %w", err)
}
billingQ := billing.New(a.DB)
existing, err := billingQ.GetSubscriptionByID(ctx, mapping.SubscriptionID)
if err != nil {
return "", fmt.Errorf("get subscription: %w", err)
}
// Update subscription fields
periodStart := sql.NullTime{}
if sub.CurrentPeriodStart > 0 {
periodStart = sql.NullTime{Time: time.Unix(sub.CurrentPeriodStart, 0), Valid: true}
}
periodEnd := sql.NullTime{}
if sub.CurrentPeriodEnd > 0 {
periodEnd = sql.NullTime{Time: time.Unix(sub.CurrentPeriodEnd, 0), Valid: true}
}
canceledAt := sql.NullTime{}
if sub.CanceledAt != nil {
canceledAt = sql.NullTime{Time: time.Unix(*sub.CanceledAt, 0), Valid: true}
}
endedAt := sql.NullTime{}
if sub.EndedAt != nil {
endedAt = sql.NullTime{Time: time.Unix(*sub.EndedAt, 0), Valid: true}
}
_, err = billingQ.UpdateSubscriptionStatus(ctx, billing.UpdateSubscriptionStatusParams{
SubscriptionID: existing.SubscriptionID,
Status: sub.Status,
CurrentPeriodStart: periodStart,
CurrentPeriodEnd: periodEnd,
CancelAtPeriodEnd: sub.CancelAtPeriodEnd,
CanceledAt: canceledAt,
EndedAt: endedAt,
})
if err != nil {
return "", fmt.Errorf("update subscription status: %w", err)
}
// Record status change if different
if sub.Status != existing.Status {
_, err = billingQ.CreateSubscriptionChange(ctx, billing.CreateSubscriptionChangeParams{
SubscriptionID: existing.SubscriptionID,
PreviousStatus: sql.NullString{String: existing.Status, Valid: true},
NewStatus: sub.Status,
StripeEventID: evt.ProviderEventID,
})
if err != nil {
return "", fmt.Errorf("record subscription change: %w", err)
}
// Adjust provision status based on subscription status
if err := a.adjustProvisionStatus(ctx, existing.SubscriptionID, sub.Status, evt.ProviderEventID); err != nil {
return "", fmt.Errorf("adjust provision status: %w", err)
}
}
a.Logger.Info("customer.subscription.updated processed",
slog.String("stripe_subscription_id", sub.ID),
slog.String("new_status", sub.Status))
return "completed", nil
}
func (a *WebhookActivities) handleSubscriptionDeleted(ctx context.Context, evt WebhookEvent, sub webhookSubscriptionPayload, stripeQ *internalstripe.Queries) (string, error) {
mapping, err := stripeQ.GetSubscriptionMappingByStripeID(ctx, sql.NullString{String: sub.ID, Valid: true})
if err == sql.ErrNoRows {
a.Logger.Warn("customer.subscription.deleted: no mapping found",
slog.String("stripe_subscription_id", sub.ID))
return "completed", nil
}
if err != nil {
return "", fmt.Errorf("lookup subscription mapping: %w", err)
}
billingQ := billing.New(a.DB)
existing, err := billingQ.GetSubscriptionByID(ctx, mapping.SubscriptionID)
if err != nil {
return "", fmt.Errorf("get subscription: %w", err)
}
// Update subscription to canceled/ended
endedAt := sql.NullTime{Time: time.Now(), Valid: true}
if sub.EndedAt != nil {
endedAt = sql.NullTime{Time: time.Unix(*sub.EndedAt, 0), Valid: true}
}
_, err = billingQ.UpdateSubscriptionStatus(ctx, billing.UpdateSubscriptionStatusParams{
SubscriptionID: existing.SubscriptionID,
Status: "canceled",
CurrentPeriodStart: existing.CurrentPeriodStart,
CurrentPeriodEnd: existing.CurrentPeriodEnd,
CancelAtPeriodEnd: existing.CancelAtPeriodEnd,
CanceledAt: existing.CanceledAt,
EndedAt: endedAt,
})
if err != nil {
return "", fmt.Errorf("update subscription: %w", err)
}
// Record change
_, err = billingQ.CreateSubscriptionChange(ctx, billing.CreateSubscriptionChangeParams{
SubscriptionID: existing.SubscriptionID,
PreviousStatus: sql.NullString{String: existing.Status, Valid: true},
NewStatus: "canceled",
StripeEventID: evt.ProviderEventID,
})
if err != nil {
return "", fmt.Errorf("record subscription change: %w", err)
}
// End provision and rematerialize
if err := a.adjustProvisionStatus(ctx, existing.SubscriptionID, "canceled", evt.ProviderEventID); err != nil {
return "", fmt.Errorf("end provision: %w", err)
}
// Mark mapping deleted
if err := stripeQ.UpdateSubscriptionMappingSyncStatus(ctx, internalstripe.UpdateSubscriptionMappingSyncStatusParams{
SubscriptionID: existing.SubscriptionID,
SyncStatus: "deleted",
}); err != nil {
return "", fmt.Errorf("mark mapping deleted: %w", err)
}
a.Logger.Info("customer.subscription.deleted processed",
slog.String("stripe_subscription_id", sub.ID))
return "completed", nil
}
// fulfillSubscription creates core billing records, mappings, and provisions
// within a single transaction.
func (a *WebhookActivities) fulfillSubscription(ctx context.Context, billingAccountID string, sub webhookSubscriptionPayload, stripeEventID string) error {
tx, err := a.DB.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
defer tx.Rollback()
billingQ := billing.New(tx)
stripeQ := internalstripe.New(tx)
entQ := entitlements.New(tx)
// Resolve period timestamps
periodStart := sql.NullTime{}
if sub.CurrentPeriodStart > 0 {
periodStart = sql.NullTime{Time: time.Unix(sub.CurrentPeriodStart, 0), Valid: true}
}
periodEnd := sql.NullTime{}
if sub.CurrentPeriodEnd > 0 {
periodEnd = sql.NullTime{Time: time.Unix(sub.CurrentPeriodEnd, 0), Valid: true}
}
// 1. Create billing.subscriptions record
subscription, err := billingQ.CreateSubscription(ctx, billing.CreateSubscriptionParams{
BillingAccountID: billingAccountID,
Status: sub.Status,
CurrentPeriodStart: periodStart,
CurrentPeriodEnd: periodEnd,
})
if err != nil {
return fmt.Errorf("create subscription: %w", err)
}
// 2. Create stripe.subscription_mappings record
_, err = stripeQ.UpsertSubscriptionMapping(ctx, internalstripe.UpsertSubscriptionMappingParams{
SubscriptionID: subscription.SubscriptionID,
StripeSubscriptionID: sql.NullString{String: sub.ID, Valid: true},
SyncStatus: "synced",
})
if err != nil {
return fmt.Errorf("create subscription mapping: %w", err)
}
// 3. Create subscription items and their mappings
for _, item := range sub.Items {
// Resolve core price_id and product_id from stripe price mapping
priceMapping, err := stripeQ.GetPriceMappingByStripeID(ctx, sql.NullString{String: item.PriceID, Valid: true})
if err != nil {
return fmt.Errorf("resolve price mapping for stripe price %s: %w", item.PriceID, err)
}
price, err := billingQ.GetPrice(ctx, priceMapping.PriceID)
if err != nil {
return fmt.Errorf("get price %s: %w", priceMapping.PriceID, err)
}
quantity := int32(item.Quantity)
if quantity <= 0 {
quantity = 1
}
subItem, err := billingQ.CreateSubscriptionItem(ctx, billing.CreateSubscriptionItemParams{
SubscriptionID: subscription.SubscriptionID,
ProductID: price.ProductID,
PriceID: price.PriceID,
Quantity: quantity,
})
if err != nil {
return fmt.Errorf("create subscription item: %w", err)
}
// Create subscription item mapping if we have a stripe item ID
if item.ID != "" {
_, err = stripeQ.UpsertSubscriptionItemMapping(ctx, internalstripe.UpsertSubscriptionItemMappingParams{
SubscriptionItemID: subItem.SubscriptionItemID,
StripeSubscriptionItemID: sql.NullString{String: item.ID, Valid: true},
SyncStatus: "synced",
})
if err != nil {
return fmt.Errorf("create subscription item mapping: %w", err)
}
}
// 4. Create pool provision for this subscription item
// Resolve: billing_account → org → default pool
account, err := billingQ.GetBillingAccountByID(ctx, billingAccountID)
if err != nil {
return fmt.Errorf("get billing account: %w", err)
}
pool, err := entQ.GetDefaultPoolByOrgID(ctx, account.OrgID)
if err != nil {
return fmt.Errorf("get default pool for org %s: %w", account.OrgID, err)
}
// Resolve entitlement_set_id from product
product, err := billingQ.GetProductByID(ctx, price.ProductID)
if err != nil {
return fmt.Errorf("get product %s: %w", price.ProductID, err)
}
if !product.EntitlementSetID.Valid {
a.Logger.Warn("product has no entitlement set, skipping provision",
slog.String("product_id", product.ProductID))
continue
}
subUUID, err := uuid.Parse(subscription.SubscriptionID)
if err != nil {
return fmt.Errorf("invalid subscription UUID %s: %w", subscription.SubscriptionID, err)
}
baUUID, err := uuid.Parse(billingAccountID)
if err != nil {
return fmt.Errorf("invalid billing account UUID %s: %w", billingAccountID, err)
}
// Route plan products through the transition primitive so ladder
// attachment and audit rows are produced atomically. Non-plan
// products stay on the direct provision path.
tiers, err := billingQ.ListLaddersByProduct(ctx, product.ProductID)
if err != nil {
return fmt.Errorf("lookup ladders for product %s: %w", product.ProductID, err)
}
if len(tiers) > 0 {
tier := tiers[0]
if _, err := entitlements.Transition(ctx, tx, pool.PoolID, entitlements.TransitionTarget{
LadderID: tier.PlanLadderID,
ProductID: product.ProductID,
Source: entitlements.TransitionSource{
SubscriptionID: uuid.NullUUID{UUID: subUUID, Valid: true},
BillingAccountID: uuid.NullUUID{UUID: baUUID, Valid: true},
EntitlementSetID: product.EntitlementSetID.UUID.String(),
Quantity: quantity,
},
}, entitlements.TransitionActor{
ActorType: "webhook",
Reason: "stripe:subscription.created",
}); err != nil {
return fmt.Errorf("transition for plan product: %w", err)
}
continue
}
_, err = entQ.CreateSubscriptionPoolProvision(ctx, entitlements.CreateSubscriptionPoolProvisionParams{
PoolID: pool.PoolID,
SubscriptionID: uuid.NullUUID{UUID: subUUID, Valid: true},
BillingAccountID: uuid.NullUUID{UUID: baUUID, Valid: true},
EntitlementSetID: product.EntitlementSetID.UUID.String(),
Quantity: quantity,
})
if err != nil {
return fmt.Errorf("create pool provision: %w", err)
}
if err := entitlements.MaterializePoolEntitlements(ctx, entQ, pool.PoolID); err != nil {
return fmt.Errorf("materialize entitlements: %w", err)
}
}
// 6. Record initial subscription change
_, err = billingQ.CreateSubscriptionChange(ctx, billing.CreateSubscriptionChangeParams{
SubscriptionID: subscription.SubscriptionID,
NewStatus: sub.Status,
StripeEventID: stripeEventID,
})
if err != nil {
return fmt.Errorf("record subscription change: %w", err)
}
return tx.Commit()
}
// adjustProvisionStatus updates the pool provision status based on subscription
// status and rematerializes entitlements. When the resulting state is "ended"
// and the provision is attached to a plan ladder, routing goes through the
// Transition primitive so ladder attachments and audit rows stay consistent.
// Intermediate states (suspended/active) keep the direct provision-status
// update path and do not produce transition rows.
func (a *WebhookActivities) adjustProvisionStatus(ctx context.Context, subscriptionID string, subscriptionStatus string, stripeEventID string) error {
entQ := entitlements.New(a.DB)
subUUID, err := uuid.Parse(subscriptionID)
if err != nil {
return fmt.Errorf("invalid subscription UUID %s: %w", subscriptionID, err)
}
provision, err := entQ.GetPoolProvisionBySubscriptionID(ctx, uuid.NullUUID{UUID: subUUID, Valid: true})
if err == sql.ErrNoRows {
a.Logger.Warn("no provision found for subscription", slog.String("subscription_id", subscriptionID))
return nil
}
if err != nil {
return fmt.Errorf("get provision: %w", err)
}
var newStatus string
switch subscriptionStatus {
case "active", "trialing":
newStatus = "active"
case "past_due", "unpaid":
newStatus = "suspended"
case "canceled", "incomplete_expired":
newStatus = "ended"
default:
return nil
}
if newStatus == provision.Status {
return nil
}
if newStatus == "ended" {
attachments, err := entQ.GetLadderAttachmentsByProvision(ctx, provision.ProvisionID)
if err != nil {
return fmt.Errorf("lookup ladder attachments for provision %s: %w", provision.ProvisionID, err)
}
hasActive := false
for _, a := range attachments {
if a.Status == "active" {
hasActive = true
break
}
}
if hasActive {
return a.endPlanProvisionViaTransition(ctx, provision.PoolID, subscriptionStatus, stripeEventID)
}
}
_, err = entQ.UpdatePoolProvisionStatus(ctx, entitlements.UpdatePoolProvisionStatusParams{
Status: newStatus,
ProvisionID: provision.ProvisionID,
})
if err != nil {
return fmt.Errorf("update provision status: %w", err)
}
if err := entitlements.MaterializePoolEntitlements(ctx, entQ, provision.PoolID); err != nil {
return fmt.Errorf("rematerialize entitlements: %w", err)
}
return nil
}
// endPlanProvisionViaTransition runs entitlements.Transition with target.End
// inside its own transaction so the ladder attachment is closed, the default
// product is re-applied if configured, and a single transition row is
// recorded. Used when a subscription backing a plan provision is canceled.
func (a *WebhookActivities) endPlanProvisionViaTransition(ctx context.Context, poolID, subscriptionStatus, stripeEventID string) error {
tx, err := a.DB.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
defer tx.Rollback()
reason := fmt.Sprintf("stripe:subscription.%s", subscriptionStatus)
if stripeEventID != "" {
reason = reason + " (" + stripeEventID + ")"
}
if _, err := entitlements.Transition(ctx, tx, poolID, entitlements.TransitionTarget{End: true}, entitlements.TransitionActor{
ActorType: "webhook",
Reason: reason,
}); err != nil {
return fmt.Errorf("transition end: %w", err)
}
if err := entitlements.MaterializePoolEntitlements(ctx, entitlements.New(tx), poolID); err != nil {
return fmt.Errorf("rematerialize after transition: %w", err)
}
return tx.Commit()
}
// readPayload is a helper to read the scrubbed webhook payload.
func (a *WebhookActivities) readPayload(ctx context.Context, eventID int64) ([]byte, error) {
var payloadBytes []byte
err := a.DB.QueryRowContext(ctx,
`SELECT payload FROM integration.webhook_events
WHERE id = $1 AND received_at = (SELECT received_at FROM integration.webhook_events WHERE id = $1)`,
eventID,
).Scan(&payloadBytes)
if err != nil {
return nil, fmt.Errorf("read webhook payload: %w", err)
}
return payloadBytes, nil
}