160 lines
4.5 KiB
Python

from pillar_tool.schemas import HealthCheckError
from pillar_tool.util import load_config, config
load_config()
from pillar_tool.middleware.logging import request_logging_middleware
from starlette.middleware.base import BaseHTTPMiddleware
from pillar_tool.db.database import get_connection
from pillar_tool.db.queries.auth_queries import create_user
from fastapi import FastAPI
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.requests import Request
from fastapi.responses import HTMLResponse, PlainTextResponse
from starlette.responses import JSONResponse
from pillar_tool.middleware.basicauth_backend import BasicAuthBackend
from pillar_tool.middleware.db_connection import db_connection_middleware
from pillar_tool.db.database import run_db_migrations
# import all the routers
from pillar_tool.routers.host import router as host_router
# run any pending migrations
run_db_migrations()
# get a database connection
db = get_connection()
# create default user if it does not exist
# noinspection PyBroadException
try:
create_user(db, "admin", "admin")
except:
pass
# commit and close the db
db.commit()
db.close()
def on_auth_error(request: Request, exc: Exception):
response = PlainTextResponse(str(exc), status_code=401)
response.headers["WWW-Authenticate"] = "Basic"
return response
def on_db_error(request: Request, exc: Exception):
response = PlainTextResponse(str(exc), status_code=500)
return response
def on_general_error(request: Request, exc: Exception):
print("wtf?")
response = PlainTextResponse(str(exc), status_code=500)
return response
app = FastAPI()
app.add_middleware(AuthenticationMiddleware, backend=BasicAuthBackend(), on_error=on_auth_error)
app.add_middleware(BaseHTTPMiddleware, dispatch=db_connection_middleware)
app.add_middleware(BaseHTTPMiddleware, dispatch=request_logging_middleware)
app.exception_handler(Exception)(on_general_error)
# Setup the api router
app.include_router(host_router)
@app.get("/")
async def root():
return {"message": "Hello World"}
@app.get("/health")
async def health():
# Check database connection
try:
db = get_connection()
db.execute("SELECT 1")
db.close()
except Exception as e:
return HealthCheckError(500, f"Database connection error:\n{e}").response()
return HealthCheckSuccess().response()
"""
@app.get("/pillar/{host}")
async def pillar_get(req: Request, host: str):
print(req.headers)
#return JSONResponse(content=collect_pillar_data(host))
return JSONResponse({})
@app.post("/pillar/{host}")
async def pillar_set(request: Request, host: str, value: str):
return JSONResponse({
"captain.linvogel.internal": {
"states": ["state1", "state2"],
"test": {
"pillar": "value"
}
}
})
# TODO: list, create update and delete hosts
@app.get("/hosts")
async def host_list(request: Request):
all_hosts = list_all_hosts(request.state.db)
return JSONResponse([x.name for x in all_hosts if not x.is_hostgroup])
# TODO: list, create, update and delete hostgroups
@app.get("/hostgroups")
async def hostgroup_list(request: Request):
all_hosts = list_all_hosts(request.state.db)
return JSONResponse([x.name for x in all_hosts if x.is_hostgroup])
# TODO: list, create, update and delete states
# TODO: list, create, update and delete environments
# TODO: top files generated on a per host basis
@app.get("/top/{fqdn}")
async def host_top(req: Request, fqdn: str):
db: Session = req.state.db
if not validate_fqdn(fqdn):
return JSONResponse(status_code=400, content={
'message': f"Invalid FQDN: {fqdn}"
})
environment_stmt = select(Environment)
result = db.execute(environment_stmt).fetchall()
if len(result) == 0:
return JSONResponse(status_code=400, content={
'message': "There are no environments defined"
})
environments: list[Environment] = list(map(lambda x: x[0], result))
stmt_host = select(Host).where(Host.name == fqdn)
result = db.execute(stmt_host).fetchall()
if len(result) < 1:
return JSONResponse(status_code=404, content={
'message': f"No such Host is known: {fqdn}"
})
# this should be enforced by the database
assert len(result) == 1
host: Host = result[0][0]
stmt_top = select(Environment, Host, State).where(Environment).join
# TODO: implement
return JSONResponse({})
"""