242 lines
8.3 KiB
Python

import uuid
from sqlalchemy import select, insert, delete, bindparam, and_
from sqlalchemy.orm import Session
from starlette.exceptions import HTTPException
from starlette.requests import Request
from fastapi import APIRouter
from starlette.responses import JSONResponse
from pillar_tool.db import TopFile, Host
from pillar_tool.db.models.top_data import State, StateAssignment, Environment
from pillar_tool.schemas import StateParams
from pillar_tool.util.validation import validate_state_name
router = APIRouter(
prefix="/state",
tags=["state"],
)
@router.get("")
def states_get(req: Request):
"""
Retrieve all states.
Fetches and returns a list of state names from the database.
Returns:
JSONResponse: A JSON response with status code 200 containing a list of state names (strings).
"""
db: Session = req.state.db
result = db.execute(select(State)).fetchall()
states: list[State] = list(map(lambda x: x[0], result))
return JSONResponse(status_code=200, content=[state.name for state in states])
@router.get("/{name}")
def state_get(req: Request, name: str):
"""
Retrieve a specific state by name.
Fetches and returns details of the specified state.
Returns 404 if no such state exists.
Args:
req (Request): The incoming request object.
name (str): The name of the state to retrieve.
Returns:
JSONResponse: A JSON response with status code 200 and the state details on success,
or 404 if not found.
"""
db: Session = req.state.db
# Validate name before query
if not validate_state_name(name):
raise HTTPException(status_code=400, detail="Invalid state name format")
stmt = select(State).where(State.name == name)
result = db.execute(stmt).fetchall()
if len(result) == 0:
raise HTTPException(status_code=404, detail="No such state exists")
assert len(result) == 1
state: State = result[0][0]
# Get assigned hosts count as an example of additional info
assignments_stmt = select(StateAssignment).where(
StateAssignment.state_id == state.id
)
assignments_count = db.execute(assignments_stmt).fetchall().__len__()
return JSONResponse(status_code=200, content={
'state': state.name,
'assignment_count': assignments_count
})
@router.post("/{name}")
def state_create(req: Request, name: str, patch_params: StateParams):
"""
Create a new state.
Creates a new state record in the database with the provided parameters.
Args:
req (Request): The incoming request object.
name (str): The name of the state (must be unique).
Returns:
JSONResponse: A JSON response with status code 201 on success,
or appropriate error codes (e.g., 409 if already exists, 400 for invalid format).
"""
db = req.state.db
# Validate name format
if not validate_state_name(name):
raise HTTPException(status_code=400,
detail="Invalid state name. State names must start with a letter or underscore and contain only alphanumeric characters, underscores, or dashes.")
# Check if state already exists
stmt_check = select(State).where(State.name == name)
existing = db.execute(stmt_check).fetchall()
if len(existing) > 0:
raise HTTPException(status_code=409, detail="State already exists")
new_id = uuid.uuid4()
db.execute(insert(State).values(id=new_id, name=name))
stmt_set_env = insert(StateAssignment).values(state_id=new_id, environment_id=bindparam('env_id'))
stmt_get_env_id = select(Environment).where(Environment.name == bindparam('env_name'))
for env in patch_params.addenv:
env_id_res = db.execute(stmt_get_env_id, {'env_name': env}).fetchall()
if len(env_id_res) < 1:
raise HTTPException(status_code=404, detail="No such environment exists")
env_id = env_id_res[0][0].id
db.execute(stmt_set_env, {'env_id': env_id})
return JSONResponse(status_code=201, content={
'id': str(new_id),
'name': name
})
@router.delete("/{name}")
def state_delete(req: Request, name: str):
"""
Delete a state by name.
Deletes the specified state from the database.
Returns 409 if hosts are still assigned to this state.
Returns 404 if no such state exists.
Args:
req (Request): The incoming request object.
name (str): The name of the state to delete.
Returns:
JSONResponse: A JSON response with status code 204 on successful deletion,
or appropriate error codes for conflicts or not found.
"""
db = req.state.db
# Validate name format
if not validate_state_name(name):
raise HTTPException(status_code=400, detail="Invalid state name format")
stmt = select(State).where(State.name == name)
result = db.execute(stmt).fetchall()
if len(result) == 0:
raise HTTPException(status_code=404, detail="No such state exists")
assert len(result) == 1
state: State = result[0][0]
# Check for assigned environments before deleting
assignments_stmt = select(StateAssignment).where(
StateAssignment.state_id == state.id
)
assignments = db.execute(assignments_stmt).fetchall()
if len(assignments) > 0:
env_ids_stmt = select(StateAssignment.environment_id).where(
StateAssignment.state_id == state.id
)
env_ids = [row[0] for row in db.execute(env_ids_stmt).fetchall()]
# Get host names (could optimize this)
envs_stmt = select(Environment).where(Environment.id.in_(env_ids))
envs: list[Environment] = list(map(lambda x: x[0], db.execute(envs_stmt).fetchall()))
return JSONResponse(status_code=409, content={
'message': "Cannot delete a state that still has environment assignments",
'assigned_envs': [e.name for e in envs]
})
# Check for assigned top files before deleting
top_stmt = select(TopFile).where(TopFile.state_id == state.id)
top = db.execute(top_stmt).fetchall()
if len(top) > 0:
host_ids_stmt = select(TopFile.host_id).where(
TopFile.state_id == state.id
)
host_ids = [row[0] for row in db.execute(host_ids_stmt).fetchall()]
# Get host names (could optimize this)
hosts_stmt = select(Host).where(Host.id.in_(host_ids))
hosts: list[Host] = list(map(lambda x: x[0], db.execute(hosts_stmt).fetchall()))
return JSONResponse(status_code=409, content={
'message': "Cannot delete a state that still has host assignments",
'assigned_hosts': [h.name for h in hosts]
})
# Delete the state
db.execute(delete(State).where(State.id == state.id))
return JSONResponse(status_code=204, content={})
@router.patch("/{name}")
def state_patch(req: Request, name: str, patch_params: StateParams):
db: Session = req.state.db
stmt_state_id = select(State).where(State.name == name)
selected_state_res = db.execute(stmt_state_id).fetchall()
if len(selected_state_res) != 1:
raise HTTPException(status_code=404, detail="No such state exists")
state: State = selected_state_res[0][0]
# Statement for getting the
stmt_get_env_id = select(Environment).where(Environment.name == bindparam('env_name'))
# add any requested environments to the state in question
stmt_set_env = insert(StateAssignment).values(state_id=state.id, environment_id=bindparam('env_id'))
for env in patch_params.addenv:
env_id_res = db.execute(stmt_get_env_id, {'env_name': env}).fetchall()
if len(env_id_res) < 1:
raise HTTPException(status_code=404, detail="No such environment exists")
env_id = env_id_res[0][0].id
db.execute(stmt_set_env, {'env_id': env_id})
stmt_del_env = delete(StateAssignment).where(and_(StateAssignment.state_id == state.id, StateAssignment.environment_id == bindparam('env_id')))
for env in patch_params.delenv:
env_id_res = db.execute(stmt_get_env_id, {'env_name': env}).fetchall()
if len(env_id_res) < 1:
raise HTTPException(status_code=404, detail="No such environment exists")
env_id = env_id_res[0][0].id
db.execute(stmt_del_env, {'env_id': env_id})
return JSONResponse(status_code=204, content={})