package middleware import ( "context" "net/http" "time" ) // TimeoutMiddleware is necessary in addition to http.Server's ReadTimeout, // WriteTimeout, and IdleTimeout. http.Server's timeouts are network-level // timeouts, while this middleware's timeout is at the application level. // TODO: Verify this statement // TimeoutMiddleware sets a timeout for each request func TimeoutMiddleware(duration time.Duration) Middleware { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Create a context with a timeout ctx, cancel := context.WithTimeout(r.Context(), duration) defer cancel() // Create a channel to signal when the request is done done := make(chan struct{}) // Create a new request with the timeout context r = r.WithContext(ctx) // Use a goroutine to run the next handler go func() { next.ServeHTTP(w, r) close(done) }() // Wait for the handler to finish or the context to timeout select { case <-done: // Request finished within the timeout return case <-ctx.Done(): // Timeout occurred, respond with a timeout error http.Error(w, "Request timed out", http.StatusGatewayTimeout) return } }) } }