adding boilerplate code for db integrations; settings based on environment, and logger; pre-commit updates

This commit is contained in:
Francis Secada 2025-01-16 13:36:24 -05:00
parent 529f21f123
commit a975e2b7ec
14 changed files with 730 additions and 3 deletions

View File

@ -32,7 +32,7 @@ repos:
- repo: https://github.com/charliermarsh/ruff-pre-commit - repo: https://github.com/charliermarsh/ruff-pre-commit
# Ruff version. # Ruff version.
rev: "v0.8.3" rev: "v0.9.2"
hooks: hooks:
- id: ruff - id: ruff

View File

@ -508,7 +508,6 @@ Build Pygentic-AI from the source and intsall dependencies:
<!-- [docker-link]: https://www.docker.com/ --> <!-- [docker-link]: https://www.docker.com/ -->
**Using [docker](https://www.docker.com/):** **Using [docker](https://www.docker.com/):**
```sh ```sh
docker build -t fsecada01/Pygentic-AI . docker build -t fsecada01/Pygentic-AI .
``` ```
@ -521,7 +520,7 @@ Build Pygentic-AI from the source and intsall dependencies:
**Using [pip](https://pypi.org/project/pip/):** **Using [pip](https://pypi.org/project/pip/):**
```sh ```sh
pip install -r core_requirements.in, core_requirements.txt, dev_requirements.in, dev_requirements.txt pip install -r core_requirements.in dev_requirements.in
``` ```
<!-- SHIELDS BADGE CURRENTLY DISABLED --> <!-- SHIELDS BADGE CURRENTLY DISABLED -->
<!-- [![uv][uv-shield]][uv-link] --> <!-- [![uv][uv-shield]][uv-link] -->

View File

@ -0,0 +1,51 @@
import os
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from starlette import status
from starlette.responses import JSONResponse
from starlette.staticfiles import StaticFiles
from backend import create_app
from backend.logger import logger
from backend.settings import app_settings, debug_arg
app = create_app(debug=debug_arg, settings_obj=app_settings)
# app.logger = CustomizeLogger.make_logger(config_path)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
request: Request, exc: RequestValidationError
):
exc_str = f"{exc}".replace("\n", "; ").replace(" ", " ")
logger.error(f"{request}: {exc_str}")
content = {"status_code": 10422, "message": exc_str, "data": None}
return JSONResponse(
content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
)
class UnicornException(Exception):
def __init__(self, name: str):
self.name = name
@app.exception_handler(UnicornException)
async def unicorn_exception_handler(request: Request, exc: UnicornException):
return JSONResponse(
status_code=418,
content={
"message": f"Oops! {exc.name} did something. "
"There goes a rainbow..."
},
)
app.mount(
"/static",
StaticFiles(directory=os.path.join(app_settings.frontend_dir, "static")),
name="static",
)

View File

@ -0,0 +1,18 @@
from fastapi import FastAPI
def create_app(debug: bool = False, settings_obj=None):
"""
:param debug:
:param settings_obj:
:return:
"""
app = FastAPI(
title="API Service for 525 Ocean Parkway's Cooperative Board",
debug=debug,
)
if settings_obj:
app.settings = settings_obj
return app

38
src/backend/db/base.py Normal file
View File

@ -0,0 +1,38 @@
from typing import Any
from backend.db.db import meta
def get_base_class():
"""
Generate base SQLModel class for project
:return:
"""
from sqlmodel import SQLModel
class Base(SQLModel):
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
"""
Updated to Pydantic V2.5 results in an error with
`__pydantic_extra__` attribute not being found. This is a workaround
taken from the GitHub issues page:
https://github.com/tiangolo/sqlmodel/pull/632#discussion_r1280895115
Args:
*args: positional args as tuple; explodes into Any type
**kwargs: keyword arguments as dict; explodes into Any type
Returns:
Any
"""
new_obj = super().__new__(cls)
object.__setattr__(new_obj, "__pydantic_fields_set__", set())
object.__setattr__(new_obj, "__pydantic_extra__", {})
return new_obj
Base.metadata = meta
return Base
Base = get_base_class()
Base.__mapper_args = {"eager_defaults": True}

212
src/backend/db/db.py Normal file
View File

@ -0,0 +1,212 @@
import asyncio
import inspect
from collections.abc import AsyncGenerator
from typing import Annotated
from fastapi import Form
from pydantic import BaseModel
from sqlalchemy import MetaData, event
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.orm import merge_frozen_result, sessionmaker
from sqlalchemy.pool import NullPool
from sqlmodel import Session, create_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from backend.logger import logger
from backend.utils import get_val
# echo, create = True, True
echo, create = False, False
# echo, create = False, True
# echo, create = True, False
db_url = get_val("PROJECT_DB_URL")
schema = "pygentic_ai"
meta = MetaData(schema=schema)
def as_form(cls: type[BaseModel]):
"""
:param cls:
:return:
"""
new_params = [
inspect.Parameter(
field_name,
inspect.Parameter.POSITIONAL_ONLY,
default=model_field.default,
annotation=Annotated[
model_field.annotation, *model_field.metadata, Form()
],
)
for field_name, model_field in cls.model_fields.items()
]
async def as_form_func(**data):
# logger.debug(pformat(data))
return cls(**data)
sig = inspect.signature(as_form_func)
sig = sig.replace(parameters=new_params)
as_form_func.__signature__ = sig # type: ignore
cls.as_form = as_form_func
return cls
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""
Get async SQLA session via generator function
"""
async with async_session_maker() as session_obj:
yield session_obj
def get_sync_session():
"""
Get SQLA session via generator function
"""
with sessionmaker(db_url, expire_on_commit=False) as session:
yield session
async def run_out_of_band(
async_sessionaker: sessionmaker,
session_inst: AsyncSession,
statement,
merge_results: bool = True,
):
"""
:param async_sessionaker:
:param session_inst:
:param statement:
:param merge_results:
:return:
"""
async with async_sessionaker() as oob_session:
await oob_session.connection(
execution_options={"isolation_level": "AUTOCOMMIT"}
)
result = await oob_session.execute(statement)
if merge_results:
return (
await session_inst.run_sync(
merge_frozen_result, statement, result.freeze(), load=False
)
)()
else:
await result.close()
async def check_db_exists(engine_inst):
"""
:param engine_inst:
:return:
"""
# async with Session(engine_inst) as conn:
async with engine_inst.connect() as conn:
from sqlalchemy import inspect # noqa
tables = await conn.run_sync(
lambda sync_conn: inspect(sync_conn).get_table_names()
)
logger.info(tables)
if not tables:
return False
return True
async def create_db(
engine_inst: AsyncEngine,
create_bool: bool = False,
):
"""
:param engine_inst:
:param create_bool:
:return:
"""
url = engine_inst.url
exists = await check_db_exists(engine_inst)
if not exists or create_bool:
from backend.db.base_model import Base # noqa
async with engine_inst.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
logger.info(f"Database {url.database} created")
return engine_inst
def create_db_engine(
db_url: str, async_bool: bool = False, echo_bool: bool = False
):
"""
:param db_url:
:param async_bool:
:param echo_bool:
:return:
"""
engine_inst = (
create_engine(db_url, echo=echo_bool)
if async_bool is False
else create_async_engine(db_url, echo=echo_bool, poolclass=NullPool)
)
return engine_inst
engine = create_db_engine(db_url, async_bool=True, echo_bool=False)
@event.listens_for(engine.sync_engine, "connect", insert=True)
def set_current_schema(dbapi_connection, connection_record):
"""
This is a helper event listener to ensure that the current
schema is set when bootstrapping the database tables.
Taken from:
https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#setting-alternate
-search-paths-on-connect
:param dbapi_connection:
:param connection_record:
:return:
"""
existing_autocommit = dbapi_connection.autocommit
dbapi_connection.autocommit = True
cursor_obj = dbapi_connection.cursor()
cursor_obj.execute("SET SESSION search_path='%s'" % schema)
cursor_obj.close()
dbapi_connection.autocommit = existing_autocommit
sync_engine = create_db_engine(db_url, echo_bool=echo)
session_inst = Session(sync_engine, expire_on_commit=False)
async_session_maker = sessionmaker(
bind=engine, class_=AsyncSession, expire_on_commit=False
)
session_maker = sessionmaker(
bind=engine, class_=Session, expire_on_commit=False
)
if __name__ == "__main__":
import sys
from backend.db.models import * # noqa F403
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(create_db(engine_inst=engine, create_bool=True))

View File

@ -0,0 +1,114 @@
"""
Custom Logger Using Loguru
"""
import json
import logging
import sys
from datetime import date
from pathlib import Path
from loguru import logger
from backend.settings.consts import BACKEND_DIR
dir_path = Path(BACKEND_DIR).joinpath("logs")
log_file = Path(dir_path).joinpath(f"log_{date.today()}.log")
list(
map(
lambda x: logger.add(
Path(dir_path).joinpath(f"{x}_{log_file.name}"),
filter=lambda record: record["level"].name == x.upper(),
rotation="1 day",
retention="1 week",
enqueue=True,
),
["info", "debug", "error", "success", "warning", "critical"],
)
)
class InterceptHandler(logging.Handler):
loglevel_mapping = {
50: "CRITICAL",
40: "ERROR",
30: "WARNING",
20: "INFO",
10: "DEBUG",
0: "NOTSET",
}
def emit(self, record):
try:
level = logger.level(record.levelname).name
except AttributeError:
level = self.loglevel_mapping[record.levelno]
frame, depth = logging.currentframe(), 2
while frame.f_code.co_filename == logging.__file__:
frame = frame.f_back
depth += 1
log = logger.bind(request_id="app")
log.opt(depth=depth, exception=record.exc_info).log(
level, record.getMessage()
)
class CustomizeLogger:
@classmethod
def make_logger(cls, config_path: Path):
config = cls.load_logging_config(config_path)
logging_config = config.get("logger")
logger = cls.customize_logging(
logging_config.get("path"),
level=logging_config.get("level"),
retention=logging_config.get("retention"),
rotation=logging_config.get("rotation"),
format=logging_config.get("format"),
)
return logger
@classmethod
def customize_logging(
cls,
filepath: Path,
level: str,
rotation: str,
retention: str,
format: str,
):
logger.remove()
logger.add(
sys.stdout,
enqueue=True,
backtrace=True,
level=level.upper(),
format=format,
)
logger.add(
str(filepath),
rotation=rotation,
retention=retention,
enqueue=True,
backtrace=True,
level=level.upper(),
format=format,
)
logging.basicConfig(handlers=[InterceptHandler()], level=0)
logging.getLogger("uvicorn.access").handlers = [InterceptHandler()]
for _log in ["uvicorn", "uvicorn.error", "fastapi"]:
_logger = logging.getLogger(_log)
_logger.handlers = [InterceptHandler()]
return logger.bind(request_id=None, method=None)
@classmethod
def load_logging_config(cls, config_path):
config = None
with open(config_path) as config_file:
config = json.load(config_file)
return config

View File

@ -0,0 +1,31 @@
import enum
import os
from functools import lru_cache
from backend.settings.dev import Settings as DevSettings
from backend.settings.prod import Settings as ProdSettings
from backend.utils import get_val
server_types = enum.StrEnum(
"ServerTypes", {x.upper(): x for x in ("dev", "uat", "staging", "prod")}
)
@lru_cache
def get_settings(server: server_types = "dev", debug: bool = False):
if server == "dev":
settings = DevSettings()
else:
settings = ProdSettings()
if debug:
settings.DEBUG = debug
return settings
debug_arg = get_val(val="DEBUG", default=False, cast=bool)
env_arg = get_val("SERVER_ENV", default="dev")
app_settings = get_settings(server=env_arg, debug=debug_arg)
os.environ["PROJECT_DB_URL"] = app_settings.DB_URL

View File

@ -0,0 +1,113 @@
"""
backend configuration variables for Django Database settings. Import these \
into the appropriate settings file and then fit them into the value slot of \
the DATABASES variable
"""
import os
from dataclasses import dataclass
from pathlib import Path
from backend.settings.consts import BASE_DIR, all_dialects
from backend.utils import get_val
@dataclass
class DBConfs:
name: str | None
user: str | None
password: str | None
host: str | None
port: str | int | None
# rebased for Flask and Tortoise ORM
def get_db_val(
db_configs: DBConfs,
dialect: all_dialects = "sqlite",
ssl_dict: dict | None = None,
**kwargs,
):
"""
This utility function generates an SQL DB url based on provided parameters.
Dialects are checked against a whitelist composed of accepted entries from
SQLA 2.0's documentation. Further dialects to be added from MSSQL and
Oracle DBs.
Parameters
----------
db_configs: DBConfs: dataclass
dialect: Literal[all_dialects]: defaults to `sqlite`
ssl_dict: Optional[dict]: defaults to None
kwargs: keyword_args that are compacted
Returns url: str
-------
"""
if isinstance(db_configs.host, Path):
pass
schema_sep = "://" if "sqlite" not in dialect else ":///"
if "sqlite" in dialect and not any(
[
hasattr(db_configs, x)
for x in ("name", "user", "password", "host", "port", "kwargs")
]
):
url = f"{dialect}:{schema_sep}{db_configs.host}"
else:
url = (
f"{dialect}{schema_sep}{db_configs.user}:{db_configs.password}"
f"@{db_configs.host}:{db_configs.port}"
f"/{db_configs.name}"
)
if ssl_dict:
ssl_uri = "&".join([f"{k}={v}" for k, v in ssl_dict.items()])
url = f"{url}?{ssl_uri}"
if kwargs:
uri = "&".join([f"{k}={v}" for k, v in kwargs.items()])
url = f"{url}?{uri}"
return url
cloud_backend = {
"name": get_val("CLOUD_DB_DB"),
"user": get_val("CLOUD_DB_UN"),
"password": get_val("CLOUD_DB_PW"),
"host": get_val("CLOUD_DB_HOST"),
"port": get_val("CLOUD_DB_PORT"),
}
local_db = {
"name": get_val("LOCAL_DB_DB"),
"user": get_val("LOCAL_DB_UN"),
"password": get_val("LOCAL_DB_PW"),
"host": get_val("LOCAL_DB_HOST"),
"port": get_val("LOCAL_DB_PORT"),
}
dialect = "postgresql+asyncpg"
local_db_config = get_db_val(db_configs=DBConfs(**local_db), dialect=dialect)
# local_db_config = get_db_val(**local_db, dialect="postgresql")
cloud_db_config = get_db_val(
db_configs=DBConfs(**cloud_backend),
dialect=dialect,
)
sqlite_db_config = get_db_val(
db_configs=DBConfs(
name=None,
user=None,
password=None,
port=None,
host=os.path.join(BASE_DIR, "backend", "db.sqlite"),
),
dialect="sqlite",
)

View File

@ -0,0 +1,11 @@
from pydantic_settings import BaseSettings
from backend.settings.consts import BACKEND_DIR, FRONTEND_DIR
class Settings(BaseSettings):
backend_dir: str = BACKEND_DIR
frontend_dir: str = FRONTEND_DIR
DB_URL: str = "sqlite+aiosqlite:///backend/pygentic_ai.sqlite3"
DEBUG: bool = False

View File

@ -0,0 +1,56 @@
import enum
import os
from decouple import config
BASE_DIR = os.path.join(os.path.dirname(os.path.dirname("__name__")))
BACKEND_DIR = os.path.join(BASE_DIR, "backend")
FRONTEND_DIR = os.path.join(BASE_DIR, "frontend")
pg_dialects = [
"postgres",
"postgresql",
"postgresql+asyncpg",
"postgresql+pg8000",
"postgresql+psycopg",
"postgresql+psycopg2",
"postgresql+psycopg2cffi",
"postgresql+py-postgresql",
"postgresql+pygresql",
]
mysql_dialects = [
"mysql+mysqldb",
"mysql+pymysql",
"mariadb+mariadbconnector",
"mysql+mysqlconnector",
"mysql+asyncmy",
"mysql+aiomysql",
"mysql+cymysql",
"mysql+pyodbc",
]
sqlite_dialects = [
"sqlite",
"sqlite+pysqlite",
"sqlite+aiosqlite",
"sqlite+pysqlcipher",
]
all_dialects = enum.Enum(
"DatabaseDialect",
{
x.upper(): x
for x in [
*[
y
for x in (pg_dialects, mysql_dialects, sqlite_dialects)
for y in x
]
]
},
)
# TODO: include Oracle and MSSQL dialects
SECRET_KEY: str = config("SECRET_KEY")

View File

@ -0,0 +1,8 @@
from backend.settings.backend_options import local_db_config
from backend.settings.base import Settings as BaseSettings
class Settings(BaseSettings):
DB_URL: str = local_db_config
SQLALCHEMY_DATABASE_URL: str = local_db_config
DEBUG: bool = True

View File

@ -0,0 +1,8 @@
from backend.settings.backend_options import cloud_db_config
from backend.settings.base import Settings as BaseSettings
class Settings(BaseSettings):
DB_URL: str = cloud_db_config
SQLALCHEMY_DATABASE_URI: str = cloud_db_config
DEBUG: bool = False

View File

@ -0,0 +1,68 @@
import os
from decouple import config
def get_db_url(env: str = "dev"):
"""
:param env:
:return:
"""
if env == "dev":
un, pw, db, host, port = (
config(x)
for x in (
"LOCAL_DB_UN",
"LOCAL_DB_PW",
"LOCAL_DB_DB",
"LOCAL_DB_HOST",
"LOCAL_DB_PORT",
)
)
# url = f"mysql://{un}:{pw}@{host}:{port}/{db}"
# url = f"mariadb+pymysql://{un}:{pw}@{host}:{port}/{db}"
elif env == "prod":
un, pw, db, host, port = (
config(x)
for x in (
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_DB",
"db",
"5432",
)
)
else:
return None
url = f"postgresql://{un}:{pw}@{host}:{port}/{db}"
return url
def get_val(val: str, default: str | int | bool | None = None, **kwargs):
"""
A utility function that checks the platform environment or an .env file to
pull a key:value pair and return the value. If the value does not exist, a
ValueError will be raised
Args:
val: str
default: Union[str, None]: default is None
Returns val: str
"""
if os.environ.get(val, None) is not None:
val = os.environ.get(val)
elif config(val, None, **kwargs) is not None:
val = config(val)
elif default:
val = default
else:
raise ValueError(
f"Env Var {val} is not populated in the environment "
f"or within the configuration files"
)
return val