mirror of
https://github.com/fsecada01/Pygentic-AI.git
synced 2025-06-18 04:56:03 +00:00
adding boilerplate code for db integrations; settings based on environment, and logger; pre-commit updates
This commit is contained in:
parent
529f21f123
commit
a975e2b7ec
@ -32,7 +32,7 @@ repos:
|
||||
|
||||
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: "v0.8.3"
|
||||
rev: "v0.9.2"
|
||||
hooks:
|
||||
- id: ruff
|
||||
|
||||
|
@ -508,7 +508,6 @@ Build Pygentic-AI from the source and intsall dependencies:
|
||||
<!-- [docker-link]: https://www.docker.com/ -->
|
||||
|
||||
**Using [docker](https://www.docker.com/):**
|
||||
|
||||
```sh
|
||||
❯ docker build -t fsecada01/Pygentic-AI .
|
||||
```
|
||||
@ -521,7 +520,7 @@ Build Pygentic-AI from the source and intsall dependencies:
|
||||
**Using [pip](https://pypi.org/project/pip/):**
|
||||
|
||||
```sh
|
||||
❯ pip install -r core_requirements.in, core_requirements.txt, dev_requirements.in, dev_requirements.txt
|
||||
❯ pip install -r core_requirements.in dev_requirements.in
|
||||
```
|
||||
<!-- SHIELDS BADGE CURRENTLY DISABLED -->
|
||||
<!-- [![uv][uv-shield]][uv-link] -->
|
||||
|
51
src/app.py
51
src/app.py
@ -0,0 +1,51 @@
|
||||
import os
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from starlette import status
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.staticfiles import StaticFiles
|
||||
|
||||
from backend import create_app
|
||||
from backend.logger import logger
|
||||
from backend.settings import app_settings, debug_arg
|
||||
|
||||
app = create_app(debug=debug_arg, settings_obj=app_settings)
|
||||
|
||||
# app.logger = CustomizeLogger.make_logger(config_path)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(
|
||||
request: Request, exc: RequestValidationError
|
||||
):
|
||||
exc_str = f"{exc}".replace("\n", "; ").replace(" ", " ")
|
||||
logger.error(f"{request}: {exc_str}")
|
||||
content = {"status_code": 10422, "message": exc_str, "data": None}
|
||||
|
||||
return JSONResponse(
|
||||
content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
)
|
||||
|
||||
|
||||
class UnicornException(Exception):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
|
||||
|
||||
@app.exception_handler(UnicornException)
|
||||
async def unicorn_exception_handler(request: Request, exc: UnicornException):
|
||||
return JSONResponse(
|
||||
status_code=418,
|
||||
content={
|
||||
"message": f"Oops! {exc.name} did something. "
|
||||
"There goes a rainbow..."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
app.mount(
|
||||
"/static",
|
||||
StaticFiles(directory=os.path.join(app_settings.frontend_dir, "static")),
|
||||
name="static",
|
||||
)
|
@ -0,0 +1,18 @@
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
def create_app(debug: bool = False, settings_obj=None):
|
||||
"""
|
||||
|
||||
:param debug:
|
||||
:param settings_obj:
|
||||
:return:
|
||||
"""
|
||||
app = FastAPI(
|
||||
title="API Service for 525 Ocean Parkway's Cooperative Board",
|
||||
debug=debug,
|
||||
)
|
||||
if settings_obj:
|
||||
app.settings = settings_obj
|
||||
|
||||
return app
|
38
src/backend/db/base.py
Normal file
38
src/backend/db/base.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Any
|
||||
|
||||
from backend.db.db import meta
|
||||
|
||||
|
||||
def get_base_class():
|
||||
"""
|
||||
Generate base SQLModel class for project
|
||||
:return:
|
||||
"""
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
class Base(SQLModel):
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Updated to Pydantic V2.5 results in an error with
|
||||
`__pydantic_extra__` attribute not being found. This is a workaround
|
||||
taken from the GitHub issues page:
|
||||
https://github.com/tiangolo/sqlmodel/pull/632#discussion_r1280895115
|
||||
Args:
|
||||
*args: positional args as tuple; explodes into Any type
|
||||
**kwargs: keyword arguments as dict; explodes into Any type
|
||||
|
||||
Returns:
|
||||
Any
|
||||
"""
|
||||
new_obj = super().__new__(cls)
|
||||
object.__setattr__(new_obj, "__pydantic_fields_set__", set())
|
||||
object.__setattr__(new_obj, "__pydantic_extra__", {})
|
||||
return new_obj
|
||||
|
||||
Base.metadata = meta
|
||||
|
||||
return Base
|
||||
|
||||
|
||||
Base = get_base_class()
|
||||
Base.__mapper_args = {"eager_defaults": True}
|
212
src/backend/db/db.py
Normal file
212
src/backend/db/db.py
Normal file
@ -0,0 +1,212 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Form
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import MetaData, event
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlalchemy.orm import merge_frozen_result, sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlmodel import Session, create_engine
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from backend.logger import logger
|
||||
from backend.utils import get_val
|
||||
|
||||
# echo, create = True, True
|
||||
echo, create = False, False
|
||||
# echo, create = False, True
|
||||
# echo, create = True, False
|
||||
|
||||
db_url = get_val("PROJECT_DB_URL")
|
||||
|
||||
schema = "pygentic_ai"
|
||||
meta = MetaData(schema=schema)
|
||||
|
||||
|
||||
def as_form(cls: type[BaseModel]):
|
||||
"""
|
||||
|
||||
:param cls:
|
||||
:return:
|
||||
"""
|
||||
new_params = [
|
||||
inspect.Parameter(
|
||||
field_name,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
default=model_field.default,
|
||||
annotation=Annotated[
|
||||
model_field.annotation, *model_field.metadata, Form()
|
||||
],
|
||||
)
|
||||
for field_name, model_field in cls.model_fields.items()
|
||||
]
|
||||
|
||||
async def as_form_func(**data):
|
||||
# logger.debug(pformat(data))
|
||||
return cls(**data)
|
||||
|
||||
sig = inspect.signature(as_form_func)
|
||||
sig = sig.replace(parameters=new_params)
|
||||
as_form_func.__signature__ = sig # type: ignore
|
||||
cls.as_form = as_form_func
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Get async SQLA session via generator function
|
||||
"""
|
||||
async with async_session_maker() as session_obj:
|
||||
yield session_obj
|
||||
|
||||
|
||||
def get_sync_session():
|
||||
"""
|
||||
Get SQLA session via generator function
|
||||
"""
|
||||
with sessionmaker(db_url, expire_on_commit=False) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def run_out_of_band(
|
||||
async_sessionaker: sessionmaker,
|
||||
session_inst: AsyncSession,
|
||||
statement,
|
||||
merge_results: bool = True,
|
||||
):
|
||||
"""
|
||||
|
||||
:param async_sessionaker:
|
||||
:param session_inst:
|
||||
:param statement:
|
||||
:param merge_results:
|
||||
:return:
|
||||
"""
|
||||
async with async_sessionaker() as oob_session:
|
||||
await oob_session.connection(
|
||||
execution_options={"isolation_level": "AUTOCOMMIT"}
|
||||
)
|
||||
|
||||
result = await oob_session.execute(statement)
|
||||
|
||||
if merge_results:
|
||||
return (
|
||||
await session_inst.run_sync(
|
||||
merge_frozen_result, statement, result.freeze(), load=False
|
||||
)
|
||||
)()
|
||||
else:
|
||||
await result.close()
|
||||
|
||||
|
||||
async def check_db_exists(engine_inst):
|
||||
"""
|
||||
|
||||
:param engine_inst:
|
||||
:return:
|
||||
"""
|
||||
# async with Session(engine_inst) as conn:
|
||||
async with engine_inst.connect() as conn:
|
||||
from sqlalchemy import inspect # noqa
|
||||
|
||||
tables = await conn.run_sync(
|
||||
lambda sync_conn: inspect(sync_conn).get_table_names()
|
||||
)
|
||||
logger.info(tables)
|
||||
if not tables:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
async def create_db(
|
||||
engine_inst: AsyncEngine,
|
||||
create_bool: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
:param engine_inst:
|
||||
:param create_bool:
|
||||
:return:
|
||||
"""
|
||||
url = engine_inst.url
|
||||
exists = await check_db_exists(engine_inst)
|
||||
if not exists or create_bool:
|
||||
from backend.db.base_model import Base # noqa
|
||||
|
||||
async with engine_inst.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
logger.info(f"Database {url.database} created")
|
||||
|
||||
return engine_inst
|
||||
|
||||
|
||||
def create_db_engine(
|
||||
db_url: str, async_bool: bool = False, echo_bool: bool = False
|
||||
):
|
||||
"""
|
||||
|
||||
:param db_url:
|
||||
:param async_bool:
|
||||
:param echo_bool:
|
||||
:return:
|
||||
"""
|
||||
engine_inst = (
|
||||
create_engine(db_url, echo=echo_bool)
|
||||
if async_bool is False
|
||||
else create_async_engine(db_url, echo=echo_bool, poolclass=NullPool)
|
||||
)
|
||||
|
||||
return engine_inst
|
||||
|
||||
|
||||
engine = create_db_engine(db_url, async_bool=True, echo_bool=False)
|
||||
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect", insert=True)
|
||||
def set_current_schema(dbapi_connection, connection_record):
|
||||
"""
|
||||
This is a helper event listener to ensure that the current
|
||||
schema is set when bootstrapping the database tables.
|
||||
Taken from:
|
||||
|
||||
https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#setting-alternate
|
||||
-search-paths-on-connect
|
||||
|
||||
:param dbapi_connection:
|
||||
:param connection_record:
|
||||
:return:
|
||||
"""
|
||||
existing_autocommit = dbapi_connection.autocommit
|
||||
dbapi_connection.autocommit = True
|
||||
cursor_obj = dbapi_connection.cursor()
|
||||
cursor_obj.execute("SET SESSION search_path='%s'" % schema)
|
||||
cursor_obj.close()
|
||||
dbapi_connection.autocommit = existing_autocommit
|
||||
|
||||
|
||||
sync_engine = create_db_engine(db_url, echo_bool=echo)
|
||||
|
||||
session_inst = Session(sync_engine, expire_on_commit=False)
|
||||
|
||||
async_session_maker = sessionmaker(
|
||||
bind=engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
session_maker = sessionmaker(
|
||||
bind=engine, class_=Session, expire_on_commit=False
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
from backend.db.models import * # noqa F403
|
||||
|
||||
if sys.platform == "win32":
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
asyncio.run(create_db(engine_inst=engine, create_bool=True))
|
@ -0,0 +1,114 @@
|
||||
"""
|
||||
Custom Logger Using Loguru
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from backend.settings.consts import BACKEND_DIR
|
||||
|
||||
dir_path = Path(BACKEND_DIR).joinpath("logs")
|
||||
|
||||
log_file = Path(dir_path).joinpath(f"log_{date.today()}.log")
|
||||
|
||||
list(
|
||||
map(
|
||||
lambda x: logger.add(
|
||||
Path(dir_path).joinpath(f"{x}_{log_file.name}"),
|
||||
filter=lambda record: record["level"].name == x.upper(),
|
||||
rotation="1 day",
|
||||
retention="1 week",
|
||||
enqueue=True,
|
||||
),
|
||||
["info", "debug", "error", "success", "warning", "critical"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
loglevel_mapping = {
|
||||
50: "CRITICAL",
|
||||
40: "ERROR",
|
||||
30: "WARNING",
|
||||
20: "INFO",
|
||||
10: "DEBUG",
|
||||
0: "NOTSET",
|
||||
}
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except AttributeError:
|
||||
level = self.loglevel_mapping[record.levelno]
|
||||
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
log = logger.bind(request_id="app")
|
||||
log.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
class CustomizeLogger:
|
||||
@classmethod
|
||||
def make_logger(cls, config_path: Path):
|
||||
config = cls.load_logging_config(config_path)
|
||||
logging_config = config.get("logger")
|
||||
|
||||
logger = cls.customize_logging(
|
||||
logging_config.get("path"),
|
||||
level=logging_config.get("level"),
|
||||
retention=logging_config.get("retention"),
|
||||
rotation=logging_config.get("rotation"),
|
||||
format=logging_config.get("format"),
|
||||
)
|
||||
return logger
|
||||
|
||||
@classmethod
|
||||
def customize_logging(
|
||||
cls,
|
||||
filepath: Path,
|
||||
level: str,
|
||||
rotation: str,
|
||||
retention: str,
|
||||
format: str,
|
||||
):
|
||||
logger.remove()
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
enqueue=True,
|
||||
backtrace=True,
|
||||
level=level.upper(),
|
||||
format=format,
|
||||
)
|
||||
logger.add(
|
||||
str(filepath),
|
||||
rotation=rotation,
|
||||
retention=retention,
|
||||
enqueue=True,
|
||||
backtrace=True,
|
||||
level=level.upper(),
|
||||
format=format,
|
||||
)
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=0)
|
||||
logging.getLogger("uvicorn.access").handlers = [InterceptHandler()]
|
||||
for _log in ["uvicorn", "uvicorn.error", "fastapi"]:
|
||||
_logger = logging.getLogger(_log)
|
||||
_logger.handlers = [InterceptHandler()]
|
||||
|
||||
return logger.bind(request_id=None, method=None)
|
||||
|
||||
@classmethod
|
||||
def load_logging_config(cls, config_path):
|
||||
config = None
|
||||
with open(config_path) as config_file:
|
||||
config = json.load(config_file)
|
||||
return config
|
@ -0,0 +1,31 @@
|
||||
import enum
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
from backend.settings.dev import Settings as DevSettings
|
||||
from backend.settings.prod import Settings as ProdSettings
|
||||
from backend.utils import get_val
|
||||
|
||||
server_types = enum.StrEnum(
|
||||
"ServerTypes", {x.upper(): x for x in ("dev", "uat", "staging", "prod")}
|
||||
)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings(server: server_types = "dev", debug: bool = False):
|
||||
if server == "dev":
|
||||
settings = DevSettings()
|
||||
else:
|
||||
settings = ProdSettings()
|
||||
|
||||
if debug:
|
||||
settings.DEBUG = debug
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
debug_arg = get_val(val="DEBUG", default=False, cast=bool)
|
||||
env_arg = get_val("SERVER_ENV", default="dev")
|
||||
|
||||
app_settings = get_settings(server=env_arg, debug=debug_arg)
|
||||
os.environ["PROJECT_DB_URL"] = app_settings.DB_URL
|
113
src/backend/settings/backend_options.py
Normal file
113
src/backend/settings/backend_options.py
Normal file
@ -0,0 +1,113 @@
|
||||
"""
|
||||
backend configuration variables for Django Database settings. Import these \
|
||||
into the appropriate settings file and then fit them into the value slot of \
|
||||
the DATABASES variable
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from backend.settings.consts import BASE_DIR, all_dialects
|
||||
from backend.utils import get_val
|
||||
|
||||
|
||||
@dataclass
|
||||
class DBConfs:
|
||||
name: str | None
|
||||
user: str | None
|
||||
password: str | None
|
||||
host: str | None
|
||||
port: str | int | None
|
||||
|
||||
|
||||
# rebased for Flask and Tortoise ORM
|
||||
|
||||
|
||||
def get_db_val(
|
||||
db_configs: DBConfs,
|
||||
dialect: all_dialects = "sqlite",
|
||||
ssl_dict: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
This utility function generates an SQL DB url based on provided parameters.
|
||||
Dialects are checked against a whitelist composed of accepted entries from
|
||||
SQLA 2.0's documentation. Further dialects to be added from MSSQL and
|
||||
Oracle DBs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
db_configs: DBConfs: dataclass
|
||||
dialect: Literal[all_dialects]: defaults to `sqlite`
|
||||
ssl_dict: Optional[dict]: defaults to None
|
||||
kwargs: keyword_args that are compacted
|
||||
|
||||
Returns url: str
|
||||
-------
|
||||
"""
|
||||
|
||||
if isinstance(db_configs.host, Path):
|
||||
pass
|
||||
|
||||
schema_sep = "://" if "sqlite" not in dialect else ":///"
|
||||
|
||||
if "sqlite" in dialect and not any(
|
||||
[
|
||||
hasattr(db_configs, x)
|
||||
for x in ("name", "user", "password", "host", "port", "kwargs")
|
||||
]
|
||||
):
|
||||
url = f"{dialect}:{schema_sep}{db_configs.host}"
|
||||
else:
|
||||
url = (
|
||||
f"{dialect}{schema_sep}{db_configs.user}:{db_configs.password}"
|
||||
f"@{db_configs.host}:{db_configs.port}"
|
||||
f"/{db_configs.name}"
|
||||
)
|
||||
if ssl_dict:
|
||||
ssl_uri = "&".join([f"{k}={v}" for k, v in ssl_dict.items()])
|
||||
url = f"{url}?{ssl_uri}"
|
||||
|
||||
if kwargs:
|
||||
uri = "&".join([f"{k}={v}" for k, v in kwargs.items()])
|
||||
url = f"{url}?{uri}"
|
||||
|
||||
return url
|
||||
|
||||
|
||||
cloud_backend = {
|
||||
"name": get_val("CLOUD_DB_DB"),
|
||||
"user": get_val("CLOUD_DB_UN"),
|
||||
"password": get_val("CLOUD_DB_PW"),
|
||||
"host": get_val("CLOUD_DB_HOST"),
|
||||
"port": get_val("CLOUD_DB_PORT"),
|
||||
}
|
||||
|
||||
local_db = {
|
||||
"name": get_val("LOCAL_DB_DB"),
|
||||
"user": get_val("LOCAL_DB_UN"),
|
||||
"password": get_val("LOCAL_DB_PW"),
|
||||
"host": get_val("LOCAL_DB_HOST"),
|
||||
"port": get_val("LOCAL_DB_PORT"),
|
||||
}
|
||||
|
||||
dialect = "postgresql+asyncpg"
|
||||
|
||||
local_db_config = get_db_val(db_configs=DBConfs(**local_db), dialect=dialect)
|
||||
# local_db_config = get_db_val(**local_db, dialect="postgresql")
|
||||
cloud_db_config = get_db_val(
|
||||
db_configs=DBConfs(**cloud_backend),
|
||||
dialect=dialect,
|
||||
)
|
||||
|
||||
sqlite_db_config = get_db_val(
|
||||
db_configs=DBConfs(
|
||||
name=None,
|
||||
user=None,
|
||||
password=None,
|
||||
port=None,
|
||||
host=os.path.join(BASE_DIR, "backend", "db.sqlite"),
|
||||
),
|
||||
dialect="sqlite",
|
||||
)
|
11
src/backend/settings/base.py
Normal file
11
src/backend/settings/base.py
Normal file
@ -0,0 +1,11 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.settings.consts import BACKEND_DIR, FRONTEND_DIR
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
backend_dir: str = BACKEND_DIR
|
||||
frontend_dir: str = FRONTEND_DIR
|
||||
|
||||
DB_URL: str = "sqlite+aiosqlite:///backend/pygentic_ai.sqlite3"
|
||||
DEBUG: bool = False
|
@ -0,0 +1,56 @@
|
||||
import enum
|
||||
import os
|
||||
|
||||
from decouple import config
|
||||
|
||||
BASE_DIR = os.path.join(os.path.dirname(os.path.dirname("__name__")))
|
||||
BACKEND_DIR = os.path.join(BASE_DIR, "backend")
|
||||
FRONTEND_DIR = os.path.join(BASE_DIR, "frontend")
|
||||
|
||||
|
||||
pg_dialects = [
|
||||
"postgres",
|
||||
"postgresql",
|
||||
"postgresql+asyncpg",
|
||||
"postgresql+pg8000",
|
||||
"postgresql+psycopg",
|
||||
"postgresql+psycopg2",
|
||||
"postgresql+psycopg2cffi",
|
||||
"postgresql+py-postgresql",
|
||||
"postgresql+pygresql",
|
||||
]
|
||||
|
||||
mysql_dialects = [
|
||||
"mysql+mysqldb",
|
||||
"mysql+pymysql",
|
||||
"mariadb+mariadbconnector",
|
||||
"mysql+mysqlconnector",
|
||||
"mysql+asyncmy",
|
||||
"mysql+aiomysql",
|
||||
"mysql+cymysql",
|
||||
"mysql+pyodbc",
|
||||
]
|
||||
|
||||
sqlite_dialects = [
|
||||
"sqlite",
|
||||
"sqlite+pysqlite",
|
||||
"sqlite+aiosqlite",
|
||||
"sqlite+pysqlcipher",
|
||||
]
|
||||
|
||||
all_dialects = enum.Enum(
|
||||
"DatabaseDialect",
|
||||
{
|
||||
x.upper(): x
|
||||
for x in [
|
||||
*[
|
||||
y
|
||||
for x in (pg_dialects, mysql_dialects, sqlite_dialects)
|
||||
for y in x
|
||||
]
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: include Oracle and MSSQL dialects
|
||||
SECRET_KEY: str = config("SECRET_KEY")
|
8
src/backend/settings/dev.py
Normal file
8
src/backend/settings/dev.py
Normal file
@ -0,0 +1,8 @@
|
||||
from backend.settings.backend_options import local_db_config
|
||||
from backend.settings.base import Settings as BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DB_URL: str = local_db_config
|
||||
SQLALCHEMY_DATABASE_URL: str = local_db_config
|
||||
DEBUG: bool = True
|
8
src/backend/settings/prod.py
Normal file
8
src/backend/settings/prod.py
Normal file
@ -0,0 +1,8 @@
|
||||
from backend.settings.backend_options import cloud_db_config
|
||||
from backend.settings.base import Settings as BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DB_URL: str = cloud_db_config
|
||||
SQLALCHEMY_DATABASE_URI: str = cloud_db_config
|
||||
DEBUG: bool = False
|
@ -0,0 +1,68 @@
|
||||
import os
|
||||
|
||||
from decouple import config
|
||||
|
||||
|
||||
def get_db_url(env: str = "dev"):
|
||||
"""
|
||||
|
||||
:param env:
|
||||
:return:
|
||||
"""
|
||||
if env == "dev":
|
||||
un, pw, db, host, port = (
|
||||
config(x)
|
||||
for x in (
|
||||
"LOCAL_DB_UN",
|
||||
"LOCAL_DB_PW",
|
||||
"LOCAL_DB_DB",
|
||||
"LOCAL_DB_HOST",
|
||||
"LOCAL_DB_PORT",
|
||||
)
|
||||
)
|
||||
|
||||
# url = f"mysql://{un}:{pw}@{host}:{port}/{db}"
|
||||
# url = f"mariadb+pymysql://{un}:{pw}@{host}:{port}/{db}"
|
||||
elif env == "prod":
|
||||
un, pw, db, host, port = (
|
||||
config(x)
|
||||
for x in (
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASSWORD",
|
||||
"POSTGRES_DB",
|
||||
"db",
|
||||
"5432",
|
||||
)
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
url = f"postgresql://{un}:{pw}@{host}:{port}/{db}"
|
||||
return url
|
||||
|
||||
|
||||
def get_val(val: str, default: str | int | bool | None = None, **kwargs):
|
||||
"""
|
||||
A utility function that checks the platform environment or an .env file to
|
||||
pull a key:value pair and return the value. If the value does not exist, a
|
||||
ValueError will be raised
|
||||
Args:
|
||||
val: str
|
||||
default: Union[str, None]: default is None
|
||||
|
||||
Returns val: str
|
||||
|
||||
"""
|
||||
if os.environ.get(val, None) is not None:
|
||||
val = os.environ.get(val)
|
||||
elif config(val, None, **kwargs) is not None:
|
||||
val = config(val)
|
||||
elif default:
|
||||
val = default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Env Var {val} is not populated in the environment "
|
||||
f"or within the configuration files"
|
||||
)
|
||||
|
||||
return val
|
Loading…
x
Reference in New Issue
Block a user