import uuid from sqlalchemy import select, insert, delete from sqlalchemy.orm import Session from starlette.exceptions import HTTPException from starlette.requests import Request from fastapi import APIRouter from starlette.responses import JSONResponse from pillar_tool.db.models.top_data import State, StateAssignment from pillar_tool.util.validation import validate_state_name router = APIRouter( prefix="/state", tags=["state"], ) @router.get("") def states_get(req: Request): """ Retrieve all states. Fetches and returns a list of state names from the database. Returns: JSONResponse: A JSON response with status code 200 containing a list of state names (strings). """ db: Session = req.state.db result = db.execute(select(State)).fetchall() states: list[State] = list(map(lambda x: x[0], result)) return JSONResponse(status_code=200, content=[state.name for state in states]) @router.get("/{name}") def state_get(req: Request, name: str): """ Retrieve a specific state by name. Fetches and returns details of the specified state. Returns 404 if no such state exists. Args: req (Request): The incoming request object. name (str): The name of the state to retrieve. Returns: JSONResponse: A JSON response with status code 200 and the state details on success, or 404 if not found. """ db: Session = req.state.db # Validate name before query if not validate_state_name(name): raise HTTPException(status_code=400, detail="Invalid state name format") stmt = select(State).where(State.name == name) result = db.execute(stmt).fetchall() if len(result) == 0: raise HTTPException(status_code=404, detail="No such state exists") assert len(result) == 1 state: State = result[0][0] # Get assigned hosts count as an example of additional info assignments_stmt = select(StateAssignment).where( StateAssignment.state_id == state.id ) assignments_count = db.execute(assignments_stmt).fetchall().__len__() return JSONResponse(status_code=200, content={ 'state': state.name, 'assignment_count': assignments_count }) @router.post("/{name}") def state_create(req: Request, name: str): """ Create a new state. Creates a new state record in the database with the provided parameters. Args: req (Request): The incoming request object. name (str): The name of the state (must be unique). Returns: JSONResponse: A JSON response with status code 201 on success, or appropriate error codes (e.g., 409 if already exists, 400 for invalid format). """ db = req.state.db # Validate name format if not validate_state_name(name): raise HTTPException(status_code=400, detail="Invalid state name. State names must start with a letter or underscore and contain only alphanumeric characters, underscores, or dashes.") # Check if state already exists stmt_check = select(State).where(State.name == name) existing = db.execute(stmt_check).fetchall() if len(existing) > 0: raise HTTPException(status_code=409, detail="State already exists") new_id = uuid.uuid4() db.execute(insert(State).values(id=new_id, name=name)) return JSONResponse(status_code=201, content={ 'id': str(new_id), 'name': name }) @router.delete("/{name}") def state_delete(req: Request, name: str): """ Delete a state by name. Deletes the specified state from the database. Returns 409 if hosts are still assigned to this state. Returns 404 if no such state exists. Args: req (Request): The incoming request object. name (str): The name of the state to delete. Returns: JSONResponse: A JSON response with status code 204 on successful deletion, or appropriate error codes for conflicts or not found. """ db = req.state.db # Validate name format if not validate_state_name(name): raise HTTPException(status_code=400, detail="Invalid state name format") stmt = select(State).where(State.name == name) result = db.execute(stmt).fetchall() if len(result) == 0: raise HTTPException(status_code=404, detail="No such state exists") assert len(result) == 1 state: State = result[0][0] # Check for assigned hosts before deleting assignments_stmt = select(StateAssignment).where( StateAssignment.state_id == state.id ) assignments = db.execute(assignments_stmt).fetchall() if len(assignments) > 0: host_ids_stmt = select(StateAssignment.host_id).where( StateAssignment.state_id == state.id ) host_ids = [row[0] for row in db.execute(host_ids_stmt).fetchall()] # Get host names (could optimize this) from pillar_tool.db import Host hosts_stmt = select(Host).where(Host.id.in_(host_ids)) hosts: list[Host] = list(map(lambda x: x[0], db.execute(hosts_stmt).fetchall())) return JSONResponse(status_code=409, content={ 'message': "Cannot delete a state that still has host assignments", 'assigned_hosts': [h.name for h in hosts] }) # Delete the state db.execute(delete(State).where(State.id == state.id)) return JSONResponse(status_code=204, content={})