diff --git a/pillar_tool/routers/pillar.py b/pillar_tool/routers/pillar.py index 6e97ba8..64dd7b8 100644 --- a/pillar_tool/routers/pillar.py +++ b/pillar_tool/routers/pillar.py @@ -2,8 +2,9 @@ import json import uuid from uuid import uuid4 +from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import insert -from sqlalchemy import select, delete, bindparam +from sqlalchemy import select, delete, bindparam, and_, or_ from sqlalchemy.orm import Session from starlette.exceptions import HTTPException from starlette.requests import Request @@ -78,8 +79,6 @@ def pillar_create(req: Request, name: str, params: PillarParams): target_stmt = select(Host).where(Host.name == params.host) result = db.execute(target_stmt).fetchall() - print(name, result) - if len(result) == 0: return JSONResponse(status_code=404, content={}) @@ -92,15 +91,18 @@ def pillar_create(req: Request, name: str, params: PillarParams): return JSONResponse(status_code=400, content={'message': "No target specified"}) last = None current = None - group_stmt = select(Host).where(Host.is_hostgroup == True and Host.parent_id == bindparam('parent') and Host.name == bindparam('name') ) + # Note: both statements need to be present, since '==' will not work for None and 'is' will not work for a UUID + group_stmt = select(Host).where(and_(Host.is_hostgroup == True, Host.parent_id == bindparam('parent'), Host.name == bindparam('name'))) + group_stmt_none = select(Host).where(and_(Host.is_hostgroup == True, Host.parent_id.is_(None), Host.name == bindparam('name'))) for label in path: - result = db.execute(group_stmt, {'name': label, 'parent': last}).fetchall() + result = db.execute(group_stmt if last is not None else group_stmt_none, {'name': label, 'parent': last}).fetchall() if len(result) == 0: return JSONResponse(status_code=404, content={'message': f"No hostgroup named: {params.hostgroup}"}) # Note: this should be enforced by the database - assert len(result) == 1 + assert len(result) == 1, f"Result: {[x[0].name for x in result]}" current: Host = result[0][0] last = current.id + target: Host = current else: return JSONResponse(status_code=400, content={'message': "Neither host nor hostgroup set"}) @@ -128,6 +130,7 @@ def pillar_create(req: Request, name: str, params: PillarParams): { 'name': name, 'value': params.value, 'type': params.type } ] + print(f"Pillars to store: {pillars_to_store}") # store pillar data insert_stmt = insert(Pillar).values(id=bindparam('new_id'), host_id=target.id, pillar_name=bindparam('name'), parameter_type=bindparam('type'), value=bindparam('value')) upsert_stmt = insert_stmt.on_conflict_do_update(constraint='pillar_unique_pillar_name', set_={'parameter_type': bindparam('type'), 'value': bindparam('value')} ) @@ -135,7 +138,6 @@ def pillar_create(req: Request, name: str, params: PillarParams): for instance in pillars_to_store: instance['new_id'] = uuid4() result = db.execute(upsert_stmt, instance) - print(result) return JSONResponse(status_code=200, content={'message': 'ok'}) diff --git a/pillar_tool/util/validation.py b/pillar_tool/util/validation.py index 87a8098..ab1523e 100644 --- a/pillar_tool/util/validation.py +++ b/pillar_tool/util/validation.py @@ -84,18 +84,18 @@ def split_and_validate_path(path: str) -> list[str] | None: def type_from_name(data_type: str) -> type | None: match data_type: - case 'int': return int + case 'integer': return int case 'float': return float - case 'str': return str + case 'string': return str case 'bool': return bool case 'dict': return dict case 'list': return list case _: raise ValueError(f"Invalid pillar input: Unsupported data type: {data_type}") def name_from_type(value) -> str: - if type(value) is int: return 'int' + if type(value) is int: return 'integer' if type(value) is float: return 'float' - if type(value) is str: return 'str' + if type(value) is str: return 'string' if type(value) is bool: return 'bool' if type(value) is dict: return 'dict' if type(value) is list: return 'list'