implemented the state endpoint

This commit is contained in:
Linus Vogel 2026-02-15 10:39:22 +01:00
parent 02437e9c08
commit c31a61336e
7 changed files with 272 additions and 6 deletions

View File

@ -26,6 +26,7 @@ from pillar_tool.db.database import run_db_migrations
from pillar_tool.routers.host import router as host_router
from pillar_tool.routers.hostgroup import router as hostgroup_router
from pillar_tool.routers.environment import router as environment_router
from pillar_tool.routers.state import router as state_router
# run any pending migrations
run_db_migrations()
@ -72,6 +73,7 @@ app.exception_handler(Exception)(on_general_error)
app.include_router(host_router)
app.include_router(hostgroup_router)
app.include_router(environment_router)
app.include_router(state_router)
@app.get("/")
async def root():

View File

@ -2,8 +2,73 @@ import click
import requests
from .cli_main import main, auth_header, base_url
from ...schemas import StateParams
@main.group("state")
def state():
pass
pass
@state.command("list")
def state_list():
click.echo("Listing known states...")
try:
response = requests.get(f'{base_url()}/state', headers=auth_header())
response.raise_for_status()
click.echo("States:")
for st in response.json():
click.echo(f" - {st}")
except requests.exceptions.HTTPError as e:
raise click.ClickException(f"Failed to list states:\n{e}")
@state.command("show")
@click.argument("name")
def state_show(name: str):
click.echo(f"Showing state '{name}'...")
try:
data = StateParams()
response = requests.get(f'{base_url()}/state/{name}', headers=auth_header(), params=data.model_dump())
response.raise_for_status()
click.echo("State details:")
for key, value in response.json().items():
if isinstance(value, dict):
click.echo(f" {key}:")
for sub_key, sub_val in value.items():
click.echo(f" {sub_key}: {sub_val}")
else:
click.echo(f" {key}: {value}")
except requests.exceptions.HTTPError as e:
raise click.ClickException(f"Failed to show state:\n{e}")
@state.command("create")
@click.argument("name")
def state_create(name: str):
click.echo(f"Creating state '{name}'...")
try:
data = StateParams()
response = requests.post(f'{base_url()}/state/{name}', headers=auth_header(), json=data.model_dump())
response.raise_for_status()
click.echo("State created successfully:")
for key, value in response.json().items():
click.echo(f" {key}: {value}")
except requests.exceptions.HTTPError as e:
raise click.ClickException(f"Failed to create state:\n{e}")
@state.command("delete")
@click.argument("name")
def state_delete(name: str):
click.echo(f"Deleting state '{name}'...")
try:
response = requests.delete(f'{base_url()}/state/{name}', headers=auth_header())
response.raise_for_status()
click.echo("State deleted successfully.")
except requests.exceptions.HTTPError as e:
raise click.ClickException(f"Failed to delete state:\n{e}")

View File

@ -10,7 +10,7 @@ from starlette.responses import JSONResponse
from pillar_tool.db import Host
from pillar_tool.db.models.top_data import Environment, EnvironmentAssignment
from pillar_tool.schemas import HostgroupParams, get_hostgroup_params_from_query, get_model_from_query
from pillar_tool.schemas import HostgroupParams, get_model_from_query
from pillar_tool.util.validation import split_and_validate_path, validate_environment_name
router = APIRouter(

View File

@ -9,7 +9,7 @@ from fastapi import APIRouter, Query, Depends
from starlette.responses import JSONResponse
from pillar_tool.db import Host
from pillar_tool.schemas import HostgroupParams, get_hostgroup_params_from_query, get_model_from_query
from pillar_tool.schemas import HostgroupParams, get_model_from_query
from pillar_tool.util.validation import split_and_validate_path
router = APIRouter(

View File

@ -0,0 +1,177 @@
import uuid
from sqlalchemy import select, insert, delete
from sqlalchemy.orm import Session
from starlette.exceptions import HTTPException
from starlette.requests import Request
from fastapi import APIRouter
from starlette.responses import JSONResponse
from pillar_tool.db.models.top_data import State, StateAssignment
from pillar_tool.util.validation import validate_state_name
router = APIRouter(
prefix="/state",
tags=["state"],
)
@router.get("")
def states_get(req: Request):
"""
Retrieve all states.
Fetches and returns a list of state names from the database.
Returns:
JSONResponse: A JSON response with status code 200 containing a list of state names (strings).
"""
db: Session = req.state.db
result = db.execute(select(State)).fetchall()
states: list[State] = list(map(lambda x: x[0], result))
return JSONResponse(status_code=200, content=[state.name for state in states])
@router.get("/{name}")
def state_get(req: Request, name: str):
"""
Retrieve a specific state by name.
Fetches and returns details of the specified state.
Returns 404 if no such state exists.
Args:
req (Request): The incoming request object.
name (str): The name of the state to retrieve.
Returns:
JSONResponse: A JSON response with status code 200 and the state details on success,
or 404 if not found.
"""
db: Session = req.state.db
# Validate name before query
if not validate_state_name(name):
raise HTTPException(status_code=400, detail="Invalid state name format")
stmt = select(State).where(State.name == name)
result = db.execute(stmt).fetchall()
if len(result) == 0:
raise HTTPException(status_code=404, detail="No such state exists")
assert len(result) == 1
state: State = result[0][0]
# Get assigned hosts count as an example of additional info
assignments_stmt = select(StateAssignment).where(
StateAssignment.state_id == state.id
)
assignments_count = db.execute(assignments_stmt).fetchall().__len__()
return JSONResponse(status_code=200, content={
'state': state.name,
'assignment_count': assignments_count
})
@router.post("/{name}")
def state_create(req: Request, name: str):
"""
Create a new state.
Creates a new state record in the database with the provided parameters.
Args:
req (Request): The incoming request object.
name (str): The name of the state (must be unique).
Returns:
JSONResponse: A JSON response with status code 201 on success,
or appropriate error codes (e.g., 409 if already exists, 400 for invalid format).
"""
db = req.state.db
# Validate name format
if not validate_state_name(name):
raise HTTPException(status_code=400,
detail="Invalid state name. State names must start with a letter or underscore and contain only alphanumeric characters, underscores, or dashes.")
# Check if state already exists
stmt_check = select(State).where(State.name == name)
existing = db.execute(stmt_check).fetchall()
if len(existing) > 0:
raise HTTPException(status_code=409, detail="State already exists")
new_id = uuid.uuid4()
db.execute(insert(State).values(id=new_id, name=name))
return JSONResponse(status_code=201, content={
'id': str(new_id),
'name': name
})
@router.delete("/{name}")
def state_delete(req: Request, name: str):
"""
Delete a state by name.
Deletes the specified state from the database.
Returns 409 if hosts are still assigned to this state.
Returns 404 if no such state exists.
Args:
req (Request): The incoming request object.
name (str): The name of the state to delete.
Returns:
JSONResponse: A JSON response with status code 204 on successful deletion,
or appropriate error codes for conflicts or not found.
"""
db = req.state.db
# Validate name format
if not validate_state_name(name):
raise HTTPException(status_code=400, detail="Invalid state name format")
stmt = select(State).where(State.name == name)
result = db.execute(stmt).fetchall()
if len(result) == 0:
raise HTTPException(status_code=404, detail="No such state exists")
assert len(result) == 1
state: State = result[0][0]
# Check for assigned hosts before deleting
assignments_stmt = select(StateAssignment).where(
StateAssignment.state_id == state.id
)
assignments = db.execute(assignments_stmt).fetchall()
if len(assignments) > 0:
host_ids_stmt = select(StateAssignment.host_id).where(
StateAssignment.state_id == state.id
)
host_ids = [row[0] for row in db.execute(host_ids_stmt).fetchall()]
# Get host names (could optimize this)
from pillar_tool.db import Host
hosts_stmt = select(Host).where(Host.id.in_(host_ids))
hosts: list[Host] = list(map(lambda x: x[0], db.execute(hosts_stmt).fetchall()))
return JSONResponse(status_code=409, content={
'message': "Cannot delete a state that still has host assignments",
'assigned_hosts': [h.name for h in hosts]
})
# Delete the state
db.execute(delete(State).where(State.id == state.id))
return JSONResponse(status_code=204, content={})

View File

@ -1,4 +1,4 @@
from typing import Type, Callable
from typing import Callable
from pydantic import BaseModel
from starlette.requests import Request
@ -46,9 +46,12 @@ class HostCreateParams(BaseModel):
class HostgroupParams(BaseModel):
path: str | None
# State operations
class StateParams(BaseModel):
pass # No parameters needed for state operations currently
def get_hostgroup_params_from_query(req: Request):
return HostgroupParams(**dict(req.query_params))
def get_model_from_query[T](model: T) -> Callable[[Request], T]:
def aux(req: Request) -> T:

View File

@ -6,6 +6,8 @@ FQDN_REGEX = re.compile(r'^([a-zA-Z0-9.-]+\.)+[a-zA-Z]{2,}$')
ENV_NAME_REGEX = re.compile(r'^[a-zA-Z0-9_-]+$')
STATE_NAME_REGEX = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_-]*(\.[a-zA-Z_][a-zA-Z0-9_-]*)*$')
# TODO: improve doc comment for this function
def validate_environment_name(name: str) -> bool:
@ -24,6 +26,23 @@ def validate_environment_name(name: str) -> bool:
"""
return bool(ENV_NAME_REGEX.match(name))
def validate_state_name(name: str) -> bool:
"""
Validates a state name.
Args:
name: The state name to validate (e.g., "active", "pending_removal")
Returns:
True if the name contains only alphanumeric characters, underscores, or dashes,
and starts with an alphabetic character or underscore.
False otherwise.
Note:
State names cannot be empty and must match the pattern [a-zA-Z_][a-zA-Z0-9_-]*
"""
return bool(STATE_NAME_REGEX.match(name))
def validate_fqdn(fqdn: str) -> bool:
"""