cli mostly works now, possibly ready for a first test

This commit is contained in:
Linus Vogel 2026-05-10 12:37:30 +02:00
parent c7b7e8f6f8
commit a4c9bf8e6f
10 changed files with 94 additions and 59 deletions

View File

@ -13,7 +13,7 @@ def get_pillar_name_sequence(name: str) -> list[str]:
def decode_pillar_value(pillar: Pillar) -> str | int | float | bool | list | dict: def decode_pillar_value(pillar: Pillar) -> str | int | float | bool | list | dict:
match pillar.parameter_type: match pillar.parameter_type:
case 'string': return pillar.value case 'string': return str(pillar.value)
case 'integer': return int(pillar.value) case 'integer': return int(pillar.value)
case 'float': return float(pillar.value) case 'float': return float(pillar.value)
case 'boolean': return bool(pillar.value) case 'boolean': return bool(pillar.value)
@ -32,10 +32,17 @@ def get_pillar_for_target(db: Session, target: UUID) -> dict:
name = row.pillar_name name = row.pillar_name
value = decode_pillar_value(row) value = decode_pillar_value(row)
labels = get_pillar_name_sequence(name) labels = get_pillar_name_sequence(name)
print(f"Labels: {labels}, Value: {value}")
current = out current = out
l = len(labels) l = len(labels)
for i, label in enumerate(labels): for i, label in enumerate(labels):
if label not in current: if i < l-1:
current[label] = {} if i < l-1 else value if label not in current:
current[label] = {}
elif type(current[label]) is not dict:
current[label] = {}
current = current[label]
else:
current[label] = json.loads(str(value))
return out return out

View File

@ -1,6 +1,5 @@
from .host import host from .host import host
from .hostgroup import hostgroup from .hostgroup import hostgroup
from .query import query
from .state import state from .state import state
from .pillar import pillar from .pillar import pillar
from .environment import environment from .environment import environment

View File

@ -39,7 +39,7 @@ def hostgroup_show(path: str):
response = requests.get(f'{base_url()}/hostgroup/{name}', headers=auth_header(), params=data.model_dump()) response = requests.get(f'{base_url()}/hostgroup/{name}', headers=auth_header(), params=data.model_dump())
response.raise_for_status() response.raise_for_status()
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
raise click.ClickException(f"Failed to show hostgroup:\n{e}") raise click.ClickException(f"Failed to show hostgroup:\n{e.response.text}")
@hostgroup.command("create") @hostgroup.command("create")
@ -56,7 +56,7 @@ def hostgroup_create(path: str):
response = requests.post(f'{base_url()}/hostgroup/{name}', headers=auth_header(), json=data.model_dump()) response = requests.post(f'{base_url()}/hostgroup/{name}', headers=auth_header(), json=data.model_dump())
response.raise_for_status() response.raise_for_status()
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
raise click.ClickException(f"Failed to create hostgroup:\n{e}") raise click.ClickException(f"Failed to create hostgroup:\n{e.response.text}")
@hostgroup.command("delete") @hostgroup.command("delete")
@ -71,4 +71,4 @@ def hostgroup_delete(path: str):
response = requests.delete(f'{base_url()}/hostgroup/{name}{query_params}', headers=auth_header()) response = requests.delete(f'{base_url()}/hostgroup/{name}{query_params}', headers=auth_header())
response.raise_for_status() response.raise_for_status()
except requests.exceptions.HTTPError as e: except requests.exceptions.HTTPError as e:
raise click.ClickException(f"Failed to delete hostgroup:\n{e}") raise click.ClickException(f"Failed to delete hostgroup:\n{e.response.text}")

View File

@ -14,19 +14,20 @@ def pillar():
@pillar.command("get") @pillar.command("get")
@click.argument("fqdn") @click.argument("target")
def pillar_get(fqdn): def pillar_get(target: str):
"""Get pillar data for a given FQDN.""" """Get pillar data for a given host or hostgroup."""
try: try:
response = requests.get( response = requests.get(
f"{base_url()}/pillar/{fqdn}", f"{base_url()}/pillar/{target.replace('/', "%%2f")}",
headers=auth_header(), headers=auth_header()
) )
response.raise_for_status() response.raise_for_status()
pillar_data = response.json() pillar_data = response.json()
click.echo(yaml.dump({fqdn: pillar_data})) click.echo(json.dumps(pillar_data, indent=2))
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
click.echo(f"Error: {e}") click.echo(f"Error: {e.response.json()}")
@pillar.command("set") @pillar.command("set")
@click.argument("name") @click.argument("name")

View File

@ -1,9 +0,0 @@
import click
import requests
from .cli_main import main, auth_header, base_url
@main.group("query")
def query():
pass

View File

@ -45,7 +45,7 @@ def top_assign(host: str, state: str):
click.echo("Assigning state to host...") click.echo("Assigning state to host...")
try: try:
response = requests.post(f'{base_url()}/top/assign/{host}/{state}', headers=auth_header()) response = requests.post(f'{base_url()}/top/assign/{host.replace('/', "%%2f")}/{state}', headers=auth_header())
response.raise_for_status() response.raise_for_status()
click.echo("Assigned state") click.echo("Assigned state")

View File

@ -1,11 +1,11 @@
import uuid import uuid
from typing import Annotated
from sqlalchemy import select, insert, bindparam, delete, and_ from sqlalchemy import select, insert, bindparam, delete, and_
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.requests import Request from starlette.requests import Request
from fastapi import APIRouter, Query, Depends from fastapi import APIRouter, Depends
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from pillar_tool.db import Host from pillar_tool.db import Host
@ -45,7 +45,7 @@ def hostgroups_get(req: Request):
return JSONResponse(status_code=200, content=all_hostgroup_names) return JSONResponse(status_code=200, content=all_hostgroup_names)
@router.get("/{name}") @router.get("/{name}")
def hostgroup_get(req: Request, name: str, params: HostgroupParams = Depends(get_model_from_query(HostgroupParams))): def hostgroup_get(req: Request, name: str, params: HostgroupParams):
""" """
Retrieve a specific host group by name with additional details Retrieve a specific host group by name with additional details
@ -66,7 +66,11 @@ def hostgroup_get(req: Request, name: str, params: HostgroupParams = Depends(get
# decode the path # decode the path
last = None last = None
ancestors = [] ancestors = []
path = split_and_validate_path(params.path) if params.path else [] path = []
if params:
path = split_and_validate_path(params.path) if params.path else []
print("test")
# get the path from the db # get the path from the db
path_stmt = select(Host).where(and_(Host.name == bindparam('name') and Host.parent_id == bindparam('parent_id'))) path_stmt = select(Host).where(and_(Host.name == bindparam('name') and Host.parent_id == bindparam('parent_id')))
@ -117,7 +121,7 @@ def hostgroup_create(req: Request, name: str, params: HostgroupParams):
""" """
db = req.state.db db = req.state.db
path = params.path path = params.path
labels = split_and_validate_path(path) if path is not None else [] labels = ( split_and_validate_path(path) if path is not None else [] ) or []
labels += [ name ] labels += [ name ]
stmt = select(Host).where(and_(Host.name == bindparam('name'), Host.is_hostgroup == True, Host.parent_id == bindparam('last'))) stmt = select(Host).where(and_(Host.name == bindparam('name'), Host.is_hostgroup == True, Host.parent_id == bindparam('last')))

View File

@ -26,8 +26,8 @@ router = APIRouter(
# Note: there is no list of all pillars, as this would not be helpful # Note: there is no list of all pillars, as this would not be helpful
@router.get("/{fqdn}") @router.get("/{target}")
def pillar_get(req: Request, fqdn: str): def pillar_get(req: Request, target: str) -> JSONResponse:
# TODO: implement # TODO: implement
# this function should: # this function should:
# - get the affected host hierarchy # - get the affected host hierarchy
@ -37,27 +37,54 @@ def pillar_get(req: Request, fqdn: str):
# if any error happens, return non-200 status and an empty dictionary so that salt does not shit itself # if any error happens, return non-200 status and an empty dictionary so that salt does not shit itself
db: Session = req.state.db db: Session = req.state.db
# get the host hierarchy # if the target is a hostgroup with path, then split the path and get to the target host this way
host_stmt = select(Host).where(and_(Host.name == fqdn, Host.is_hostgroup == False)) target = target.replace("%%2F", "%%2f")
result = db.execute(host_stmt).fetchall() if "%%2f" in target:
if len(result) == 0: path_labels = target.split("%%2f")
return JSONResponse(status=404, content={})
# NOTE: should be enforced by the database
assert len(result) == 1
host: Host = result[0][0]
path: list[Host] = [host] host_stmt_remain = select(Host).where(and_(Host.name == bindparam('frag'), Host.parent_id == bindparam('parent_id')))
parent_stmt = select(Host).where(Host.id == bindparam('parent')) host_stmt_first = select(Host).where(and_(Host.name == bindparam('frag'), Host.parent_id == None))
while path[-1].parent_id is not None: host_stmt = host_stmt_first
result = db.execute(parent_stmt, {'parent': path[-1].parent_id}).fetchall()
parent_id = None
path: list[Host] = []
for fragment in path_labels:
result = db.execute(host_stmt, {"frag": fragment, "parent_id": parent_id}).fetchall()
host_stmt = host_stmt_remain
if len(result) == 0:
return JSONResponse(status_code=404, content={"message": f"No such path fragment: {fragment} with parent_id {parent_id}"})
assert len(result) == 1 # Note: that the db should enforce this
current: Host = result[0][0]
parent_id = current.id
path.append(current)
else:
# get the host hierarchy from a fqdn or unique hostgroup name
host_stmt = select(Host).where(Host.name == target)
result = db.execute(host_stmt).fetchall()
if len(result) == 0:
return JSONResponse(status_code=404, content={'message': f'No such target: {target}'})
# NOTE: should be enforced by the database # NOTE: should be enforced by the database
assert len(result) == 1 if len(result) > 1:
tmp: Host = result[0][0] return JSONResponse(status_code=400, content={'message': f'Multiple targets: {target}'})
path.append(tmp)
host: Host = result[0][0]
path: list[Host] = [host]
parent_stmt = select(Host).where(Host.id == bindparam('parent'))
while path[-1].parent_id is not None:
result = db.execute(parent_stmt, {'parent': path[-1].parent_id}).fetchall()
# NOTE: should be enforced by the database
assert len(result) == 1
tmp: Host = result[0][0]
path.append(tmp)
path.reverse()
path.reverse()
out = merge([get_pillar_for_target(db, host.id) for host in path]) out = merge([get_pillar_for_target(db, host.id) for host in path])
return JSONResponse(status_code=200, content=out) return JSONResponse(status_code=200, content=out)
@ -126,11 +153,10 @@ def pillar_create(req: Request, name: str, params: PillarParams):
else: else:
# build the pillar package # build the pillar package
pillars_to_store= [ pillars_to_store = [
{ 'name': name, 'value': params.value, 'type': params.type } { 'name': name, 'value': params.value, 'type': params.type }
] ]
print(f"Pillars to store: {pillars_to_store}")
# store pillar data # store pillar data
insert_stmt = insert(Pillar).values(id=bindparam('new_id'), host_id=target.id, pillar_name=bindparam('name'), parameter_type=bindparam('type'), value=bindparam('value')) insert_stmt = insert(Pillar).values(id=bindparam('new_id'), host_id=target.id, pillar_name=bindparam('name'), parameter_type=bindparam('type'), value=bindparam('value'))
upsert_stmt = insert_stmt.on_conflict_do_update(constraint='pillar_unique_pillar_name', set_={'parameter_type': bindparam('type'), 'value': bindparam('value')} ) upsert_stmt = insert_stmt.on_conflict_do_update(constraint='pillar_unique_pillar_name', set_={'parameter_type': bindparam('type'), 'value': bindparam('value')} )

View File

@ -112,12 +112,19 @@ def top_state_assign(req: Request, host_name: str, state_name: str):
db: Session = req.state.db db: Session = req.state.db
# get the host in question # get the host in question
host_stmt = select(Host).where(Host.name == host_name) path_labels = host_name.replace("%%2F", "%%2f").split("%%2f")
host_res = db.execute(host_stmt).fetchall() parent_id = None
if len(host_res) != 1: for path in path_labels:
return JSONResponse(status_code=404, content={"error": f"Host '{host_name} not found"}) host_stmt = select(Host).where(and_(Host.name == path, Host.parent_id == parent_id))
host_res = db.execute(host_stmt).fetchall()
host: Host = host_res[0][0] if len(host_res) != 1:
return JSONResponse(status_code=404, content={"error": f"Host '{host_name} not found"})
current: Host = host_res[0][0]
parent_id = current.id
host: Host = current
parent_stmt = select(Host).where(Host.id == bindparam("parent_id")) parent_stmt = select(Host).where(Host.id == bindparam("parent_id"))
parents: list[Host] = [] parents: list[Host] = []

View File

@ -53,10 +53,10 @@ class StateParams(BaseModel):
# Pillar operations # Pillar operations
class PillarParams(BaseModel): class PillarParams(BaseModel):
host: str | None host: str | None = None
hostgroup: str | None hostgroup: str | None = None
value: str | None # value if the pillar should be set value: str | None = None # value if the pillar should be set
type: str | None # type of pillar if pillar should be set type: str | None = None # type of pillar if pillar should be set
def get_model_from_query[T](model: T) -> Callable[[Request], T]: def get_model_from_query[T](model: T) -> Callable[[Request], T]: