From 37fa6bcbb388d90d0b1038dcb32c0ec667f2fde8 Mon Sep 17 00:00:00 2001 From: Linus Vogel Date: Mon, 16 Feb 2026 22:43:49 +0100 Subject: [PATCH] worked on pillar endpoints --- pillar_tool/db/queries/pillar_queries.py | 79 +++++++++--------------- pillar_tool/main.py | 4 +- pillar_tool/ptcli/cli/hostgroup.py | 8 +-- pillar_tool/ptcli/cli/pillar.py | 18 +++++- pillar_tool/routers/pillar.py | 44 ++++++++++--- 5 files changed, 87 insertions(+), 66 deletions(-) diff --git a/pillar_tool/db/queries/pillar_queries.py b/pillar_tool/db/queries/pillar_queries.py index 9ab218c..d840416 100644 --- a/pillar_tool/db/queries/pillar_queries.py +++ b/pillar_tool/db/queries/pillar_queries.py @@ -1,7 +1,9 @@ -from typing import Any +import json +from collections import defaultdict from pillar_tool.db.models.pillar_data import * +from uuid import UUID from sqlalchemy import select, insert, union from sqlalchemy.orm import Session @@ -9,57 +11,32 @@ from sqlalchemy.orm import Session def get_pillar_name_sequence(name: str) -> list[str]: return name.split(':') - -def generate_host_hierarchy(db: Session, labels: list[str]) -> list[Host]: - path_consumed = [] - out = [] - last_parent_id = None - for label in labels: - path_consumed += label - stmt = select(Host).where(Host.name == label and Host.parent_id == last_parent_id) - result = list(db.execute(stmt).fetchall()) - if not result: - raise RuntimeError(f"No such host(-group): '{':'.join(path_consumed)}'") - # NOTE: this is an assertion because the schema should enforce this - assert len(result) == 1 - instance = Host(result[0]) - last_parent_id = instance.id - out.append(instance) - - return out +def decode_pillar_value(pillar: Pillar) -> str | int | float | bool | list | dict: + match pillar.type: + case 'string': return pillar.value + case 'integer': return int(pillar.value) + case 'float': return float(pillar.value) + case 'boolean': return bool(pillar.value) + case 'array': return json.loads(pillar.value) + case 'dict': return json.loads(pillar.value) + raise RuntimeError(f"Failed to decode pillar value: Invalid type '{pillar.type}'") -def get_values_for_host(db: Session, host: str) -> dict: - labels = get_pillar_name_sequence(host) - hierarchy = generate_host_hierarchy(db, labels) +def get_pillar_for_target(db: Session, target: UUID) -> dict: + pillar_stmt = select(Pillar).where(Pillar.host_id == target) + result = db.execute(pillar_stmt).fetchall() - # TODO: generate host hierarchy - # TODO: find all values assigned o this host hierarchy and sort by depth - # TODO: build the pillar structure - - return {} - - -def create_pillar_host(db: Session, host_id: UUID, name: str, value: Any) -> None: - # TODO: generate host hierarchy - # get the involved host or hostgroup - res = db.execute(select(Host).where(Host.id == host_id)).fetchone() - if res is None: - # TODO: handle this error with a custom Exception - raise RuntimeError(f"No Host or Hostgroup with id {host_id} exists!") - host = res[0][0] - - - - # TODO: generate pillar path from name - - # TODO: find if pillar already exists - # TODO: create new pillar if it doesn't exist - # TODO: assign value to new or existing pillar - return - - - -def create_pillar_host_group(db: Session, host_group: UUID, name: str, value: Any) -> None: - pass + out = {} + for row in result: + row: Pillar = row[0] + name = row.pillar_name + value = decode_pillar_value(row) + labels = get_pillar_name_sequence(name) + current = out + l = len(labels) + for i, label in enumerate(labels): + if label not in current: + current[label] = {} if i < l-1 else value + print(json.dumps(out, indent=4)) + pass \ No newline at end of file diff --git a/pillar_tool/main.py b/pillar_tool/main.py index acbbd31..6a8dd78 100644 --- a/pillar_tool/main.py +++ b/pillar_tool/main.py @@ -27,6 +27,7 @@ from pillar_tool.routers.host import router as host_router from pillar_tool.routers.hostgroup import router as hostgroup_router from pillar_tool.routers.environment import router as environment_router from pillar_tool.routers.state import router as state_router +from pillar_tool.routers.pillar import router as pillar_router # run any pending migrations run_db_migrations() @@ -69,11 +70,12 @@ app.add_middleware(BaseHTTPMiddleware, dispatch=db_connection_middleware) app.add_middleware(BaseHTTPMiddleware, dispatch=request_logging_middleware) app.exception_handler(Exception)(on_general_error) -# Setup the api router +# Set up the api router app.include_router(host_router) app.include_router(hostgroup_router) app.include_router(environment_router) app.include_router(state_router) +app.include_router(pillar_router) @app.get("/") async def root(): diff --git a/pillar_tool/ptcli/cli/hostgroup.py b/pillar_tool/ptcli/cli/hostgroup.py index 8fe836e..ae25a42 100644 --- a/pillar_tool/ptcli/cli/hostgroup.py +++ b/pillar_tool/ptcli/cli/hostgroup.py @@ -15,7 +15,7 @@ def hostgroup(): def hostgroup_list(): click.echo("Listing known hostgroups...") try: - response = requests.get(f'{base_url}/hostgroup', headers=auth_header()) + response = requests.get(f'{base_url()}/hostgroup', headers=auth_header()) response.raise_for_status() click.echo("Hostgroups:") @@ -36,7 +36,7 @@ def hostgroup_show(path: str): data = HostgroupParams( path=path ) - response = requests.get(f'{base_url}/hostgroup/{name}', headers=auth_header(), params=data.model_dump()) + response = requests.get(f'{base_url()}/hostgroup/{name}', headers=auth_header(), params=data.model_dump()) response.raise_for_status() except requests.exceptions.HTTPError as e: raise click.ClickException(f"Failed to show hostgroup:\n{e}") @@ -53,7 +53,7 @@ def hostgroup_create(path: str): data = HostgroupParams( path=path ) - response = requests.post(f'{base_url}/hostgroup/{name}', headers=auth_header(), json=data.model_dump()) + response = requests.post(f'{base_url()}/hostgroup/{name}', headers=auth_header(), json=data.model_dump()) response.raise_for_status() except requests.exceptions.HTTPError as e: raise click.ClickException(f"Failed to create hostgroup:\n{e}") @@ -68,7 +68,7 @@ def hostgroup_delete(path: str): name = labels[-1] prefix = "/".join(labels[:-1]) if len(labels) > 1 else None query_params = f"?path={prefix}" if prefix is not None else '' - response = requests.delete(f'{base_url}/hostgroup/{name}{query_params}', headers=auth_header()) + response = requests.delete(f'{base_url()}/hostgroup/{name}{query_params}', headers=auth_header()) response.raise_for_status() except requests.exceptions.HTTPError as e: raise click.ClickException(f"Failed to delete hostgroup:\n{e}") diff --git a/pillar_tool/ptcli/cli/pillar.py b/pillar_tool/ptcli/cli/pillar.py index dffe677..f019b93 100644 --- a/pillar_tool/ptcli/cli/pillar.py +++ b/pillar_tool/ptcli/cli/pillar.py @@ -6,4 +6,20 @@ from .cli_main import main, auth_header, base_url @main.group("pillar") def pillar(): - pass \ No newline at end of file + pass + + +@pillar.command("get") +@click.argument("fqdn") +def pillar_get(fqdn): + """Get pillar data for a given FQDN.""" + try: + response = requests.get( + f"{base_url()}/pillar/{fqdn}", + headers=auth_header(), + ) + response.raise_for_status() + pillar_data = response.json() + click.echo(pillar_data) + except requests.exceptions.RequestException as e: + click.echo(f"Error: {e}") \ No newline at end of file diff --git a/pillar_tool/routers/pillar.py b/pillar_tool/routers/pillar.py index 4afebc3..ecd22e7 100644 --- a/pillar_tool/routers/pillar.py +++ b/pillar_tool/routers/pillar.py @@ -1,15 +1,17 @@ import uuid -from fastapi.params import Depends -from sqlalchemy import select, insert, delete +from sqlalchemy import select, insert, delete, bindparam from sqlalchemy.orm import Session from starlette.exceptions import HTTPException from starlette.requests import Request -from fastapi import APIRouter +from fastapi import APIRouter, Depends from starlette.responses import JSONResponse +from pillar_tool.db import Host from pillar_tool.db.models.top_data import State, StateAssignment +from pillar_tool.db.queries.pillar_queries import get_pillar_for_target from pillar_tool.schemas import PillarParams, get_model_from_query +from pillar_tool.util.pillar_utilities import merge from pillar_tool.util.validation import validate_state_name router = APIRouter( @@ -19,8 +21,8 @@ router = APIRouter( # Note: there is no list of all pillars, as this would not be helpful -@router.get("/{name}") -def state_get(req: Request, name: str, params: Depends(get_model_from_query(PillarParams))): +@router.get("/{fqdn}") +def pillar_get(req: Request, fqdn: str): # TODO: implement # this function should: # - get the affected host hierarchy @@ -30,14 +32,38 @@ def state_get(req: Request, name: str, params: Depends(get_model_from_query(Pill # if any error happens, return non-200 status and an empty dictionary so that salt does not shit itself db: Session = req.state.db + # get the host hierarchy + host_stmt = select(Host).where(Host.name == fqdn and Host.is_hostgroup == False) + result = db.execute(host_stmt).fetchall() + if len(result) == 0: + return JSONResponse(status=404, content={}) + # NOTE: should be enforced by the database + assert len(result) == 1 -@router.post("/{name}") -def state_create(req: Request, name: str): + host: Host = result[0][0] + path: list[Host] = [host] + parent_stmt = select(Host).where(Host.id == bindparam('parent')) + while path[-1].parent_id is not None: + result = db.execute(parent_stmt, {'parent': path[-1].parent_id}).fetchall() + # NOTE: should be enforced by the database + assert len(result) == 1 + tmp: Host = result[0][0] + path.append(tmp) + + path.reverse() + out = merge(get_pillar_for_target(db, host.id) for host in path) + + return JSONResponse(status_code=200, content={}) + + + +@router.post("/{fqdn}") +def pillar_create(req: Request, fqdn: str, params: PillarParams): # TODO: implement db = req.state.db -@router.delete("/{name}") -def state_delete(req: Request, name: str): +@router.delete("/{fqdn}") +def pillar_delete(req: Request, fqdn: str): # TODO: implement db = req.state.db