diff --git a/pillar_tool/db/queries/pillar_queries.py b/pillar_tool/db/queries/pillar_queries.py index 5e9e184..7ddb7e9 100644 --- a/pillar_tool/db/queries/pillar_queries.py +++ b/pillar_tool/db/queries/pillar_queries.py @@ -13,7 +13,7 @@ def get_pillar_name_sequence(name: str) -> list[str]: def decode_pillar_value(pillar: Pillar) -> str | int | float | bool | list | dict: match pillar.parameter_type: - case 'string': return pillar.value + case 'string': return str(pillar.value) case 'integer': return int(pillar.value) case 'float': return float(pillar.value) case 'boolean': return bool(pillar.value) @@ -32,10 +32,17 @@ def get_pillar_for_target(db: Session, target: UUID) -> dict: name = row.pillar_name value = decode_pillar_value(row) labels = get_pillar_name_sequence(name) + print(f"Labels: {labels}, Value: {value}") current = out l = len(labels) for i, label in enumerate(labels): - if label not in current: - current[label] = {} if i < l-1 else value + if i < l-1: + if label not in current: + current[label] = {} + elif type(current[label]) is not dict: + current[label] = {} + current = current[label] + else: + current[label] = json.loads(str(value)) return out diff --git a/pillar_tool/ptcli/cli/__init__.py b/pillar_tool/ptcli/cli/__init__.py index 674f2c5..e7eaed0 100644 --- a/pillar_tool/ptcli/cli/__init__.py +++ b/pillar_tool/ptcli/cli/__init__.py @@ -1,6 +1,5 @@ from .host import host from .hostgroup import hostgroup -from .query import query from .state import state from .pillar import pillar from .environment import environment diff --git a/pillar_tool/ptcli/cli/hostgroup.py b/pillar_tool/ptcli/cli/hostgroup.py index ae25a42..11eb351 100644 --- a/pillar_tool/ptcli/cli/hostgroup.py +++ b/pillar_tool/ptcli/cli/hostgroup.py @@ -39,7 +39,7 @@ def hostgroup_show(path: str): response = requests.get(f'{base_url()}/hostgroup/{name}', headers=auth_header(), params=data.model_dump()) response.raise_for_status() except requests.exceptions.HTTPError as e: - raise click.ClickException(f"Failed to show hostgroup:\n{e}") + raise click.ClickException(f"Failed to show hostgroup:\n{e.response.text}") @hostgroup.command("create") @@ -56,7 +56,7 @@ def hostgroup_create(path: str): response = requests.post(f'{base_url()}/hostgroup/{name}', headers=auth_header(), json=data.model_dump()) response.raise_for_status() except requests.exceptions.HTTPError as e: - raise click.ClickException(f"Failed to create hostgroup:\n{e}") + raise click.ClickException(f"Failed to create hostgroup:\n{e.response.text}") @hostgroup.command("delete") @@ -71,4 +71,4 @@ def hostgroup_delete(path: str): response = requests.delete(f'{base_url()}/hostgroup/{name}{query_params}', headers=auth_header()) response.raise_for_status() except requests.exceptions.HTTPError as e: - raise click.ClickException(f"Failed to delete hostgroup:\n{e}") + raise click.ClickException(f"Failed to delete hostgroup:\n{e.response.text}") diff --git a/pillar_tool/ptcli/cli/pillar.py b/pillar_tool/ptcli/cli/pillar.py index 1f634dd..6a921bc 100644 --- a/pillar_tool/ptcli/cli/pillar.py +++ b/pillar_tool/ptcli/cli/pillar.py @@ -14,19 +14,20 @@ def pillar(): @pillar.command("get") -@click.argument("fqdn") -def pillar_get(fqdn): - """Get pillar data for a given FQDN.""" +@click.argument("target") +def pillar_get(target: str): + """Get pillar data for a given host or hostgroup.""" try: response = requests.get( - f"{base_url()}/pillar/{fqdn}", - headers=auth_header(), + f"{base_url()}/pillar/{target.replace('/', "%%2f")}", + headers=auth_header() ) + response.raise_for_status() pillar_data = response.json() - click.echo(yaml.dump({fqdn: pillar_data})) + click.echo(json.dumps(pillar_data, indent=2)) except requests.exceptions.RequestException as e: - click.echo(f"Error: {e}") + click.echo(f"Error: {e.response.json()}") @pillar.command("set") @click.argument("name") diff --git a/pillar_tool/ptcli/cli/query.py b/pillar_tool/ptcli/cli/query.py deleted file mode 100644 index 0d5fc29..0000000 --- a/pillar_tool/ptcli/cli/query.py +++ /dev/null @@ -1,9 +0,0 @@ -import click -import requests - -from .cli_main import main, auth_header, base_url - - -@main.group("query") -def query(): - pass \ No newline at end of file diff --git a/pillar_tool/ptcli/cli/top.py b/pillar_tool/ptcli/cli/top.py index 2e42d0e..7617fd3 100644 --- a/pillar_tool/ptcli/cli/top.py +++ b/pillar_tool/ptcli/cli/top.py @@ -45,7 +45,7 @@ def top_assign(host: str, state: str): click.echo("Assigning state to host...") try: - response = requests.post(f'{base_url()}/top/assign/{host}/{state}', headers=auth_header()) + response = requests.post(f'{base_url()}/top/assign/{host.replace('/', "%%2f")}/{state}', headers=auth_header()) response.raise_for_status() click.echo("Assigned state") diff --git a/pillar_tool/routers/hostgroup.py b/pillar_tool/routers/hostgroup.py index ec46eea..bccef02 100644 --- a/pillar_tool/routers/hostgroup.py +++ b/pillar_tool/routers/hostgroup.py @@ -1,11 +1,11 @@ import uuid - +from typing import Annotated from sqlalchemy import select, insert, bindparam, delete, and_ from sqlalchemy.orm import Session from starlette.exceptions import HTTPException from starlette.requests import Request -from fastapi import APIRouter, Query, Depends +from fastapi import APIRouter, Depends from starlette.responses import JSONResponse from pillar_tool.db import Host @@ -45,7 +45,7 @@ def hostgroups_get(req: Request): return JSONResponse(status_code=200, content=all_hostgroup_names) @router.get("/{name}") -def hostgroup_get(req: Request, name: str, params: HostgroupParams = Depends(get_model_from_query(HostgroupParams))): +def hostgroup_get(req: Request, name: str, params: HostgroupParams): """ Retrieve a specific host group by name with additional details @@ -66,7 +66,11 @@ def hostgroup_get(req: Request, name: str, params: HostgroupParams = Depends(get # decode the path last = None ancestors = [] - path = split_and_validate_path(params.path) if params.path else [] + path = [] + if params: + path = split_and_validate_path(params.path) if params.path else [] + + print("test") # get the path from the db path_stmt = select(Host).where(and_(Host.name == bindparam('name') and Host.parent_id == bindparam('parent_id'))) @@ -117,7 +121,7 @@ def hostgroup_create(req: Request, name: str, params: HostgroupParams): """ db = req.state.db path = params.path - labels = split_and_validate_path(path) if path is not None else [] + labels = ( split_and_validate_path(path) if path is not None else [] ) or [] labels += [ name ] stmt = select(Host).where(and_(Host.name == bindparam('name'), Host.is_hostgroup == True, Host.parent_id == bindparam('last'))) diff --git a/pillar_tool/routers/pillar.py b/pillar_tool/routers/pillar.py index d3d5563..72fc17f 100644 --- a/pillar_tool/routers/pillar.py +++ b/pillar_tool/routers/pillar.py @@ -26,8 +26,8 @@ router = APIRouter( # Note: there is no list of all pillars, as this would not be helpful -@router.get("/{fqdn}") -def pillar_get(req: Request, fqdn: str): +@router.get("/{target}") +def pillar_get(req: Request, target: str) -> JSONResponse: # TODO: implement # this function should: # - get the affected host hierarchy @@ -37,27 +37,54 @@ def pillar_get(req: Request, fqdn: str): # if any error happens, return non-200 status and an empty dictionary so that salt does not shit itself db: Session = req.state.db - # get the host hierarchy - host_stmt = select(Host).where(and_(Host.name == fqdn, Host.is_hostgroup == False)) - result = db.execute(host_stmt).fetchall() - if len(result) == 0: - return JSONResponse(status=404, content={}) - # NOTE: should be enforced by the database - assert len(result) == 1 + # if the target is a hostgroup with path, then split the path and get to the target host this way + target = target.replace("%%2F", "%%2f") + if "%%2f" in target: + path_labels = target.split("%%2f") - host: Host = result[0][0] - path: list[Host] = [host] - parent_stmt = select(Host).where(Host.id == bindparam('parent')) - while path[-1].parent_id is not None: - result = db.execute(parent_stmt, {'parent': path[-1].parent_id}).fetchall() + + host_stmt_remain = select(Host).where(and_(Host.name == bindparam('frag'), Host.parent_id == bindparam('parent_id'))) + host_stmt_first = select(Host).where(and_(Host.name == bindparam('frag'), Host.parent_id == None)) + host_stmt = host_stmt_first + + parent_id = None + path: list[Host] = [] + for fragment in path_labels: + result = db.execute(host_stmt, {"frag": fragment, "parent_id": parent_id}).fetchall() + host_stmt = host_stmt_remain + if len(result) == 0: + return JSONResponse(status_code=404, content={"message": f"No such path fragment: {fragment} with parent_id {parent_id}"}) + assert len(result) == 1 # Note: that the db should enforce this + + current: Host = result[0][0] + parent_id = current.id + path.append(current) + + + else: + # get the host hierarchy from a fqdn or unique hostgroup name + host_stmt = select(Host).where(Host.name == target) + result = db.execute(host_stmt).fetchall() + if len(result) == 0: + return JSONResponse(status_code=404, content={'message': f'No such target: {target}'}) # NOTE: should be enforced by the database - assert len(result) == 1 - tmp: Host = result[0][0] - path.append(tmp) + if len(result) > 1: + return JSONResponse(status_code=400, content={'message': f'Multiple targets: {target}'}) + + host: Host = result[0][0] + path: list[Host] = [host] + parent_stmt = select(Host).where(Host.id == bindparam('parent')) + while path[-1].parent_id is not None: + result = db.execute(parent_stmt, {'parent': path[-1].parent_id}).fetchall() + # NOTE: should be enforced by the database + assert len(result) == 1 + tmp: Host = result[0][0] + path.append(tmp) + + path.reverse() + - path.reverse() out = merge([get_pillar_for_target(db, host.id) for host in path]) - return JSONResponse(status_code=200, content=out) @@ -126,11 +153,10 @@ def pillar_create(req: Request, name: str, params: PillarParams): else: # build the pillar package - pillars_to_store= [ + pillars_to_store = [ { 'name': name, 'value': params.value, 'type': params.type } ] - print(f"Pillars to store: {pillars_to_store}") # store pillar data insert_stmt = insert(Pillar).values(id=bindparam('new_id'), host_id=target.id, pillar_name=bindparam('name'), parameter_type=bindparam('type'), value=bindparam('value')) upsert_stmt = insert_stmt.on_conflict_do_update(constraint='pillar_unique_pillar_name', set_={'parameter_type': bindparam('type'), 'value': bindparam('value')} ) diff --git a/pillar_tool/routers/top.py b/pillar_tool/routers/top.py index 454ea75..4a386cf 100644 --- a/pillar_tool/routers/top.py +++ b/pillar_tool/routers/top.py @@ -112,12 +112,19 @@ def top_state_assign(req: Request, host_name: str, state_name: str): db: Session = req.state.db # get the host in question - host_stmt = select(Host).where(Host.name == host_name) - host_res = db.execute(host_stmt).fetchall() - if len(host_res) != 1: - return JSONResponse(status_code=404, content={"error": f"Host '{host_name} not found"}) + path_labels = host_name.replace("%%2F", "%%2f").split("%%2f") + parent_id = None + for path in path_labels: + host_stmt = select(Host).where(and_(Host.name == path, Host.parent_id == parent_id)) + host_res = db.execute(host_stmt).fetchall() - host: Host = host_res[0][0] + if len(host_res) != 1: + return JSONResponse(status_code=404, content={"error": f"Host '{host_name} not found"}) + + current: Host = host_res[0][0] + parent_id = current.id + + host: Host = current parent_stmt = select(Host).where(Host.id == bindparam("parent_id")) parents: list[Host] = [] diff --git a/pillar_tool/schemas.py b/pillar_tool/schemas.py index 4d6ee0c..d9d5da8 100644 --- a/pillar_tool/schemas.py +++ b/pillar_tool/schemas.py @@ -53,10 +53,10 @@ class StateParams(BaseModel): # Pillar operations class PillarParams(BaseModel): - host: str | None - hostgroup: str | None - value: str | None # value if the pillar should be set - type: str | None # type of pillar if pillar should be set + host: str | None = None + hostgroup: str | None = None + value: str | None = None # value if the pillar should be set + type: str | None = None # type of pillar if pillar should be set def get_model_from_query[T](model: T) -> Callable[[Request], T]: