diff --git a/pillar_tool/main.py b/pillar_tool/main.py index 2cbd018..bf16e30 100644 --- a/pillar_tool/main.py +++ b/pillar_tool/main.py @@ -26,6 +26,7 @@ from pillar_tool.db.database import run_db_migrations from pillar_tool.routers.host import router as host_router from pillar_tool.routers.hostgroup import router as hostgroup_router from pillar_tool.routers.environment import router as environment_router +from pillar_tool.routers.state import router as state_router # run any pending migrations run_db_migrations() @@ -72,6 +73,7 @@ app.exception_handler(Exception)(on_general_error) app.include_router(host_router) app.include_router(hostgroup_router) app.include_router(environment_router) +app.include_router(state_router) @app.get("/") async def root(): diff --git a/pillar_tool/ptcli/cli/state.py b/pillar_tool/ptcli/cli/state.py index a29ef3e..a1b38ae 100644 --- a/pillar_tool/ptcli/cli/state.py +++ b/pillar_tool/ptcli/cli/state.py @@ -2,8 +2,73 @@ import click import requests from .cli_main import main, auth_header, base_url +from ...schemas import StateParams @main.group("state") def state(): - pass \ No newline at end of file + pass + + +@state.command("list") +def state_list(): + click.echo("Listing known states...") + try: + response = requests.get(f'{base_url()}/state', headers=auth_header()) + response.raise_for_status() + + click.echo("States:") + for st in response.json(): + click.echo(f" - {st}") + except requests.exceptions.HTTPError as e: + raise click.ClickException(f"Failed to list states:\n{e}") + + +@state.command("show") +@click.argument("name") +def state_show(name: str): + click.echo(f"Showing state '{name}'...") + try: + data = StateParams() + response = requests.get(f'{base_url()}/state/{name}', headers=auth_header(), params=data.model_dump()) + response.raise_for_status() + + click.echo("State details:") + for key, value in response.json().items(): + if isinstance(value, dict): + click.echo(f" {key}:") + for sub_key, sub_val in value.items(): + click.echo(f" {sub_key}: {sub_val}") + else: + click.echo(f" {key}: {value}") + except requests.exceptions.HTTPError as e: + raise click.ClickException(f"Failed to show state:\n{e}") + + +@state.command("create") +@click.argument("name") +def state_create(name: str): + click.echo(f"Creating state '{name}'...") + try: + data = StateParams() + response = requests.post(f'{base_url()}/state/{name}', headers=auth_header(), json=data.model_dump()) + response.raise_for_status() + + click.echo("State created successfully:") + for key, value in response.json().items(): + click.echo(f" {key}: {value}") + except requests.exceptions.HTTPError as e: + raise click.ClickException(f"Failed to create state:\n{e}") + + +@state.command("delete") +@click.argument("name") +def state_delete(name: str): + click.echo(f"Deleting state '{name}'...") + try: + response = requests.delete(f'{base_url()}/state/{name}', headers=auth_header()) + response.raise_for_status() + + click.echo("State deleted successfully.") + except requests.exceptions.HTTPError as e: + raise click.ClickException(f"Failed to delete state:\n{e}") \ No newline at end of file diff --git a/pillar_tool/routers/environment.py b/pillar_tool/routers/environment.py index 9825b36..0c89258 100644 --- a/pillar_tool/routers/environment.py +++ b/pillar_tool/routers/environment.py @@ -10,7 +10,7 @@ from starlette.responses import JSONResponse from pillar_tool.db import Host from pillar_tool.db.models.top_data import Environment, EnvironmentAssignment -from pillar_tool.schemas import HostgroupParams, get_hostgroup_params_from_query, get_model_from_query +from pillar_tool.schemas import HostgroupParams, get_model_from_query from pillar_tool.util.validation import split_and_validate_path, validate_environment_name router = APIRouter( diff --git a/pillar_tool/routers/hostgroup.py b/pillar_tool/routers/hostgroup.py index 22b371c..988e77b 100644 --- a/pillar_tool/routers/hostgroup.py +++ b/pillar_tool/routers/hostgroup.py @@ -9,7 +9,7 @@ from fastapi import APIRouter, Query, Depends from starlette.responses import JSONResponse from pillar_tool.db import Host -from pillar_tool.schemas import HostgroupParams, get_hostgroup_params_from_query, get_model_from_query +from pillar_tool.schemas import HostgroupParams, get_model_from_query from pillar_tool.util.validation import split_and_validate_path router = APIRouter( diff --git a/pillar_tool/routers/state.py b/pillar_tool/routers/state.py index e69de29..767c165 100644 --- a/pillar_tool/routers/state.py +++ b/pillar_tool/routers/state.py @@ -0,0 +1,177 @@ +import uuid + +from sqlalchemy import select, insert, delete +from sqlalchemy.orm import Session +from starlette.exceptions import HTTPException +from starlette.requests import Request +from fastapi import APIRouter +from starlette.responses import JSONResponse + +from pillar_tool.db.models.top_data import State, StateAssignment +from pillar_tool.util.validation import validate_state_name + +router = APIRouter( + prefix="/state", + tags=["state"], +) + + +@router.get("") +def states_get(req: Request): + """ + Retrieve all states. + + Fetches and returns a list of state names from the database. + + Returns: + JSONResponse: A JSON response with status code 200 containing a list of state names (strings). + """ + db: Session = req.state.db + + result = db.execute(select(State)).fetchall() + states: list[State] = list(map(lambda x: x[0], result)) + + return JSONResponse(status_code=200, content=[state.name for state in states]) + + +@router.get("/{name}") +def state_get(req: Request, name: str): + """ + Retrieve a specific state by name. + + Fetches and returns details of the specified state. + Returns 404 if no such state exists. + + Args: + req (Request): The incoming request object. + name (str): The name of the state to retrieve. + + Returns: + JSONResponse: A JSON response with status code 200 and the state details on success, + or 404 if not found. + """ + db: Session = req.state.db + + # Validate name before query + if not validate_state_name(name): + raise HTTPException(status_code=400, detail="Invalid state name format") + + stmt = select(State).where(State.name == name) + result = db.execute(stmt).fetchall() + + if len(result) == 0: + raise HTTPException(status_code=404, detail="No such state exists") + + assert len(result) == 1 + + state: State = result[0][0] + + # Get assigned hosts count as an example of additional info + assignments_stmt = select(StateAssignment).where( + StateAssignment.state_id == state.id + ) + assignments_count = db.execute(assignments_stmt).fetchall().__len__() + + return JSONResponse(status_code=200, content={ + 'state': state.name, + 'assignment_count': assignments_count + }) + + +@router.post("/{name}") +def state_create(req: Request, name: str): + """ + Create a new state. + + Creates a new state record in the database with the provided parameters. + + Args: + req (Request): The incoming request object. + name (str): The name of the state (must be unique). + + Returns: + JSONResponse: A JSON response with status code 201 on success, + or appropriate error codes (e.g., 409 if already exists, 400 for invalid format). + """ + db = req.state.db + + # Validate name format + if not validate_state_name(name): + raise HTTPException(status_code=400, + detail="Invalid state name. State names must start with a letter or underscore and contain only alphanumeric characters, underscores, or dashes.") + + # Check if state already exists + stmt_check = select(State).where(State.name == name) + existing = db.execute(stmt_check).fetchall() + + if len(existing) > 0: + raise HTTPException(status_code=409, detail="State already exists") + + new_id = uuid.uuid4() + db.execute(insert(State).values(id=new_id, name=name)) + + return JSONResponse(status_code=201, content={ + 'id': str(new_id), + 'name': name + }) + + +@router.delete("/{name}") +def state_delete(req: Request, name: str): + """ + Delete a state by name. + + Deletes the specified state from the database. + Returns 409 if hosts are still assigned to this state. + Returns 404 if no such state exists. + + Args: + req (Request): The incoming request object. + name (str): The name of the state to delete. + + Returns: + JSONResponse: A JSON response with status code 204 on successful deletion, + or appropriate error codes for conflicts or not found. + """ + db = req.state.db + + # Validate name format + if not validate_state_name(name): + raise HTTPException(status_code=400, detail="Invalid state name format") + + stmt = select(State).where(State.name == name) + result = db.execute(stmt).fetchall() + + if len(result) == 0: + raise HTTPException(status_code=404, detail="No such state exists") + + assert len(result) == 1 + + state: State = result[0][0] + + # Check for assigned hosts before deleting + assignments_stmt = select(StateAssignment).where( + StateAssignment.state_id == state.id + ) + assignments = db.execute(assignments_stmt).fetchall() + + if len(assignments) > 0: + host_ids_stmt = select(StateAssignment.host_id).where( + StateAssignment.state_id == state.id + ) + host_ids = [row[0] for row in db.execute(host_ids_stmt).fetchall()] + + # Get host names (could optimize this) + from pillar_tool.db import Host + hosts_stmt = select(Host).where(Host.id.in_(host_ids)) + hosts: list[Host] = list(map(lambda x: x[0], db.execute(hosts_stmt).fetchall())) + + return JSONResponse(status_code=409, content={ + 'message': "Cannot delete a state that still has host assignments", + 'assigned_hosts': [h.name for h in hosts] + }) + + # Delete the state + db.execute(delete(State).where(State.id == state.id)) + + return JSONResponse(status_code=204, content={}) \ No newline at end of file diff --git a/pillar_tool/schemas.py b/pillar_tool/schemas.py index 27752dc..03e1d86 100644 --- a/pillar_tool/schemas.py +++ b/pillar_tool/schemas.py @@ -1,4 +1,4 @@ -from typing import Type, Callable +from typing import Callable from pydantic import BaseModel from starlette.requests import Request @@ -46,9 +46,12 @@ class HostCreateParams(BaseModel): class HostgroupParams(BaseModel): path: str | None +# State operations +class StateParams(BaseModel): + pass # No parameters needed for state operations currently + + -def get_hostgroup_params_from_query(req: Request): - return HostgroupParams(**dict(req.query_params)) def get_model_from_query[T](model: T) -> Callable[[Request], T]: def aux(req: Request) -> T: diff --git a/pillar_tool/util/validation.py b/pillar_tool/util/validation.py index af5e4cb..b7900ef 100644 --- a/pillar_tool/util/validation.py +++ b/pillar_tool/util/validation.py @@ -6,6 +6,8 @@ FQDN_REGEX = re.compile(r'^([a-zA-Z0-9.-]+\.)+[a-zA-Z]{2,}$') ENV_NAME_REGEX = re.compile(r'^[a-zA-Z0-9_-]+$') +STATE_NAME_REGEX = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_-]*(\.[a-zA-Z_][a-zA-Z0-9_-]*)*$') + # TODO: improve doc comment for this function def validate_environment_name(name: str) -> bool: @@ -24,6 +26,23 @@ def validate_environment_name(name: str) -> bool: """ return bool(ENV_NAME_REGEX.match(name)) +def validate_state_name(name: str) -> bool: + """ + Validates a state name. + + Args: + name: The state name to validate (e.g., "active", "pending_removal") + + Returns: + True if the name contains only alphanumeric characters, underscores, or dashes, + and starts with an alphabetic character or underscore. + False otherwise. + + Note: + State names cannot be empty and must match the pattern [a-zA-Z_][a-zA-Z0-9_-]* + """ + return bool(STATE_NAME_REGEX.match(name)) + def validate_fqdn(fqdn: str) -> bool: """