import uuid from sqlalchemy import select, insert, delete, bindparam, and_ from sqlalchemy.orm import Session from starlette.exceptions import HTTPException from starlette.requests import Request from fastapi import APIRouter from starlette.responses import JSONResponse from pillar_tool.db import TopFile, Host from pillar_tool.db.models.top_data import State, StateAssignment, Environment from pillar_tool.schemas import StateParams from pillar_tool.util.validation import validate_state_name router = APIRouter( prefix="/state", tags=["state"], ) @router.get("") def states_get(req: Request): """ Retrieve all states. Fetches and returns a list of state names from the database. Returns: JSONResponse: A JSON response with status code 200 containing a list of state names (strings). """ db: Session = req.state.db result = db.execute(select(State)).fetchall() states: list[State] = list(map(lambda x: x[0], result)) return JSONResponse(status_code=200, content=[state.name for state in states]) @router.get("/{name}") def state_get(req: Request, name: str): """ Retrieve a specific state by name. Fetches and returns details of the specified state. Returns 404 if no such state exists. Args: req (Request): The incoming request object. name (str): The name of the state to retrieve. Returns: JSONResponse: A JSON response with status code 200 and the state details on success, or 404 if not found. """ db: Session = req.state.db # Validate name before query if not validate_state_name(name): raise HTTPException(status_code=400, detail="Invalid state name format") stmt = select(State).where(State.name == name) result = db.execute(stmt).fetchall() if len(result) == 0: raise HTTPException(status_code=404, detail="No such state exists") assert len(result) == 1 state: State = result[0][0] # Get assigned hosts count as an example of additional info assignments_stmt = select(StateAssignment).where( StateAssignment.state_id == state.id ) assignments_count = db.execute(assignments_stmt).fetchall().__len__() return JSONResponse(status_code=200, content={ 'state': state.name, 'assignment_count': assignments_count }) @router.post("/{name}") def state_create(req: Request, name: str, patch_params: StateParams): """ Create a new state. Creates a new state record in the database with the provided parameters. Args: req (Request): The incoming request object. name (str): The name of the state (must be unique). Returns: JSONResponse: A JSON response with status code 201 on success, or appropriate error codes (e.g., 409 if already exists, 400 for invalid format). """ db = req.state.db # Validate name format if not validate_state_name(name): raise HTTPException(status_code=400, detail="Invalid state name. State names must start with a letter or underscore and contain only alphanumeric characters, underscores, or dashes.") # Check if state already exists stmt_check = select(State).where(State.name == name) existing = db.execute(stmt_check).fetchall() if len(existing) > 0: raise HTTPException(status_code=409, detail="State already exists") new_id = uuid.uuid4() db.execute(insert(State).values(id=new_id, name=name)) stmt_set_env = insert(StateAssignment).values(state_id=new_id, environment_id=bindparam('env_id')) stmt_get_env_id = select(Environment).where(Environment.name == bindparam('env_name')) for env in patch_params.addenv: env_id_res = db.execute(stmt_get_env_id, {'env_name': env}).fetchall() if len(env_id_res) < 1: raise HTTPException(status_code=404, detail="No such environment exists") env_id = env_id_res[0][0].id db.execute(stmt_set_env, {'env_id': env_id}) return JSONResponse(status_code=201, content={ 'id': str(new_id), 'name': name }) @router.delete("/{name}") def state_delete(req: Request, name: str): """ Delete a state by name. Deletes the specified state from the database. Returns 409 if hosts are still assigned to this state. Returns 404 if no such state exists. Args: req (Request): The incoming request object. name (str): The name of the state to delete. Returns: JSONResponse: A JSON response with status code 204 on successful deletion, or appropriate error codes for conflicts or not found. """ db = req.state.db # Validate name format if not validate_state_name(name): raise HTTPException(status_code=400, detail="Invalid state name format") stmt = select(State).where(State.name == name) result = db.execute(stmt).fetchall() if len(result) == 0: raise HTTPException(status_code=404, detail="No such state exists") assert len(result) == 1 state: State = result[0][0] # Check for assigned environments before deleting assignments_stmt = select(StateAssignment).where( StateAssignment.state_id == state.id ) assignments = db.execute(assignments_stmt).fetchall() if len(assignments) > 0: env_ids_stmt = select(StateAssignment.environment_id).where( StateAssignment.state_id == state.id ) env_ids = [row[0] for row in db.execute(env_ids_stmt).fetchall()] # Get host names (could optimize this) envs_stmt = select(Environment).where(Environment.id.in_(env_ids)) envs: list[Environment] = list(map(lambda x: x[0], db.execute(envs_stmt).fetchall())) return JSONResponse(status_code=409, content={ 'message': "Cannot delete a state that still has environment assignments", 'assigned_envs': [e.name for e in envs] }) # Check for assigned top files before deleting top_stmt = select(TopFile).where(TopFile.state_id == state.id) top = db.execute(top_stmt).fetchall() if len(top) > 0: host_ids_stmt = select(TopFile.host_id).where( TopFile.state_id == state.id ) host_ids = [row[0] for row in db.execute(host_ids_stmt).fetchall()] # Get host names (could optimize this) hosts_stmt = select(Host).where(Host.id.in_(host_ids)) hosts: list[Host] = list(map(lambda x: x[0], db.execute(hosts_stmt).fetchall())) return JSONResponse(status_code=409, content={ 'message': "Cannot delete a state that still has host assignments", 'assigned_hosts': [h.name for h in hosts] }) # Delete the state db.execute(delete(State).where(State.id == state.id)) return JSONResponse(status_code=204, content={}) @router.patch("/{name}") def state_patch(req: Request, name: str, patch_params: StateParams): db: Session = req.state.db stmt_state_id = select(State).where(State.name == name) selected_state_res = db.execute(stmt_state_id).fetchall() if len(selected_state_res) != 1: raise HTTPException(status_code=404, detail="No such state exists") state: State = selected_state_res[0][0] # Statement for getting the stmt_get_env_id = select(Environment).where(Environment.name == bindparam('env_name')) # add any requested environments to the state in question stmt_set_env = insert(StateAssignment).values(state_id=state.id, environment_id=bindparam('env_id')) for env in patch_params.addenv: env_id_res = db.execute(stmt_get_env_id, {'env_name': env}).fetchall() if len(env_id_res) < 1: raise HTTPException(status_code=404, detail="No such environment exists") env_id = env_id_res[0][0].id db.execute(stmt_set_env, {'env_id': env_id}) stmt_del_env = delete(StateAssignment).where(and_(StateAssignment.state_id == state.id, StateAssignment.environment_id == bindparam('env_id'))) for env in patch_params.delenv: env_id_res = db.execute(stmt_get_env_id, {'env_name': env}).fetchall() if len(env_id_res) < 1: raise HTTPException(status_code=404, detail="No such environment exists") env_id = env_id_res[0][0].id db.execute(stmt_del_env, {'env_id': env_id}) return JSONResponse(status_code=204, content={})