96 lines
2.8 KiB
Python
96 lines
2.8 KiB
Python
from sqlalchemy import text
|
|
|
|
from pillar_tool.schemas import HealthCheckError, HealthCheckSuccess
|
|
from pillar_tool.util import load_config, config
|
|
load_config()
|
|
|
|
from pillar_tool.middleware.logging import request_logging_middleware
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from pillar_tool.db.database import get_connection
|
|
|
|
from pillar_tool.db.queries.auth_queries import create_user
|
|
|
|
from fastapi import FastAPI
|
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
|
from starlette.requests import Request
|
|
from fastapi.responses import HTMLResponse, PlainTextResponse
|
|
from starlette.responses import JSONResponse
|
|
|
|
from pillar_tool.middleware.basicauth_backend import BasicAuthBackend
|
|
from pillar_tool.middleware.db_connection import db_connection_middleware
|
|
from pillar_tool.db.database import run_db_migrations
|
|
|
|
# import all the routers
|
|
from pillar_tool.routers.host import router as host_router
|
|
from pillar_tool.routers.hostgroup import router as hostgroup_router
|
|
from pillar_tool.routers.environment import router as environment_router
|
|
from pillar_tool.routers.state import router as state_router
|
|
from pillar_tool.routers.pillar import router as pillar_router
|
|
|
|
# run any pending migrations
|
|
run_db_migrations()
|
|
|
|
# get a database connection
|
|
db = get_connection()
|
|
|
|
# create default user if it does not exist
|
|
# noinspection PyBroadException
|
|
try:
|
|
create_user(db, "admin", "admin")
|
|
except:
|
|
pass
|
|
|
|
# commit and close the db
|
|
db.commit()
|
|
db.close()
|
|
|
|
|
|
def on_auth_error(request: Request, exc: Exception):
|
|
response = PlainTextResponse(str(exc), status_code=401)
|
|
response.headers["WWW-Authenticate"] = "Basic"
|
|
|
|
return response
|
|
|
|
def on_db_error(request: Request, exc: Exception):
|
|
response = PlainTextResponse(str(exc), status_code=500)
|
|
|
|
return response
|
|
|
|
def on_general_error(request: Request, exc: Exception):
|
|
print("wtf?")
|
|
|
|
response = PlainTextResponse(str(exc), status_code=500)
|
|
return response
|
|
|
|
app = FastAPI()
|
|
app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error)
|
|
app.add_middleware(BaseHTTPMiddleware, dispatch=db_connection_middleware)
|
|
app.add_middleware(BaseHTTPMiddleware, dispatch=request_logging_middleware)
|
|
app.exception_handler(Exception)(on_general_error)
|
|
|
|
# Set up the api router
|
|
app.include_router(host_router)
|
|
app.include_router(hostgroup_router)
|
|
app.include_router(environment_router)
|
|
app.include_router(state_router)
|
|
app.include_router(pillar_router)
|
|
|
|
@app.get("/")
|
|
async def root():
|
|
return {"message": "Hello World"}
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
# Check database connection
|
|
try:
|
|
db = get_connection()
|
|
db.execute(text("SELECT 1"))
|
|
db.close()
|
|
except Exception as e:
|
|
return HealthCheckError(500, f"Database connection error:\n{e}").response()
|
|
|
|
return HealthCheckSuccess().response()
|