import json import uuid from uuid import uuid4 from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import insert from sqlalchemy import select, delete, bindparam, and_ from sqlalchemy.orm import Session from starlette.exceptions import HTTPException from starlette.requests import Request from fastapi import APIRouter, Depends from starlette.responses import JSONResponse from pillar_tool.db import Host, Pillar from pillar_tool.db.models.top_data import State, StateAssignment from pillar_tool.db.queries.pillar_queries import get_pillar_for_target from pillar_tool.schemas import PillarParams, get_model_from_query from pillar_tool.util.pillar_utilities import merge from pillar_tool.util.validation import validate_state_name, validate_fqdn, validate_pillar_input_data, \ split_and_validate_path router = APIRouter( prefix="/pillar", tags=["pillar"], ) # Note: there is no list of all pillars, as this would not be helpful @router.get("/{fqdn}") def pillar_get(req: Request, fqdn: str): # TODO: implement # this function should: # - get the affected host hierarchy # - get all the relevant pillar dictionaries # - merge the pillar directories # - return the merged pillar directory # if any error happens, return non-200 status and an empty dictionary so that salt does not shit itself db: Session = req.state.db # get the host hierarchy host_stmt = select(Host).where(and_(Host.name == fqdn, Host.is_hostgroup == False)) result = db.execute(host_stmt).fetchall() if len(result) == 0: return JSONResponse(status=404, content={}) # NOTE: should be enforced by the database assert len(result) == 1 host: Host = result[0][0] path: list[Host] = [host] parent_stmt = select(Host).where(Host.id == bindparam('parent')) while path[-1].parent_id is not None: result = db.execute(parent_stmt, {'parent': path[-1].parent_id}).fetchall() # NOTE: should be enforced by the database assert len(result) == 1 tmp: Host = result[0][0] path.append(tmp) path.reverse() out = merge([get_pillar_for_target(db, host.id) for host in path]) return JSONResponse(status_code=200, content=out) @router.post("/{name}") def pillar_create(req: Request, name: str, params: PillarParams): db: Session = req.state.db # ensure that value and type have been set in the request parameters if params.type is None or params.value is None: return JSONResponse(status_code=400, content={ 'message': "Both parameter type and value need to be set!" }) # validate pillar data pillar_data = validate_pillar_input_data(params.value, params.type) if params.host is not None: target_stmt = select(Host).where(Host.name == params.host) result = db.execute(target_stmt).fetchall() if len(result) == 0: return JSONResponse(status_code=404, content={}) # this should be enforced by the database assert len(result) == 1 target: Host = result[0][0] elif params.hostgroup is not None: path = split_and_validate_path(params.hostgroup) if not path: return JSONResponse(status_code=400, content={'message': "No target specified"}) last = None current = None # Note: both statements need to be present, since '==' will not work for None and 'is' will not work for a UUID group_stmt = select(Host).where(and_(Host.is_hostgroup == True, Host.parent_id == bindparam('parent'), Host.name == bindparam('name'))) group_stmt_none = select(Host).where(and_(Host.is_hostgroup == True, Host.parent_id.is_(None), Host.name == bindparam('name'))) for label in path: result = db.execute(group_stmt if last is not None else group_stmt_none, {'name': label, 'parent': last}).fetchall() if len(result) == 0: return JSONResponse(status_code=404, content={'message': f"No hostgroup named: {params.hostgroup}"}) # Note: this should be enforced by the database assert len(result) == 1, f"Result: {[x[0].name for x in result]}" current: Host = result[0][0] last = current.id target: Host = current else: return JSONResponse(status_code=400, content={'message': "Neither host nor hostgroup set"}) # if this is a dictionary value, parse it and create a separate entry for all the sub-pillars if type(pillar_data) == dict: def aux(prefix: str, input_value: dict) -> list[dict[str, str]]: out = [] for key, value in input_value.items(): if type(value) is dict: out += aux(f"{prefix}:{key}", value) else: out += [{ 'name': f"{prefix}:{key}", 'type': type(value).__name__, 'value': json.dumps(value) }] return out pillars_to_store = aux(name, pillar_data) else: # build the pillar package pillars_to_store= [ { 'name': name, 'value': params.value, 'type': params.type } ] print(f"Pillars to store: {pillars_to_store}") # store pillar data insert_stmt = insert(Pillar).values(id=bindparam('new_id'), host_id=target.id, pillar_name=bindparam('name'), parameter_type=bindparam('type'), value=bindparam('value')) upsert_stmt = insert_stmt.on_conflict_do_update(constraint='pillar_unique_pillar_name', set_={'parameter_type': bindparam('type'), 'value': bindparam('value')} ) for instance in pillars_to_store: instance['new_id'] = uuid4() result = db.execute(upsert_stmt, instance) return JSONResponse(status_code=200, content={'message': 'ok'}) @router.delete("/{name}") def pillar_delete(req: Request, name: str, params: PillarParams): # TODO: implement db = req.state.db