191 lines
7.9 KiB
Python

import json
import uuid
from uuid import uuid4
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy import select, delete, bindparam, and_, or_
from sqlalchemy.orm import Session
from starlette.exceptions import HTTPException
from starlette.requests import Request
from fastapi import APIRouter, Depends
from starlette.responses import JSONResponse
from pillar_tool.db import Host, Pillar
from pillar_tool.db.models.top_data import State, StateAssignment
from pillar_tool.db.queries.pillar_queries import get_pillar_for_target
from pillar_tool.schemas import PillarParams, get_model_from_query
from pillar_tool.util.pillar_utilities import merge
from pillar_tool.util.validation import validate_state_name, validate_fqdn, validate_pillar_input_data, \
split_and_validate_path
router = APIRouter(
prefix="/pillar",
tags=["pillar"],
)
# Note: there is no list of all pillars, as this would not be helpful
@router.get("/{fqdn}")
def pillar_get(req: Request, fqdn: str):
# TODO: implement
# this function should:
# - get the affected host hierarchy
# - get all the relevant pillar dictionaries
# - merge the pillar directories
# - return the merged pillar directory
# if any error happens, return non-200 status and an empty dictionary so that salt does not shit itself
db: Session = req.state.db
# get the host hierarchy
host_stmt = select(Host).where(and_(Host.name == fqdn, Host.is_hostgroup == False))
result = db.execute(host_stmt).fetchall()
if len(result) == 0:
return JSONResponse(status=404, content={})
# NOTE: should be enforced by the database
assert len(result) == 1
host: Host = result[0][0]
path: list[Host] = [host]
parent_stmt = select(Host).where(Host.id == bindparam('parent'))
while path[-1].parent_id is not None:
result = db.execute(parent_stmt, {'parent': path[-1].parent_id}).fetchall()
# NOTE: should be enforced by the database
assert len(result) == 1
tmp: Host = result[0][0]
path.append(tmp)
path.reverse()
out = merge([get_pillar_for_target(db, host.id) for host in path])
return JSONResponse(status_code=200, content=out)
@router.post("/{name}")
def pillar_create(req: Request, name: str, params: PillarParams):
db: Session = req.state.db
# ensure that value and type have been set in the request parameters
if params.type is None or params.value is None:
return JSONResponse(status_code=400, content={
'message': "Both parameter type and value need to be set!"
})
# validate pillar data
pillar_data = validate_pillar_input_data(params.value, params.type)
if params.host is not None:
target_stmt = select(Host).where(Host.name == params.host)
result = db.execute(target_stmt).fetchall()
if len(result) == 0:
return JSONResponse(status_code=404, content={})
# this should be enforced by the database
assert len(result) == 1
target: Host = result[0][0]
elif params.hostgroup is not None:
path = split_and_validate_path(params.hostgroup)
if not path:
return JSONResponse(status_code=400, content={'message': "No target specified"})
last = None
current = None
# Note: both statements need to be present, since '==' will not work for None and 'is' will not work for a UUID
group_stmt = select(Host).where(and_(Host.is_hostgroup == True, Host.parent_id == bindparam('parent'), Host.name == bindparam('name')))
group_stmt_none = select(Host).where(and_(Host.is_hostgroup == True, Host.parent_id.is_(None), Host.name == bindparam('name')))
for label in path:
result = db.execute(group_stmt if last is not None else group_stmt_none, {'name': label, 'parent': last}).fetchall()
if len(result) == 0:
return JSONResponse(status_code=404, content={'message': f"No hostgroup named: {params.hostgroup}"})
# Note: this should be enforced by the database
assert len(result) == 1, f"Result: {[x[0].name for x in result]}"
current: Host = result[0][0]
last = current.id
target: Host = current
else:
return JSONResponse(status_code=400, content={'message': "Neither host nor hostgroup set"})
# if this is a dictionary value, parse it and create a separate entry for all the sub-pillars
if type(pillar_data) == dict:
def aux(prefix: str, input_value: dict) -> list[dict[str, str]]:
out = []
for key, value in input_value.items():
if type(value) is dict:
out += aux(f"{prefix}:{key}", value)
else:
out += [{
'name': f"{prefix}:{key}",
'type': type(value).__name__,
'value': json.dumps(value)
}]
return out
pillars_to_store = aux(name, pillar_data)
else:
# build the pillar package
pillars_to_store= [
{ 'name': name, 'value': params.value, 'type': params.type }
]
print(f"Pillars to store: {pillars_to_store}")
# store pillar data
insert_stmt = insert(Pillar).values(id=bindparam('new_id'), host_id=target.id, pillar_name=bindparam('name'), parameter_type=bindparam('type'), value=bindparam('value'))
upsert_stmt = insert_stmt.on_conflict_do_update(constraint='pillar_unique_pillar_name', set_={'parameter_type': bindparam('type'), 'value': bindparam('value')} )
for instance in pillars_to_store:
instance['new_id'] = uuid4()
result = db.execute(upsert_stmt, instance)
return JSONResponse(status_code=200, content={'message': 'ok'})
@router.delete("/{name}")
def pillar_delete(req: Request, name: str, params: PillarParams):
db = req.state.db
if params.host is not None:
# delete a pillar at the host level
target_stmt = select(Host).where(and_(Host.name == params.host, Host.is_hostgroup == False))
result = db.execute(target_stmt).fetchall()
if len(result) == 0:
return JSONResponse(status_code=404, content={})
# this should be enforced by the database
assert len(result) == 1
target: Host = result[0][0]
elif params.hostgroup is not None:
# delete a pillar at the hostgroup level
path = split_and_validate_path(params.hostgroup)
if not path:
return JSONResponse(status_code=400, content={'message': "No target specified"})
last = None
current = None
# Note: both statements need to be present, since '==' will not work for None and 'is' will not work for a UUID
group_stmt = select(Host).where(and_(Host.is_hostgroup == True, Host.parent_id == bindparam('parent'), Host.name == bindparam('name')))
group_stmt_none = select(Host).where(and_(Host.is_hostgroup == True, Host.parent_id.is_(None), Host.name == bindparam('name')))
for label in path:
result = db.execute(group_stmt if last is not None else group_stmt_none, {'name': label, 'parent': last}).fetchall()
if len(result) == 0:
return JSONResponse(status_code=404, content={'message': f"No hostgroup named: {params.hostgroup}"})
# Note: this should be enforced by the database
assert len(result) == 1, f"Result: {[x[0].name for x in result]}"
current: Host = result[0][0]
last = current.id
target: Host = current
else:
return JSONResponse(status_code=400, content={
'message': "Either Host or Hostgroup needs to be set!"
})
delete_stmt = delete(Pillar).where(and_(Pillar.host_id == target.id, or_(Pillar.pillar_name == name, Pillar.pillar_name.like(f"{name}:%"))))
result = db.execute(delete_stmt)
return JSONResponse(status_code=200, content={'message': 'ok'})