From 300f41630fbf6142fba467398cf8105545ec6029 Mon Sep 17 00:00:00 2001 From: Christian Galo Date: Thu, 1 May 2025 01:34:25 -0500 Subject: [PATCH] Add decompression middleware and tests for gzip handling --- .gitignore | 3 +- internal/middleware/decompress.go | 77 ++++++++++++ internal/middleware/tests/decompress_test.go | 124 +++++++++++++++++++ 3 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 internal/middleware/decompress.go create mode 100644 internal/middleware/tests/decompress_test.go diff --git a/.gitignore b/.gitignore index 0c85462..dcb355d 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,5 @@ tmp/* member-console # Ignore config files -site/member-console.yaml \ No newline at end of file +site/member-console.yaml +**/.claude/settings.local.json diff --git a/internal/middleware/decompress.go b/internal/middleware/decompress.go new file mode 100644 index 0000000..3b04cb0 --- /dev/null +++ b/internal/middleware/decompress.go @@ -0,0 +1,77 @@ +package middleware + +import ( + "compress/gzip" + "io" + "net/http" + "strings" +) + +// DecompressOptions configures the decompression middleware +type DecompressOptions struct { + // MaxSize is the maximum size of the decompressed body in bytes + // Zero means no limit + MaxSize int64 +} + +// DefaultDecompressOptions provides sensible default options +func DefaultDecompressOptions() *DecompressOptions { + return &DecompressOptions{ + MaxSize: 10 << 20, // 10MB + } +} + +// Decompress creates a middleware that decompresses HTTP requests with gzip content encoding +func Decompress(opts *DecompressOptions) Middleware { + // Use default options if none provided + if opts == nil { + opts = DefaultDecompressOptions() + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check if the request has Content-Encoding header + contentEncoding := r.Header.Get("Content-Encoding") + + // If not compressed, pass through + if contentEncoding == "" { + next.ServeHTTP(w, r) + return + } + + // Check if gzip encoded + if strings.Contains(contentEncoding, "gzip") { + // Create a gzip reader + gz, err := gzip.NewReader(r.Body) + if err != nil { + http.Error(w, "Invalid gzip body", http.StatusBadRequest) + return + } + defer gz.Close() + + // Add limit reader if max size is specified + var bodyReader io.Reader = gz + if opts.MaxSize > 0 { + bodyReader = io.LimitReader(gz, opts.MaxSize) + } + + // Replace the body with a decompressed reader + r.Body = io.NopCloser(bodyReader) + + // Remove the content-encoding header to signal that the body is now decompressed + r.Header.Del("Content-Encoding") + + // Adjust content length because the body has been decompressed + r.Header.Del("Content-Length") + } + + // Call the next handler with the decompressed body + next.ServeHTTP(w, r) + }) + } +} + +// DecompressDefault creates a middleware that decompresses HTTP requests with default options +func DecompressDefault() Middleware { + return Decompress(nil) +} \ No newline at end of file diff --git a/internal/middleware/tests/decompress_test.go b/internal/middleware/tests/decompress_test.go new file mode 100644 index 0000000..52b0f85 --- /dev/null +++ b/internal/middleware/tests/decompress_test.go @@ -0,0 +1,124 @@ +package tests + +import ( + "bytes" + "compress/gzip" + "io" + "net/http" + "net/http/httptest" + "testing" + + "git.coopcloud.tech/wiki-cafe/member-console/internal/middleware" +) + +func TestDecompressMiddleware(t *testing.T) { + // Create a simple handler that reads the request body and returns it + testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Error reading body", http.StatusInternalServerError) + return + } + w.Write(body) + }) + + // Apply decompression middleware + handler := middleware.DecompressDefault()(testHandler) + + t.Run("No compression", func(t *testing.T) { + // Create a request with no compression + testData := []byte("test data with no compression") + req := httptest.NewRequest("POST", "/", bytes.NewReader(testData)) + rr := httptest.NewRecorder() + + // Call the handler + handler.ServeHTTP(rr, req) + + // Check the response + if rr.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) + } + if !bytes.Equal(testData, rr.Body.Bytes()) { + t.Errorf("Response body does not match original data") + } + }) + + t.Run("Gzip compression", func(t *testing.T) { + // Create gzip compressed data + testData := []byte("test data with gzip compression") + var buf bytes.Buffer + gzWriter := gzip.NewWriter(&buf) + _, err := gzWriter.Write(testData) + if err != nil { + t.Fatal(err) + } + if err := gzWriter.Close(); err != nil { + t.Fatal(err) + } + + // Create a request with gzip compression + req := httptest.NewRequest("POST", "/", &buf) + req.Header.Set("Content-Encoding", "gzip") + rr := httptest.NewRecorder() + + // Call the handler + handler.ServeHTTP(rr, req) + + // Check the response + if rr.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) + } + if !bytes.Equal(testData, rr.Body.Bytes()) { + t.Errorf("Response body does not match original data") + } + }) + + t.Run("Invalid gzip data", func(t *testing.T) { + // Create invalid gzip data + testData := []byte("this is not valid gzip data") + req := httptest.NewRequest("POST", "/", bytes.NewReader(testData)) + req.Header.Set("Content-Encoding", "gzip") + rr := httptest.NewRecorder() + + // Call the handler + handler.ServeHTTP(rr, req) + + // Check that we get a bad request + if rr.Code != http.StatusBadRequest { + t.Errorf("Expected status code %d, got %d", http.StatusBadRequest, rr.Code) + } + }) + + t.Run("Size limit", func(t *testing.T) { + // Create oversized data + testData := bytes.Repeat([]byte("a"), 11<<20) // 11MB + var buf bytes.Buffer + gzWriter := gzip.NewWriter(&buf) + _, err := gzWriter.Write(testData) + if err != nil { + t.Fatal(err) + } + if err := gzWriter.Close(); err != nil { + t.Fatal(err) + } + + // Create a limited decompression middleware (10MB limit) + limitedHandler := middleware.Decompress(nil)(testHandler) + + // Create a request with gzip compression + req := httptest.NewRequest("POST", "/", &buf) + req.Header.Set("Content-Encoding", "gzip") + rr := httptest.NewRecorder() + + // Call the handler + limitedHandler.ServeHTTP(rr, req) + + // The request should process but the body will be truncated + if rr.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, rr.Code) + } + if len(rr.Body.Bytes()) >= 11<<20 { + t.Errorf("Expected response to be truncated to less than 11MB, got %d bytes", len(rr.Body.Bytes())) + } + }) +} \ No newline at end of file