diff --git a/src/backend/core/tools.py b/src/backend/core/tools.py index 6dee7d9..0e6fabe 100644 --- a/src/backend/core/tools.py +++ b/src/backend/core/tools.py @@ -25,7 +25,11 @@ async def fetch_website_content( :return: str """ logger.info(f"Fetching website content for: {url}") - async with httpx.AsyncClient(follow_redirects=True) as http_client: + # Set reasonable timeouts: 10s connect, 30s total + timeout = httpx.Timeout(30.0, connect=10.0) + async with httpx.AsyncClient( + follow_redirects=True, timeout=timeout + ) as http_client: try: response = await http_client.get(url) response.raise_for_status() diff --git a/src/backend/site/router.py b/src/backend/site/router.py index 7550d25..9683c63 100644 --- a/src/backend/site/router.py +++ b/src/backend/site/router.py @@ -1,10 +1,10 @@ import asyncio import os -from fastapi import APIRouter, Form, Request -from fastapi.staticfiles import StaticFiles +from fastapi import APIRouter, Form, HTTPException, Request from fastapi.templating import Jinja2Templates from jinjax import Catalog, JinjaX +from pydantic import BaseModel, Field, field_validator from starlette.responses import HTMLResponse, Response, StreamingResponse from backend.core.pdf_cache import pdf_cache @@ -36,11 +36,37 @@ list( ), ) -user_frontend.mount( - "/static", - StaticFiles(directory=os.path.join(frontend, "static")), - name="static", -) +# Note: Static files are mounted in app.py, not here (avoid duplicate mounts) + + +class AnalysisInput(BaseModel): + """Validation model for analysis form inputs.""" + + primary_entity: str = Field( + ..., + min_length=1, + max_length=500, + description="Primary entity (company name or URL)", + ) + comparison_entities: str = Field( + default="", + max_length=2000, + description="Comma-separated comparison entities (optional)", + ) + + @field_validator("primary_entity") + @classmethod + def validate_primary_entity(cls, v: str) -> str: + """Ensure primary entity is not just whitespace.""" + if not v.strip(): + raise ValueError("Primary entity cannot be empty") + return v.strip() + + @field_validator("comparison_entities") + @classmethod + def validate_comparison_entities(cls, v: str) -> str: + """Strip whitespace from comparison entities.""" + return v.strip() @user_frontend.post("/analyze", response_class=HTMLResponse) @@ -51,11 +77,22 @@ async def analyze_url( ) -> HTMLResponse: """ Kick off a SWOT analysis for one or more entities. - :param request: + + :param request: FastAPI request object :param primary_entity: main subject (URL or company name) :param comparison_entities: comma-separated competitors (optional) - :return: + :return: Empty HTML response (HTMX polling handles rendering) """ + # Validate inputs using Pydantic model + try: + validated = AnalysisInput( + primary_entity=primary_entity, + comparison_entities=comparison_entities, + ) + except ValueError as e: + logger.error(f"Validation error: {e}") + raise HTTPException(status_code=422, detail=str(e)) from e + session_id = str(id(request)) request.session["analysis_id"] = session_id request.session["start_time"] = asyncio.get_event_loop().time() @@ -66,17 +103,20 @@ async def analyze_url( status_store[session_id].append(ANALYZING_MESSAGE) + # Use validated inputs comp_entities = [ - e.strip() for e in comparison_entities.split(",") if e.strip() + e.strip() for e in validated.comparison_entities.split(",") if e.strip() ] logger.info( f"Starting analysis — session: {session_id}, " - f"primary: {primary_entity}, comparing: {comp_entities}" + f"primary: {validated.primary_entity}, comparing: {comp_entities}" ) task = asyncio.create_task( - run_agent_with_progress(session_id, primary_entity, comp_entities) + run_agent_with_progress( + session_id, validated.primary_entity, comp_entities + ) ) running_tasks.add(task) task.add_done_callback(running_tasks.discard) @@ -211,12 +251,20 @@ async def download_pdf(request: Request) -> StreamingResponse: logger.info(f"Serving cached PDF for session: {session_id}") pdf_buffer = cached_pdf else: - # Generate new PDF - logger.info(f"Generating new PDF for session: {session_id}") - pdf_buffer = generate_swot_pdf(result) + # Generate new PDF with error handling + try: + logger.info(f"Generating new PDF for session: {session_id}") + pdf_buffer = generate_swot_pdf(result) - # Cache the generated PDF - pdf_cache.set(session_id, result, pdf_buffer) + # Cache the generated PDF + pdf_cache.set(session_id, result, pdf_buffer) + except Exception as e: + logger.error(f"PDF generation failed for session {session_id}: {e}") + return Response( + content=b"Failed to generate PDF report. Please try again or contact support.", + media_type="text/plain", + status_code=500, + ) # Prepare filename with company names and date import re diff --git a/src/backend/site/utils.py b/src/backend/site/utils.py index 1909d03..2e60c85 100644 --- a/src/backend/site/utils.py +++ b/src/backend/site/utils.py @@ -1,6 +1,5 @@ import asyncio import random -import time from pprint import pformat from typing import Any @@ -15,11 +14,14 @@ from backend.site.consts import ( ) -def emulate_tool_completion(session_id: str, message: str) -> None: - """Pydantic AI doesn't provide a post-processing hook, so we need to emulate one.""" +async def emulate_tool_completion(session_id: str, message: str) -> None: + """ + Emulate tool completion with random delay. - # Sleep a random amount of time between 0 and 5 seconds - time.sleep(random.randint(0, 5)) + Uses asyncio.sleep to avoid blocking the event loop. + """ + # Sleep a random amount of time between 0 and 5 seconds (async) + await asyncio.sleep(random.randint(0, 5)) status_store[session_id].append(message) @@ -49,13 +51,8 @@ async def update_status(session_id: str, message: Any) -> None: if message == ANALYSIS_COMPLETE_MESSAGE: status_store[session_id].append(message) else: - loop = asyncio.get_running_loop() - await loop.run_in_executor( - None, - emulate_tool_completion, - session_id, - message, - ) + # Call async function directly (no need for run_in_executor) + await emulate_tool_completion(session_id, message) logger.info( f"Status messages for session {session_id}: {status_store[session_id]}",