diff --git a/pillar_tool/db/queries/pillar_queries.py b/pillar_tool/db/queries/pillar_queries.py index d840416..5e9e184 100644 --- a/pillar_tool/db/queries/pillar_queries.py +++ b/pillar_tool/db/queries/pillar_queries.py @@ -12,14 +12,14 @@ def get_pillar_name_sequence(name: str) -> list[str]: return name.split(':') def decode_pillar_value(pillar: Pillar) -> str | int | float | bool | list | dict: - match pillar.type: + match pillar.parameter_type: case 'string': return pillar.value case 'integer': return int(pillar.value) case 'float': return float(pillar.value) case 'boolean': return bool(pillar.value) case 'array': return json.loads(pillar.value) case 'dict': return json.loads(pillar.value) - raise RuntimeError(f"Failed to decode pillar value: Invalid type '{pillar.type}'") + raise RuntimeError(f"Failed to decode pillar value: Invalid type '{pillar.parameter_type}'") def get_pillar_for_target(db: Session, target: UUID) -> dict: @@ -38,5 +38,4 @@ def get_pillar_for_target(db: Session, target: UUID) -> dict: if label not in current: current[label] = {} if i < l-1 else value - print(json.dumps(out, indent=4)) - pass \ No newline at end of file + return out diff --git a/pillar_tool/ptcli/cli/cli_main.py b/pillar_tool/ptcli/cli/cli_main.py index 3b43fc2..c1c1e52 100644 --- a/pillar_tool/ptcli/cli/cli_main.py +++ b/pillar_tool/ptcli/cli/cli_main.py @@ -1,4 +1,9 @@ import base64 +from importlib.metadata import PackageNotFoundError +from os.path import dirname, realpath +from sysconfig import get_python_version +from importlib import metadata +import tomllib import click @@ -21,8 +26,21 @@ def base_url(): global _base_url return _base_url +def print_version_and_exit(ctx, _1, execute): + if not execute: return False + version_string = "version_undetectable" + try: + version_string = metadata.version('pillar_tool') + except PackageNotFoundError: + with (open(f"{realpath(dirname(__file__)) + '/../../..'}/pyproject.toml", "r") as f): + data = tomllib.loads(f.read()) + version_string = data['project']['version'] + print(f"pillar tool v{version_string}") + ctx.exit() + @click.group("command") -def main(): +@click.option("--version", '-v', is_flag=True, is_eager=True, callback=print_version_and_exit) +def main(**kwargs) -> None: global cfg, _base_url, _auth_header # load the configuration and store it diff --git a/pillar_tool/ptcli/cli/pillar.py b/pillar_tool/ptcli/cli/pillar.py index 0d94953..ce06957 100644 --- a/pillar_tool/ptcli/cli/pillar.py +++ b/pillar_tool/ptcli/cli/pillar.py @@ -34,9 +34,9 @@ def pillar_get(fqdn): @click.option("--value") def pillar_set(name: str, host: str | None, hostgroup: str | None, parameter_type: str | None, value: str | None): try: - if parameter_type == 'str': + if parameter_type == 'string': pass # there is nothing to do here - elif parameter_type == 'int': + elif parameter_type == 'integer': _ = int(value) elif parameter_type == 'float': _ = float(value) diff --git a/pillar_tool/routers/host.py b/pillar_tool/routers/host.py index 259663a..6217c18 100644 --- a/pillar_tool/routers/host.py +++ b/pillar_tool/routers/host.py @@ -178,6 +178,7 @@ async def host_delete(request: Request, fqdn: str): Removes the host entry and associated data from the database. TODO: Implement actual deletion logic - currently just a stub. + Args: request: FastAPI request object containing database session fqdn: Fully qualified domain name of the host to delete diff --git a/pillar_tool/routers/pillar.py b/pillar_tool/routers/pillar.py index defc24b..6e97ba8 100644 --- a/pillar_tool/routers/pillar.py +++ b/pillar_tool/routers/pillar.py @@ -55,9 +55,9 @@ def pillar_get(req: Request, fqdn: str): path.append(tmp) path.reverse() - out = merge(get_pillar_for_target(db, host.id) for host in path) + out = merge([get_pillar_for_target(db, host.id) for host in path]) - return JSONResponse(status_code=200, content={}) + return JSONResponse(status_code=200, content=out) @@ -88,8 +88,19 @@ def pillar_create(req: Request, name: str, params: PillarParams): target: Host = result[0][0] elif params.hostgroup is not None: path = split_and_validate_path(params.hostgroup) + if not path: + return JSONResponse(status_code=400, content={'message': "No target specified"}) last = None - group_stmt = select(Host).where(Host.is_hostgroup == True and Host.parent_id == bindparam('parent') ) + current = None + group_stmt = select(Host).where(Host.is_hostgroup == True and Host.parent_id == bindparam('parent') and Host.name == bindparam('name') ) + for label in path: + result = db.execute(group_stmt, {'name': label, 'parent': last}).fetchall() + if len(result) == 0: + return JSONResponse(status_code=404, content={'message': f"No hostgroup named: {params.hostgroup}"}) + # Note: this should be enforced by the database + assert len(result) == 1 + current: Host = result[0][0] + last = current.id else: return JSONResponse(status_code=400, content={'message': "Neither host nor hostgroup set"}) @@ -133,6 +144,6 @@ def pillar_create(req: Request, name: str, params: PillarParams): @router.delete("/{name}") -def pillar_delete(req: Request, name: str): +def pillar_delete(req: Request, name: str, params: PillarParams): # TODO: implement db = req.state.db diff --git a/pillar_tool/util/pillar_utilities.py b/pillar_tool/util/pillar_utilities.py index e0e4284..ce9d0cf 100644 --- a/pillar_tool/util/pillar_utilities.py +++ b/pillar_tool/util/pillar_utilities.py @@ -24,7 +24,7 @@ def apply_layer(base: dict, layer: dict): base[key] = value -def merge(*pillar_data, deep_copy=True) -> dict: +def merge(pillar_data, deep_copy=True) -> dict: """ Merges multiple pillar data dictionaries into one. @@ -48,5 +48,4 @@ def merge(*pillar_data, deep_copy=True) -> dict: for pillar in pillar_data[1:]: apply_layer(merged_pillar, pillar) - return merged_pillar \ No newline at end of file