diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b0173f6..2c15c15 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: "v0.8.3" + rev: "v0.9.2" hooks: - id: ruff diff --git a/README.md b/README.md index f901133..e7ebdeb 100644 --- a/README.md +++ b/README.md @@ -508,7 +508,6 @@ Build Pygentic-AI from the source and intsall dependencies: **Using [docker](https://www.docker.com/):** - ```sh ❯ docker build -t fsecada01/Pygentic-AI . ``` @@ -521,7 +520,7 @@ Build Pygentic-AI from the source and intsall dependencies: **Using [pip](https://pypi.org/project/pip/):** ```sh - ❯ pip install -r core_requirements.in, core_requirements.txt, dev_requirements.in, dev_requirements.txt + ❯ pip install -r core_requirements.in dev_requirements.in ``` diff --git a/src/app.py b/src/app.py index e69de29..c6180a5 100644 --- a/src/app.py +++ b/src/app.py @@ -0,0 +1,51 @@ +import os + +from fastapi import Request +from fastapi.exceptions import RequestValidationError +from starlette import status +from starlette.responses import JSONResponse +from starlette.staticfiles import StaticFiles + +from backend import create_app +from backend.logger import logger +from backend.settings import app_settings, debug_arg + +app = create_app(debug=debug_arg, settings_obj=app_settings) + +# app.logger = CustomizeLogger.make_logger(config_path) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler( + request: Request, exc: RequestValidationError +): + exc_str = f"{exc}".replace("\n", "; ").replace(" ", " ") + logger.error(f"{request}: {exc_str}") + content = {"status_code": 10422, "message": exc_str, "data": None} + + return JSONResponse( + content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY + ) + + +class UnicornException(Exception): + def __init__(self, name: str): + self.name = name + + +@app.exception_handler(UnicornException) +async def unicorn_exception_handler(request: Request, exc: UnicornException): + return JSONResponse( + status_code=418, + content={ + "message": f"Oops! {exc.name} did something. " + "There goes a rainbow..." + }, + ) + + +app.mount( + "/static", + StaticFiles(directory=os.path.join(app_settings.frontend_dir, "static")), + name="static", +) diff --git a/src/backend/__init__.py b/src/backend/__init__.py index e69de29..221c97c 100644 --- a/src/backend/__init__.py +++ b/src/backend/__init__.py @@ -0,0 +1,18 @@ +from fastapi import FastAPI + + +def create_app(debug: bool = False, settings_obj=None): + """ + + :param debug: + :param settings_obj: + :return: + """ + app = FastAPI( + title="API Service for 525 Ocean Parkway's Cooperative Board", + debug=debug, + ) + if settings_obj: + app.settings = settings_obj + + return app diff --git a/src/backend/db/base.py b/src/backend/db/base.py new file mode 100644 index 0000000..f492664 --- /dev/null +++ b/src/backend/db/base.py @@ -0,0 +1,38 @@ +from typing import Any + +from backend.db.db import meta + + +def get_base_class(): + """ + Generate base SQLModel class for project + :return: + """ + from sqlmodel import SQLModel + + class Base(SQLModel): + def __new__(cls, *args: Any, **kwargs: Any) -> Any: + """ + Updated to Pydantic V2.5 results in an error with + `__pydantic_extra__` attribute not being found. This is a workaround + taken from the GitHub issues page: + https://github.com/tiangolo/sqlmodel/pull/632#discussion_r1280895115 + Args: + *args: positional args as tuple; explodes into Any type + **kwargs: keyword arguments as dict; explodes into Any type + + Returns: + Any + """ + new_obj = super().__new__(cls) + object.__setattr__(new_obj, "__pydantic_fields_set__", set()) + object.__setattr__(new_obj, "__pydantic_extra__", {}) + return new_obj + + Base.metadata = meta + + return Base + + +Base = get_base_class() +Base.__mapper_args = {"eager_defaults": True} diff --git a/src/backend/db/db.py b/src/backend/db/db.py new file mode 100644 index 0000000..f833602 --- /dev/null +++ b/src/backend/db/db.py @@ -0,0 +1,212 @@ +import asyncio +import inspect +from collections.abc import AsyncGenerator +from typing import Annotated + +from fastapi import Form +from pydantic import BaseModel +from sqlalchemy import MetaData, event +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy.orm import merge_frozen_result, sessionmaker +from sqlalchemy.pool import NullPool +from sqlmodel import Session, create_engine +from sqlmodel.ext.asyncio.session import AsyncSession + +from backend.logger import logger +from backend.utils import get_val + +# echo, create = True, True +echo, create = False, False +# echo, create = False, True +# echo, create = True, False + +db_url = get_val("PROJECT_DB_URL") + +schema = "pygentic_ai" +meta = MetaData(schema=schema) + + +def as_form(cls: type[BaseModel]): + """ + + :param cls: + :return: + """ + new_params = [ + inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_ONLY, + default=model_field.default, + annotation=Annotated[ + model_field.annotation, *model_field.metadata, Form() + ], + ) + for field_name, model_field in cls.model_fields.items() + ] + + async def as_form_func(**data): + # logger.debug(pformat(data)) + return cls(**data) + + sig = inspect.signature(as_form_func) + sig = sig.replace(parameters=new_params) + as_form_func.__signature__ = sig # type: ignore + cls.as_form = as_form_func + + return cls + + +async def get_async_session() -> AsyncGenerator[AsyncSession, None]: + """ + Get async SQLA session via generator function + """ + async with async_session_maker() as session_obj: + yield session_obj + + +def get_sync_session(): + """ + Get SQLA session via generator function + """ + with sessionmaker(db_url, expire_on_commit=False) as session: + yield session + + +async def run_out_of_band( + async_sessionaker: sessionmaker, + session_inst: AsyncSession, + statement, + merge_results: bool = True, +): + """ + + :param async_sessionaker: + :param session_inst: + :param statement: + :param merge_results: + :return: + """ + async with async_sessionaker() as oob_session: + await oob_session.connection( + execution_options={"isolation_level": "AUTOCOMMIT"} + ) + + result = await oob_session.execute(statement) + + if merge_results: + return ( + await session_inst.run_sync( + merge_frozen_result, statement, result.freeze(), load=False + ) + )() + else: + await result.close() + + +async def check_db_exists(engine_inst): + """ + + :param engine_inst: + :return: + """ + # async with Session(engine_inst) as conn: + async with engine_inst.connect() as conn: + from sqlalchemy import inspect # noqa + + tables = await conn.run_sync( + lambda sync_conn: inspect(sync_conn).get_table_names() + ) + logger.info(tables) + if not tables: + return False + return True + + +async def create_db( + engine_inst: AsyncEngine, + create_bool: bool = False, +): + """ + + :param engine_inst: + :param create_bool: + :return: + """ + url = engine_inst.url + exists = await check_db_exists(engine_inst) + if not exists or create_bool: + from backend.db.base_model import Base # noqa + + async with engine_inst.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + logger.info(f"Database {url.database} created") + + return engine_inst + + +def create_db_engine( + db_url: str, async_bool: bool = False, echo_bool: bool = False +): + """ + + :param db_url: + :param async_bool: + :param echo_bool: + :return: + """ + engine_inst = ( + create_engine(db_url, echo=echo_bool) + if async_bool is False + else create_async_engine(db_url, echo=echo_bool, poolclass=NullPool) + ) + + return engine_inst + + +engine = create_db_engine(db_url, async_bool=True, echo_bool=False) + + +@event.listens_for(engine.sync_engine, "connect", insert=True) +def set_current_schema(dbapi_connection, connection_record): + """ + This is a helper event listener to ensure that the current + schema is set when bootstrapping the database tables. + Taken from: + + https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#setting-alternate + -search-paths-on-connect + + :param dbapi_connection: + :param connection_record: + :return: + """ + existing_autocommit = dbapi_connection.autocommit + dbapi_connection.autocommit = True + cursor_obj = dbapi_connection.cursor() + cursor_obj.execute("SET SESSION search_path='%s'" % schema) + cursor_obj.close() + dbapi_connection.autocommit = existing_autocommit + + +sync_engine = create_db_engine(db_url, echo_bool=echo) + +session_inst = Session(sync_engine, expire_on_commit=False) + +async_session_maker = sessionmaker( + bind=engine, class_=AsyncSession, expire_on_commit=False +) + +session_maker = sessionmaker( + bind=engine, class_=Session, expire_on_commit=False +) + +if __name__ == "__main__": + import sys + + from backend.db.models import * # noqa F403 + + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + + asyncio.run(create_db(engine_inst=engine, create_bool=True)) diff --git a/src/backend/logger.py b/src/backend/logger.py index e69de29..5e38fd9 100644 --- a/src/backend/logger.py +++ b/src/backend/logger.py @@ -0,0 +1,114 @@ +""" +Custom Logger Using Loguru +""" + +import json +import logging +import sys +from datetime import date +from pathlib import Path + +from loguru import logger + +from backend.settings.consts import BACKEND_DIR + +dir_path = Path(BACKEND_DIR).joinpath("logs") + +log_file = Path(dir_path).joinpath(f"log_{date.today()}.log") + +list( + map( + lambda x: logger.add( + Path(dir_path).joinpath(f"{x}_{log_file.name}"), + filter=lambda record: record["level"].name == x.upper(), + rotation="1 day", + retention="1 week", + enqueue=True, + ), + ["info", "debug", "error", "success", "warning", "critical"], + ) +) + + +class InterceptHandler(logging.Handler): + loglevel_mapping = { + 50: "CRITICAL", + 40: "ERROR", + 30: "WARNING", + 20: "INFO", + 10: "DEBUG", + 0: "NOTSET", + } + + def emit(self, record): + try: + level = logger.level(record.levelname).name + except AttributeError: + level = self.loglevel_mapping[record.levelno] + + frame, depth = logging.currentframe(), 2 + while frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + log = logger.bind(request_id="app") + log.opt(depth=depth, exception=record.exc_info).log( + level, record.getMessage() + ) + + +class CustomizeLogger: + @classmethod + def make_logger(cls, config_path: Path): + config = cls.load_logging_config(config_path) + logging_config = config.get("logger") + + logger = cls.customize_logging( + logging_config.get("path"), + level=logging_config.get("level"), + retention=logging_config.get("retention"), + rotation=logging_config.get("rotation"), + format=logging_config.get("format"), + ) + return logger + + @classmethod + def customize_logging( + cls, + filepath: Path, + level: str, + rotation: str, + retention: str, + format: str, + ): + logger.remove() + logger.add( + sys.stdout, + enqueue=True, + backtrace=True, + level=level.upper(), + format=format, + ) + logger.add( + str(filepath), + rotation=rotation, + retention=retention, + enqueue=True, + backtrace=True, + level=level.upper(), + format=format, + ) + logging.basicConfig(handlers=[InterceptHandler()], level=0) + logging.getLogger("uvicorn.access").handlers = [InterceptHandler()] + for _log in ["uvicorn", "uvicorn.error", "fastapi"]: + _logger = logging.getLogger(_log) + _logger.handlers = [InterceptHandler()] + + return logger.bind(request_id=None, method=None) + + @classmethod + def load_logging_config(cls, config_path): + config = None + with open(config_path) as config_file: + config = json.load(config_file) + return config diff --git a/src/backend/settings/__init__.py b/src/backend/settings/__init__.py index e69de29..440ba65 100644 --- a/src/backend/settings/__init__.py +++ b/src/backend/settings/__init__.py @@ -0,0 +1,31 @@ +import enum +import os +from functools import lru_cache + +from backend.settings.dev import Settings as DevSettings +from backend.settings.prod import Settings as ProdSettings +from backend.utils import get_val + +server_types = enum.StrEnum( + "ServerTypes", {x.upper(): x for x in ("dev", "uat", "staging", "prod")} +) + + +@lru_cache +def get_settings(server: server_types = "dev", debug: bool = False): + if server == "dev": + settings = DevSettings() + else: + settings = ProdSettings() + + if debug: + settings.DEBUG = debug + + return settings + + +debug_arg = get_val(val="DEBUG", default=False, cast=bool) +env_arg = get_val("SERVER_ENV", default="dev") + +app_settings = get_settings(server=env_arg, debug=debug_arg) +os.environ["PROJECT_DB_URL"] = app_settings.DB_URL diff --git a/src/backend/settings/backend_options.py b/src/backend/settings/backend_options.py new file mode 100644 index 0000000..9fe2c4e --- /dev/null +++ b/src/backend/settings/backend_options.py @@ -0,0 +1,113 @@ +""" +backend configuration variables for Django Database settings. Import these \ +into the appropriate settings file and then fit them into the value slot of \ +the DATABASES variable +""" + +import os +from dataclasses import dataclass +from pathlib import Path + +from backend.settings.consts import BASE_DIR, all_dialects +from backend.utils import get_val + + +@dataclass +class DBConfs: + name: str | None + user: str | None + password: str | None + host: str | None + port: str | int | None + + +# rebased for Flask and Tortoise ORM + + +def get_db_val( + db_configs: DBConfs, + dialect: all_dialects = "sqlite", + ssl_dict: dict | None = None, + **kwargs, +): + """ + This utility function generates an SQL DB url based on provided parameters. + Dialects are checked against a whitelist composed of accepted entries from + SQLA 2.0's documentation. Further dialects to be added from MSSQL and + Oracle DBs. + + Parameters + ---------- + db_configs: DBConfs: dataclass + dialect: Literal[all_dialects]: defaults to `sqlite` + ssl_dict: Optional[dict]: defaults to None + kwargs: keyword_args that are compacted + + Returns url: str + ------- + """ + + if isinstance(db_configs.host, Path): + pass + + schema_sep = "://" if "sqlite" not in dialect else ":///" + + if "sqlite" in dialect and not any( + [ + hasattr(db_configs, x) + for x in ("name", "user", "password", "host", "port", "kwargs") + ] + ): + url = f"{dialect}:{schema_sep}{db_configs.host}" + else: + url = ( + f"{dialect}{schema_sep}{db_configs.user}:{db_configs.password}" + f"@{db_configs.host}:{db_configs.port}" + f"/{db_configs.name}" + ) + if ssl_dict: + ssl_uri = "&".join([f"{k}={v}" for k, v in ssl_dict.items()]) + url = f"{url}?{ssl_uri}" + + if kwargs: + uri = "&".join([f"{k}={v}" for k, v in kwargs.items()]) + url = f"{url}?{uri}" + + return url + + +cloud_backend = { + "name": get_val("CLOUD_DB_DB"), + "user": get_val("CLOUD_DB_UN"), + "password": get_val("CLOUD_DB_PW"), + "host": get_val("CLOUD_DB_HOST"), + "port": get_val("CLOUD_DB_PORT"), +} + +local_db = { + "name": get_val("LOCAL_DB_DB"), + "user": get_val("LOCAL_DB_UN"), + "password": get_val("LOCAL_DB_PW"), + "host": get_val("LOCAL_DB_HOST"), + "port": get_val("LOCAL_DB_PORT"), +} + +dialect = "postgresql+asyncpg" + +local_db_config = get_db_val(db_configs=DBConfs(**local_db), dialect=dialect) +# local_db_config = get_db_val(**local_db, dialect="postgresql") +cloud_db_config = get_db_val( + db_configs=DBConfs(**cloud_backend), + dialect=dialect, +) + +sqlite_db_config = get_db_val( + db_configs=DBConfs( + name=None, + user=None, + password=None, + port=None, + host=os.path.join(BASE_DIR, "backend", "db.sqlite"), + ), + dialect="sqlite", +) diff --git a/src/backend/settings/base.py b/src/backend/settings/base.py new file mode 100644 index 0000000..023c1b2 --- /dev/null +++ b/src/backend/settings/base.py @@ -0,0 +1,11 @@ +from pydantic_settings import BaseSettings + +from backend.settings.consts import BACKEND_DIR, FRONTEND_DIR + + +class Settings(BaseSettings): + backend_dir: str = BACKEND_DIR + frontend_dir: str = FRONTEND_DIR + + DB_URL: str = "sqlite+aiosqlite:///backend/pygentic_ai.sqlite3" + DEBUG: bool = False diff --git a/src/backend/settings/consts.py b/src/backend/settings/consts.py index e69de29..b2b2a7b 100644 --- a/src/backend/settings/consts.py +++ b/src/backend/settings/consts.py @@ -0,0 +1,56 @@ +import enum +import os + +from decouple import config + +BASE_DIR = os.path.join(os.path.dirname(os.path.dirname("__name__"))) +BACKEND_DIR = os.path.join(BASE_DIR, "backend") +FRONTEND_DIR = os.path.join(BASE_DIR, "frontend") + + +pg_dialects = [ + "postgres", + "postgresql", + "postgresql+asyncpg", + "postgresql+pg8000", + "postgresql+psycopg", + "postgresql+psycopg2", + "postgresql+psycopg2cffi", + "postgresql+py-postgresql", + "postgresql+pygresql", +] + +mysql_dialects = [ + "mysql+mysqldb", + "mysql+pymysql", + "mariadb+mariadbconnector", + "mysql+mysqlconnector", + "mysql+asyncmy", + "mysql+aiomysql", + "mysql+cymysql", + "mysql+pyodbc", +] + +sqlite_dialects = [ + "sqlite", + "sqlite+pysqlite", + "sqlite+aiosqlite", + "sqlite+pysqlcipher", +] + +all_dialects = enum.Enum( + "DatabaseDialect", + { + x.upper(): x + for x in [ + *[ + y + for x in (pg_dialects, mysql_dialects, sqlite_dialects) + for y in x + ] + ] + }, +) + +# TODO: include Oracle and MSSQL dialects +SECRET_KEY: str = config("SECRET_KEY") diff --git a/src/backend/settings/dev.py b/src/backend/settings/dev.py new file mode 100644 index 0000000..9a859a7 --- /dev/null +++ b/src/backend/settings/dev.py @@ -0,0 +1,8 @@ +from backend.settings.backend_options import local_db_config +from backend.settings.base import Settings as BaseSettings + + +class Settings(BaseSettings): + DB_URL: str = local_db_config + SQLALCHEMY_DATABASE_URL: str = local_db_config + DEBUG: bool = True diff --git a/src/backend/settings/prod.py b/src/backend/settings/prod.py new file mode 100644 index 0000000..98f61d4 --- /dev/null +++ b/src/backend/settings/prod.py @@ -0,0 +1,8 @@ +from backend.settings.backend_options import cloud_db_config +from backend.settings.base import Settings as BaseSettings + + +class Settings(BaseSettings): + DB_URL: str = cloud_db_config + SQLALCHEMY_DATABASE_URI: str = cloud_db_config + DEBUG: bool = False diff --git a/src/backend/utils.py b/src/backend/utils.py index e69de29..52b8733 100644 --- a/src/backend/utils.py +++ b/src/backend/utils.py @@ -0,0 +1,68 @@ +import os + +from decouple import config + + +def get_db_url(env: str = "dev"): + """ + + :param env: + :return: + """ + if env == "dev": + un, pw, db, host, port = ( + config(x) + for x in ( + "LOCAL_DB_UN", + "LOCAL_DB_PW", + "LOCAL_DB_DB", + "LOCAL_DB_HOST", + "LOCAL_DB_PORT", + ) + ) + + # url = f"mysql://{un}:{pw}@{host}:{port}/{db}" + # url = f"mariadb+pymysql://{un}:{pw}@{host}:{port}/{db}" + elif env == "prod": + un, pw, db, host, port = ( + config(x) + for x in ( + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "POSTGRES_DB", + "db", + "5432", + ) + ) + else: + return None + + url = f"postgresql://{un}:{pw}@{host}:{port}/{db}" + return url + + +def get_val(val: str, default: str | int | bool | None = None, **kwargs): + """ + A utility function that checks the platform environment or an .env file to + pull a key:value pair and return the value. If the value does not exist, a + ValueError will be raised + Args: + val: str + default: Union[str, None]: default is None + + Returns val: str + + """ + if os.environ.get(val, None) is not None: + val = os.environ.get(val) + elif config(val, None, **kwargs) is not None: + val = config(val) + elif default: + val = default + else: + raise ValueError( + f"Env Var {val} is not populated in the environment " + f"or within the configuration files" + ) + + return val