from sqlalchemy import text from pillar_tool.schemas import HealthCheckError, HealthCheckSuccess from pillar_tool.util import load_config, config load_config() from pillar_tool.middleware.logging import request_logging_middleware from starlette.middleware.base import BaseHTTPMiddleware from pillar_tool.db.database import get_connection from pillar_tool.db.queries.auth_queries import create_user from fastapi import FastAPI from starlette.middleware.authentication import AuthenticationMiddleware from starlette.requests import Request from fastapi.responses import HTMLResponse, PlainTextResponse from starlette.responses import JSONResponse from pillar_tool.middleware.basicauth_backend import BasicAuthBackend from pillar_tool.middleware.db_connection import db_connection_middleware from pillar_tool.db.database import run_db_migrations # import all the routers from pillar_tool.routers.host import router as host_router from pillar_tool.routers.hostgroup import router as hostgroup_router from pillar_tool.routers.environment import router as environment_router from pillar_tool.routers.state import router as state_router # run any pending migrations run_db_migrations() # get a database connection db = get_connection() # create default user if it does not exist # noinspection PyBroadException try: create_user(db, "admin", "admin") except: pass # commit and close the db db.commit() db.close() def on_auth_error(request: Request, exc: Exception): response = PlainTextResponse(str(exc), status_code=401) response.headers["WWW-Authenticate"] = "Basic" return response def on_db_error(request: Request, exc: Exception): response = PlainTextResponse(str(exc), status_code=500) return response def on_general_error(request: Request, exc: Exception): print("wtf?") response = PlainTextResponse(str(exc), status_code=500) return response app = FastAPI() app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error) app.add_middleware(BaseHTTPMiddleware, dispatch=db_connection_middleware) app.add_middleware(BaseHTTPMiddleware, dispatch=request_logging_middleware) app.exception_handler(Exception)(on_general_error) # Setup the api router app.include_router(host_router) app.include_router(hostgroup_router) app.include_router(environment_router) app.include_router(state_router) @app.get("/") async def root(): return {"message": "Hello World"} @app.get("/health") async def health(): # Check database connection try: db = get_connection() db.execute(text("SELECT 1")) db.close() except Exception as e: return HealthCheckError(500, f"Database connection error:\n{e}").response() return HealthCheckSuccess().response() """ @app.get("/pillar/{host}") async def pillar_get(req: Request, host: str): print(req.headers) #return JSONResponse(content=collect_pillar_data(host)) return JSONResponse({}) @app.post("/pillar/{host}") async def pillar_set(request: Request, host: str, value: str): return JSONResponse({ "captain.linvogel.internal": { "states": ["state1", "state2"], "test": { "pillar": "value" } } }) # TODO: list, create update and delete hosts @app.get("/hosts") async def host_list(request: Request): all_hosts = list_all_hosts(request.state.db) return JSONResponse([x.name for x in all_hosts if not x.is_hostgroup]) # TODO: list, create, update and delete hostgroups @app.get("/hostgroups") async def hostgroup_list(request: Request): all_hosts = list_all_hosts(request.state.db) return JSONResponse([x.name for x in all_hosts if x.is_hostgroup]) # TODO: list, create, update and delete states # TODO: list, create, update and delete environments # TODO: top files generated on a per host basis @app.get("/top/{fqdn}") async def host_top(req: Request, fqdn: str): db: Session = req.state.db if not validate_fqdn(fqdn): return JSONResponse(status_code=400, content={ 'message': f"Invalid FQDN: {fqdn}" }) environment_stmt = select(Environment) result = db.execute(environment_stmt).fetchall() if len(result) == 0: return JSONResponse(status_code=400, content={ 'message': "There are no environments defined" }) environments: list[Environment] = list(map(lambda x: x[0], result)) stmt_host = select(Host).where(Host.name == fqdn) result = db.execute(stmt_host).fetchall() if len(result) < 1: return JSONResponse(status_code=404, content={ 'message': f"No such Host is known: {fqdn}" }) # this should be enforced by the database assert len(result) == 1 host: Host = result[0][0] stmt_top = select(Environment, Host, State).where(Environment).join # TODO: implement return JSONResponse({}) """