177 lines
5.4 KiB
Python

import uuid
from sqlalchemy import select, insert, delete
from sqlalchemy.orm import Session
from starlette.exceptions import HTTPException
from starlette.requests import Request
from fastapi import APIRouter
from starlette.responses import JSONResponse
from pillar_tool.db.models.top_data import State, StateAssignment
from pillar_tool.util.validation import validate_state_name
router = APIRouter(
prefix="/state",
tags=["state"],
)
@router.get("")
def states_get(req: Request):
"""
Retrieve all states.
Fetches and returns a list of state names from the database.
Returns:
JSONResponse: A JSON response with status code 200 containing a list of state names (strings).
"""
db: Session = req.state.db
result = db.execute(select(State)).fetchall()
states: list[State] = list(map(lambda x: x[0], result))
return JSONResponse(status_code=200, content=[state.name for state in states])
@router.get("/{name}")
def state_get(req: Request, name: str):
"""
Retrieve a specific state by name.
Fetches and returns details of the specified state.
Returns 404 if no such state exists.
Args:
req (Request): The incoming request object.
name (str): The name of the state to retrieve.
Returns:
JSONResponse: A JSON response with status code 200 and the state details on success,
or 404 if not found.
"""
db: Session = req.state.db
# Validate name before query
if not validate_state_name(name):
raise HTTPException(status_code=400, detail="Invalid state name format")
stmt = select(State).where(State.name == name)
result = db.execute(stmt).fetchall()
if len(result) == 0:
raise HTTPException(status_code=404, detail="No such state exists")
assert len(result) == 1
state: State = result[0][0]
# Get assigned hosts count as an example of additional info
assignments_stmt = select(StateAssignment).where(
StateAssignment.state_id == state.id
)
assignments_count = db.execute(assignments_stmt).fetchall().__len__()
return JSONResponse(status_code=200, content={
'state': state.name,
'assignment_count': assignments_count
})
@router.post("/{name}")
def state_create(req: Request, name: str):
"""
Create a new state.
Creates a new state record in the database with the provided parameters.
Args:
req (Request): The incoming request object.
name (str): The name of the state (must be unique).
Returns:
JSONResponse: A JSON response with status code 201 on success,
or appropriate error codes (e.g., 409 if already exists, 400 for invalid format).
"""
db = req.state.db
# Validate name format
if not validate_state_name(name):
raise HTTPException(status_code=400,
detail="Invalid state name. State names must start with a letter or underscore and contain only alphanumeric characters, underscores, or dashes.")
# Check if state already exists
stmt_check = select(State).where(State.name == name)
existing = db.execute(stmt_check).fetchall()
if len(existing) > 0:
raise HTTPException(status_code=409, detail="State already exists")
new_id = uuid.uuid4()
db.execute(insert(State).values(id=new_id, name=name))
return JSONResponse(status_code=201, content={
'id': str(new_id),
'name': name
})
@router.delete("/{name}")
def state_delete(req: Request, name: str):
"""
Delete a state by name.
Deletes the specified state from the database.
Returns 409 if hosts are still assigned to this state.
Returns 404 if no such state exists.
Args:
req (Request): The incoming request object.
name (str): The name of the state to delete.
Returns:
JSONResponse: A JSON response with status code 204 on successful deletion,
or appropriate error codes for conflicts or not found.
"""
db = req.state.db
# Validate name format
if not validate_state_name(name):
raise HTTPException(status_code=400, detail="Invalid state name format")
stmt = select(State).where(State.name == name)
result = db.execute(stmt).fetchall()
if len(result) == 0:
raise HTTPException(status_code=404, detail="No such state exists")
assert len(result) == 1
state: State = result[0][0]
# Check for assigned hosts before deleting
assignments_stmt = select(StateAssignment).where(
StateAssignment.state_id == state.id
)
assignments = db.execute(assignments_stmt).fetchall()
if len(assignments) > 0:
host_ids_stmt = select(StateAssignment.host_id).where(
StateAssignment.state_id == state.id
)
host_ids = [row[0] for row in db.execute(host_ids_stmt).fetchall()]
# Get host names (could optimize this)
from pillar_tool.db import Host
hosts_stmt = select(Host).where(Host.id.in_(host_ids))
hosts: list[Host] = list(map(lambda x: x[0], db.execute(hosts_stmt).fetchall()))
return JSONResponse(status_code=409, content={
'message': "Cannot delete a state that still has host assignments",
'assigned_hosts': [h.name for h in hosts]
})
# Delete the state
db.execute(delete(State).where(State.id == state.id))
return JSONResponse(status_code=204, content={})