diff --git a/pillar_tool/db/migrations/versions/2026_02_21_2338-ec7c818f92b5_renamed_bad_parameter.py b/pillar_tool/db/migrations/versions/2026_02_21_2338-ec7c818f92b5_renamed_bad_parameter.py new file mode 100644 index 0000000..471738d --- /dev/null +++ b/pillar_tool/db/migrations/versions/2026_02_21_2338-ec7c818f92b5_renamed_bad_parameter.py @@ -0,0 +1,34 @@ +"""renamed bad parameter + +Revision ID: ec7c818f92b5 +Revises: 58c2a8e7c302 +Create Date: 2026-02-21 23:38:00.609470 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'ec7c818f92b5' +down_revision: Union[str, Sequence[str], None] = '58c2a8e7c302' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('pillar_tool_pillar', sa.Column('parameter_type', sa.String(), nullable=False)) + op.drop_column('pillar_tool_pillar', 'type') + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('pillar_tool_pillar', sa.Column('type', sa.VARCHAR(), autoincrement=False, nullable=False)) + op.drop_column('pillar_tool_pillar', 'parameter_type') + # ### end Alembic commands ### diff --git a/pillar_tool/db/models/pillar_data.py b/pillar_tool/db/models/pillar_data.py index c20cafb..8280c75 100644 --- a/pillar_tool/db/models/pillar_data.py +++ b/pillar_tool/db/models/pillar_data.py @@ -11,7 +11,7 @@ class Pillar(Base): id = Column(UUID, primary_key=True) pillar_name = Column(String, nullable=False) host_id = Column(UUID, ForeignKey('pillar_tool_host.id'), nullable=True) - type = Column(String, nullable=False) + parameter_type = Column(String, nullable=False) value = Column(String, nullable=False) __table_args__ = ( diff --git a/pillar_tool/ptcli/cli/pillar.py b/pillar_tool/ptcli/cli/pillar.py index f019b93..0d94953 100644 --- a/pillar_tool/ptcli/cli/pillar.py +++ b/pillar_tool/ptcli/cli/pillar.py @@ -1,3 +1,5 @@ +import json + import click import requests @@ -22,4 +24,38 @@ def pillar_get(fqdn): pillar_data = response.json() click.echo(pillar_data) except requests.exceptions.RequestException as e: - click.echo(f"Error: {e}") \ No newline at end of file + click.echo(f"Error: {e}") + +@pillar.command("set") +@click.argument("name") +@click.option("--host") +@click.option("--hostgroup") +@click.option("--parameter-type") +@click.option("--value") +def pillar_set(name: str, host: str | None, hostgroup: str | None, parameter_type: str | None, value: str | None): + try: + if parameter_type == 'str': + pass # there is nothing to do here + elif parameter_type == 'int': + _ = int(value) + elif parameter_type == 'float': + _ = float(value) + elif parameter_type == 'bool': + _ = bool(value) + elif parameter_type == 'list': + value = json.loads(value) + elif parameter_type == 'dict': + value = json.loads(value) + else: + raise ValueError("Invalid parameter type") + except ValueError as e: + print(f"Failed to validate value: {e}") + else: + data = { + 'host': host, + 'hostgroup': hostgroup, + 'type': parameter_type, + 'value': json.dumps(value) + } + + requests.post(f"{base_url()}/pillar/{name}", headers=auth_header(), json=data) \ No newline at end of file diff --git a/pillar_tool/routers/pillar.py b/pillar_tool/routers/pillar.py index 7b8cca7..defc24b 100644 --- a/pillar_tool/routers/pillar.py +++ b/pillar_tool/routers/pillar.py @@ -1,18 +1,22 @@ +import json import uuid +from uuid import uuid4 -from sqlalchemy import select, insert, delete, bindparam +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy import select, delete, bindparam from sqlalchemy.orm import Session from starlette.exceptions import HTTPException from starlette.requests import Request from fastapi import APIRouter, Depends from starlette.responses import JSONResponse -from pillar_tool.db import Host +from pillar_tool.db import Host, Pillar from pillar_tool.db.models.top_data import State, StateAssignment from pillar_tool.db.queries.pillar_queries import get_pillar_for_target from pillar_tool.schemas import PillarParams, get_model_from_query from pillar_tool.util.pillar_utilities import merge -from pillar_tool.util.validation import validate_state_name, validate_fqdn +from pillar_tool.util.validation import validate_state_name, validate_fqdn, validate_pillar_input_data, \ + split_and_validate_path router = APIRouter( prefix="/pillar", @@ -59,21 +63,70 @@ def pillar_get(req: Request, fqdn: str): @router.post("/{name}") def pillar_create(req: Request, name: str, params: PillarParams): - db = req.state.db + db: Session = req.state.db # ensure that value and type have been set in the request parameters if params.type is None or params.value is None: + return JSONResponse(status_code=400, content={ + 'message': "Both parameter type and value need to be set!" + }) + # validate pillar data + pillar_data = validate_pillar_input_data(params.value, params.type) - target_stmt = select(Host).where(Host.name == name) - result = db.execute(target_stmt).fetchall() + if params.host is not None: + target_stmt = select(Host).where(Host.name == params.host) + result = db.execute(target_stmt).fetchall() - if len(result) == 0: - return JSONResponse(status_code=404, content={}) + print(name, result) - # this should be enforced by the database - assert len(result) == 1 - target: Host = result[0][0] + if len(result) == 0: + return JSONResponse(status_code=404, content={}) + + # this should be enforced by the database + assert len(result) == 1 + target: Host = result[0][0] + elif params.hostgroup is not None: + path = split_and_validate_path(params.hostgroup) + last = None + group_stmt = select(Host).where(Host.is_hostgroup == True and Host.parent_id == bindparam('parent') ) + else: + return JSONResponse(status_code=400, content={'message': "Neither host nor hostgroup set"}) + + # if this is a dictionary value, parse it and create a separate entry for all the sub-pillars + if type(pillar_data) == dict: + def aux(prefix: str, input_value: dict) -> list[dict[str, str]]: + out = [] + for key, value in input_value.items(): + if type(value) is dict: + out += aux(f"{prefix}:{key}", value) + else: + out += [{ + 'name': f"{prefix}:{key}", + 'type': type(value).__name__, + 'value': json.dumps(value) + }] + + return out + + pillars_to_store = aux(name, pillar_data) + + else: + # build the pillar package + pillars_to_store= [ + { 'name': name, 'value': params.value, 'type': params.type } + ] + + # store pillar data + insert_stmt = insert(Pillar).values(id=bindparam('new_id'), host_id=target.id, pillar_name=bindparam('name'), parameter_type=bindparam('type'), value=bindparam('value')) + upsert_stmt = insert_stmt.on_conflict_do_update(constraint='pillar_unique_pillar_name', set_={'parameter_type': bindparam('type'), 'value': bindparam('value')} ) + + for instance in pillars_to_store: + instance['new_id'] = uuid4() + result = db.execute(upsert_stmt, instance) + print(result) + + return JSONResponse(status_code=200, content={'message': 'ok'}) diff --git a/pillar_tool/schemas.py b/pillar_tool/schemas.py index b39047f..fe9e31a 100644 --- a/pillar_tool/schemas.py +++ b/pillar_tool/schemas.py @@ -52,7 +52,8 @@ class StateParams(BaseModel): # Pillar operations class PillarParams(BaseModel): - target: str # must be host or hostgroup + host: str | None + hostgroup: str | None value: str | None # value if the pillar should be set type: str | None # type of pillar if pillar should be set diff --git a/pillar_tool/util/validation.py b/pillar_tool/util/validation.py index b7900ef..87a8098 100644 --- a/pillar_tool/util/validation.py +++ b/pillar_tool/util/validation.py @@ -1,3 +1,4 @@ +import json import re @@ -81,3 +82,28 @@ def split_and_validate_path(path: str) -> list[str] | None: return labels +def type_from_name(data_type: str) -> type | None: + match data_type: + case 'int': return int + case 'float': return float + case 'str': return str + case 'bool': return bool + case 'dict': return dict + case 'list': return list + case _: raise ValueError(f"Invalid pillar input: Unsupported data type: {data_type}") + +def name_from_type(value) -> str: + if type(value) is int: return 'int' + if type(value) is float: return 'float' + if type(value) is str: return 'str' + if type(value) is bool: return 'bool' + if type(value) is dict: return 'dict' + if type(value) is list: return 'list' + raise ValueError(f"Invalid pillar input: Unsupported data type: {type(value)}") + +def validate_pillar_input_data(value: str, data_type: str): + decoded_data = json.loads(value) + if type(decoded_data) is type_from_name(data_type): + return decoded_data + raise ValueError(f"Invalid pillar input: datatype does not match value") +