Files
member-console/internal/workflows/stripe/webhook.go

470 lines
16 KiB
Go

package stripe
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log/slog"
"strings"
"time"
internalstripe "git.coopcloud.tech/wiki-cafe/member-console/internal/stripe"
"git.coopcloud.tech/wiki-cafe/member-console/internal/workflows/common"
"go.temporal.io/sdk/workflow"
)
// WebhookEvent represents a row from integration.webhook_events.
type WebhookEvent struct {
ID int64 `db:"id"`
Provider string `db:"provider"`
ProviderEventID string `db:"provider_event_id"`
EventType string `db:"event_type"`
Status string `db:"status"`
RetryCount int `db:"retry_count"`
}
// ProcessStripeWebhooksInput configures the polling workflow.
type ProcessStripeWebhooksInput struct {
BatchSize int
PollInterval time.Duration
}
// ProcessStripeWebhooks is a Temporal workflow that polls for unprocessed
// webhook events and processes them in batches. It runs as a long-running
// polling workflow using ContinueAsNew to reset history.
func ProcessStripeWebhooks(ctx workflow.Context, input ProcessStripeWebhooksInput) error {
logger := workflow.GetLogger(ctx)
if input.BatchSize <= 0 {
input.BatchSize = 50
}
if input.PollInterval <= 0 {
input.PollInterval = 15 * time.Second
}
actCtx := workflow.WithActivityOptions(ctx, common.DefaultActivityOptions())
// Poll for received events
var acts *WebhookActivities
var events []WebhookEvent
err := workflow.ExecuteActivity(actCtx, acts.PollReceivedEvents, input.BatchSize).Get(ctx, &events)
if err != nil {
logger.Error("failed to poll webhook events", "error", err)
return err
}
// Process each event
for _, evt := range events {
var processErr error
err := workflow.ExecuteActivity(actCtx, acts.ProcessWebhookEvent, evt).Get(ctx, &processErr)
if err != nil {
logger.Error("activity execution failed",
"event_id", evt.ProviderEventID,
"error", err)
// Mark as failed via activity
_ = workflow.ExecuteActivity(actCtx, acts.MarkEventFailed, evt.ID, err.Error()).Get(ctx, nil)
}
}
// Sleep then continue-as-new for the next poll cycle
_ = workflow.Sleep(ctx, input.PollInterval)
return workflow.NewContinueAsNewError(ctx, ProcessStripeWebhooks, input)
}
// WebhookActivities holds dependencies for webhook processing activities.
type WebhookActivities struct {
DB *sql.DB
Logger *slog.Logger
}
// NewWebhookActivities creates a new WebhookActivities instance.
func NewWebhookActivities(db *sql.DB, logger *slog.Logger) *WebhookActivities {
return &WebhookActivities{DB: db, Logger: logger}
}
// PollReceivedEvents selects a batch of webhook events with status = 'received'.
func (a *WebhookActivities) PollReceivedEvents(ctx context.Context, batchSize int) ([]WebhookEvent, error) {
rows, err := a.DB.QueryContext(ctx,
`SELECT id, provider, provider_event_id, event_type, status, retry_count
FROM integration.webhook_events
WHERE status = 'received'
ORDER BY received_at ASC
LIMIT $1`,
batchSize,
)
if err != nil {
return nil, fmt.Errorf("poll webhook events: %w", err)
}
defer rows.Close()
var events []WebhookEvent
for rows.Next() {
var e WebhookEvent
if err := rows.Scan(&e.ID, &e.Provider, &e.ProviderEventID, &e.EventType, &e.Status, &e.RetryCount); err != nil {
return nil, fmt.Errorf("scan webhook event: %w", err)
}
events = append(events, e)
}
return events, rows.Err()
}
// ProcessWebhookEvent transitions an event to 'processing', dispatches it,
// and marks it 'completed' or 'skipped'.
func (a *WebhookActivities) ProcessWebhookEvent(ctx context.Context, evt WebhookEvent) error {
// Transition to processing
_, err := a.DB.ExecContext(ctx,
`UPDATE integration.webhook_events
SET status = 'processing', updated_at = NOW()
WHERE id = $1 AND received_at = (SELECT received_at FROM integration.webhook_events WHERE id = $1)`,
evt.ID,
)
if err != nil {
return fmt.Errorf("mark processing: %w", err)
}
// Dispatch to event-type handler.
finalStatus, err := a.dispatchEvent(ctx, evt)
if err != nil {
return fmt.Errorf("dispatch %s: %w", evt.EventType, err)
}
// Mark final status
_, err = a.DB.ExecContext(ctx,
`UPDATE integration.webhook_events
SET status = $1, processed_at = NOW(), updated_at = NOW()
WHERE id = $2 AND received_at = (SELECT received_at FROM integration.webhook_events WHERE id = $2)`,
finalStatus, evt.ID,
)
if err != nil {
return fmt.Errorf("mark %s: %w", finalStatus, err)
}
return nil
}
// MarkEventFailed updates an event to 'failed' with an error message and
// increments retry_count.
func (a *WebhookActivities) MarkEventFailed(ctx context.Context, eventID int64, errMsg string) error {
_, err := a.DB.ExecContext(ctx,
`UPDATE integration.webhook_events
SET status = 'failed', error_message = $1, retry_count = retry_count + 1, updated_at = NOW()
WHERE id = $2 AND received_at = (SELECT received_at FROM integration.webhook_events WHERE id = $2)`,
errMsg, eventID,
)
if err != nil {
return fmt.Errorf("mark failed: %w", err)
}
return nil
}
// dispatchEvent routes webhook events to type-specific handlers.
// Returns the final status string ("completed" or "skipped").
func (a *WebhookActivities) dispatchEvent(ctx context.Context, evt WebhookEvent) (string, error) {
switch {
case strings.HasPrefix(evt.EventType, "checkout.session."):
return a.handleCheckoutSessionEvent(ctx, evt)
case strings.HasPrefix(evt.EventType, "customer.subscription."):
return a.handleSubscriptionEvent(ctx, evt)
case strings.HasPrefix(evt.EventType, "customer."):
return a.handleCustomerEvent(ctx, evt)
case strings.HasPrefix(evt.EventType, "product."):
return a.handleProductEvent(ctx, evt)
case strings.HasPrefix(evt.EventType, "price."):
return a.handlePriceEvent(ctx, evt)
case strings.HasPrefix(evt.EventType, "invoice."):
return a.handleInvoiceEvent(ctx, evt)
case strings.HasPrefix(evt.EventType, "payment_method."):
return a.handlePaymentMethodEvent(ctx, evt)
default:
a.Logger.Info("webhook event skipped (unhandled type)",
slog.String("event_id", evt.ProviderEventID),
slog.String("event_type", evt.EventType))
return "skipped", nil
}
}
// webhookCustomerPayload represents the scrubbed payload for customer events.
type webhookCustomerPayload struct {
ID string `json:"id"` // Stripe Customer ID (cus_…)
}
// handleCustomerEvent processes customer.created, customer.updated, customer.deleted events.
func (a *WebhookActivities) handleCustomerEvent(ctx context.Context, evt WebhookEvent) (string, error) {
// Read the scrubbed payload from the webhook_events table.
var payloadBytes []byte
err := a.DB.QueryRowContext(ctx,
`SELECT payload FROM integration.webhook_events
WHERE id = $1 AND received_at = (SELECT received_at FROM integration.webhook_events WHERE id = $1)`,
evt.ID,
).Scan(&payloadBytes)
if err != nil {
return "", fmt.Errorf("read webhook payload: %w", err)
}
var payload webhookCustomerPayload
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
return "", fmt.Errorf("parse customer webhook payload: %w", err)
}
stripeCustomerID := payload.ID
if stripeCustomerID == "" {
return "", fmt.Errorf("customer webhook payload missing id")
}
q := internalstripe.New(a.DB)
switch evt.EventType {
case "customer.created":
// Idempotent upsert: if mapping already exists (written by outbox), update timestamp.
// If not, create with sync_status = 'synced'.
// We need the billing_account_id — look up by stripe_customer_id first.
existing, err := q.GetCustomerMappingByStripeCustomerID(ctx, sql.NullString{String: stripeCustomerID, Valid: true})
if err == nil {
// Mapping exists — just touch updated_at
if err := q.UpdateCustomerMappingSyncStatus(ctx, internalstripe.UpdateCustomerMappingSyncStatusParams{
BillingAccountID: existing.BillingAccountID,
SyncStatus: "synced",
}); err != nil {
return "", fmt.Errorf("update existing mapping: %w", err)
}
} else if err == sql.ErrNoRows {
// No mapping yet — the outbox hasn't run or this was created externally.
// We can't create a mapping without a billing_account_id, so log and skip.
a.Logger.Warn("customer.created webhook: no mapping found for stripe customer (created externally?)",
slog.String("stripe_customer_id", stripeCustomerID))
return "completed", nil
} else {
return "", fmt.Errorf("lookup mapping by stripe customer: %w", err)
}
case "customer.updated":
// Touch updated_at on the mapping row.
existing, err := q.GetCustomerMappingByStripeCustomerID(ctx, sql.NullString{String: stripeCustomerID, Valid: true})
if err == sql.ErrNoRows {
a.Logger.Warn("customer.updated webhook: no mapping found",
slog.String("stripe_customer_id", stripeCustomerID))
return "completed", nil
} else if err != nil {
return "", fmt.Errorf("lookup mapping: %w", err)
}
if err := q.UpdateCustomerMappingSyncStatus(ctx, internalstripe.UpdateCustomerMappingSyncStatusParams{
BillingAccountID: existing.BillingAccountID,
SyncStatus: existing.SyncStatus,
}); err != nil {
return "", fmt.Errorf("update mapping timestamp: %w", err)
}
case "customer.deleted":
existing, err := q.GetCustomerMappingByStripeCustomerID(ctx, sql.NullString{String: stripeCustomerID, Valid: true})
if err == sql.ErrNoRows {
a.Logger.Warn("customer.deleted webhook: no mapping found",
slog.String("stripe_customer_id", stripeCustomerID))
return "completed", nil
} else if err != nil {
return "", fmt.Errorf("lookup mapping: %w", err)
}
if err := q.UpdateCustomerMappingSyncStatus(ctx, internalstripe.UpdateCustomerMappingSyncStatusParams{
BillingAccountID: existing.BillingAccountID,
SyncStatus: "deleted",
}); err != nil {
return "", fmt.Errorf("mark mapping deleted: %w", err)
}
default:
a.Logger.Info("customer sub-event skipped",
slog.String("event_type", evt.EventType))
return "skipped", nil
}
a.Logger.Info("customer webhook processed",
slog.String("event_type", evt.EventType),
slog.String("stripe_customer_id", stripeCustomerID))
return "completed", nil
}
// webhookProductPayload represents the scrubbed payload for product events.
type webhookProductPayload struct {
ID string `json:"id"` // Stripe Product ID (prod_…)
}
// handleProductEvent processes product.created, product.updated, product.deleted events.
func (a *WebhookActivities) handleProductEvent(ctx context.Context, evt WebhookEvent) (string, error) {
var payloadBytes []byte
err := a.DB.QueryRowContext(ctx,
`SELECT payload FROM integration.webhook_events
WHERE id = $1 AND received_at = (SELECT received_at FROM integration.webhook_events WHERE id = $1)`,
evt.ID,
).Scan(&payloadBytes)
if err != nil {
return "", fmt.Errorf("read webhook payload: %w", err)
}
var payload webhookProductPayload
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
return "", fmt.Errorf("parse product webhook payload: %w", err)
}
stripeProductID := payload.ID
if stripeProductID == "" {
return "", fmt.Errorf("product webhook payload missing id")
}
q := internalstripe.New(a.DB)
switch evt.EventType {
case "product.created":
existing, err := q.GetProductMappingByStripeID(ctx, sql.NullString{String: stripeProductID, Valid: true})
if err == nil {
// Mapping exists — touch updated_at by re-upserting with same status.
_, err = q.UpsertProductMapping(ctx, internalstripe.UpsertProductMappingParams{
ProductID: existing.ProductID,
StripeProductID: sql.NullString{String: stripeProductID, Valid: true},
SyncStatus: "synced",
})
if err != nil {
return "", fmt.Errorf("update existing product mapping: %w", err)
}
} else if err == sql.ErrNoRows {
// No mapping yet — webhook arrived before outbox. Can't create without product_id.
a.Logger.Warn("product.created webhook: no mapping found for stripe product (created externally?)",
slog.String("stripe_product_id", stripeProductID))
return "completed", nil
} else {
return "", fmt.Errorf("lookup product mapping by stripe id: %w", err)
}
case "product.updated":
existing, err := q.GetProductMappingByStripeID(ctx, sql.NullString{String: stripeProductID, Valid: true})
if err == sql.ErrNoRows {
a.Logger.Warn("product.updated webhook: no mapping found",
slog.String("stripe_product_id", stripeProductID))
return "completed", nil
} else if err != nil {
return "", fmt.Errorf("lookup product mapping: %w", err)
}
_, err = q.UpsertProductMapping(ctx, internalstripe.UpsertProductMappingParams{
ProductID: existing.ProductID,
StripeProductID: existing.StripeProductID,
SyncStatus: existing.SyncStatus,
})
if err != nil {
return "", fmt.Errorf("update product mapping timestamp: %w", err)
}
case "product.deleted":
existing, err := q.GetProductMappingByStripeID(ctx, sql.NullString{String: stripeProductID, Valid: true})
if err == sql.ErrNoRows {
a.Logger.Warn("product.deleted webhook: no mapping found",
slog.String("stripe_product_id", stripeProductID))
return "completed", nil
} else if err != nil {
return "", fmt.Errorf("lookup product mapping: %w", err)
}
if err := q.MarkProductMappingDeleted(ctx, existing.ProductID); err != nil {
return "", fmt.Errorf("mark product mapping deleted: %w", err)
}
default:
a.Logger.Info("product sub-event skipped",
slog.String("event_type", evt.EventType))
return "skipped", nil
}
a.Logger.Info("product webhook processed",
slog.String("event_type", evt.EventType),
slog.String("stripe_product_id", stripeProductID))
return "completed", nil
}
// webhookPricePayload represents the scrubbed payload for price events.
type webhookPricePayload struct {
ID string `json:"id"` // Stripe Price ID (price_…)
Active bool `json:"active"` // Whether the price is active
}
// handlePriceEvent processes price.created and price.updated events.
func (a *WebhookActivities) handlePriceEvent(ctx context.Context, evt WebhookEvent) (string, error) {
var payloadBytes []byte
err := a.DB.QueryRowContext(ctx,
`SELECT payload FROM integration.webhook_events
WHERE id = $1 AND received_at = (SELECT received_at FROM integration.webhook_events WHERE id = $1)`,
evt.ID,
).Scan(&payloadBytes)
if err != nil {
return "", fmt.Errorf("read webhook payload: %w", err)
}
var payload webhookPricePayload
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
return "", fmt.Errorf("parse price webhook payload: %w", err)
}
stripePriceID := payload.ID
if stripePriceID == "" {
return "", fmt.Errorf("price webhook payload missing id")
}
q := internalstripe.New(a.DB)
switch evt.EventType {
case "price.created":
existing, err := q.GetPriceMappingByStripeID(ctx, sql.NullString{String: stripePriceID, Valid: true})
if err == nil {
// Mapping exists — touch updated_at by re-upserting with same status.
_, err = q.UpsertPriceMapping(ctx, internalstripe.UpsertPriceMappingParams{
PriceID: existing.PriceID,
StripePriceID: sql.NullString{String: stripePriceID, Valid: true},
SyncStatus: "synced",
})
if err != nil {
return "", fmt.Errorf("update existing price mapping: %w", err)
}
} else if err == sql.ErrNoRows {
// No mapping yet — webhook arrived before outbox. Can't create without price_id.
a.Logger.Warn("price.created webhook: no mapping found for stripe price (created externally?)",
slog.String("stripe_price_id", stripePriceID))
return "completed", nil
} else {
return "", fmt.Errorf("lookup price mapping by stripe id: %w", err)
}
case "price.updated":
existing, err := q.GetPriceMappingByStripeID(ctx, sql.NullString{String: stripePriceID, Valid: true})
if err == sql.ErrNoRows {
a.Logger.Warn("price.updated webhook: no mapping found",
slog.String("stripe_price_id", stripePriceID))
return "completed", nil
} else if err != nil {
return "", fmt.Errorf("lookup price mapping: %w", err)
}
if !payload.Active {
// Stripe archives prices rather than deleting them.
if err := q.MarkPriceMappingDeleted(ctx, existing.PriceID); err != nil {
return "", fmt.Errorf("mark price mapping deleted: %w", err)
}
} else {
// Active price update — bump updated_at.
_, err = q.UpsertPriceMapping(ctx, internalstripe.UpsertPriceMappingParams{
PriceID: existing.PriceID,
StripePriceID: existing.StripePriceID,
SyncStatus: existing.SyncStatus,
})
if err != nil {
return "", fmt.Errorf("update price mapping timestamp: %w", err)
}
}
default:
a.Logger.Info("price sub-event skipped",
slog.String("event_type", evt.EventType))
return "skipped", nil
}
a.Logger.Info("price webhook processed",
slog.String("event_type", evt.EventType),
slog.String("stripe_price_id", stripePriceID))
return "completed", nil
}