import os import uuid import pygit2 from pygit2.enums import BranchType from sqlalchemy import select, delete, bindparam from sqlalchemy.orm import Session from sqlalchemy.dialects.postgresql import insert from sqlalchemy.sql.operators import in_op, and_ from starlette.exceptions import HTTPException from starlette.requests import Request from fastapi import APIRouter from starlette.responses import JSONResponse from pillar_tool.db import Host, StateAssignment from pillar_tool.db.models.top_data import Environment, EnvironmentAssignment, State from pillar_tool.util.files import recursive_list_dir from pillar_tool.util.validation import validate_environment_name from pillar_tool.util import config, Config from pillar_tool.git.repository import checkout_remote_branch router = APIRouter( prefix="/environment", tags=["environment"], ) @router.get("") def environments_get(req: Request): """ Retrieve all environments. Fetches and returns a list of environment names from the database. Returns: JSONResponse: A JSON response with status code 200 containing a list of environment names (strings). """ db: Session = req.state.db result = db.execute(select(Environment)).fetchall() environments: list[Environment] = list(map(lambda x: x[0], result)) return JSONResponse(status_code=200, content=[env.name for env in environments]) @router.get("/{name}") def environment_get(req: Request, name: str): """ Retrieve a specific environment by name. Fetches and returns details of the specified environment. Returns 404 if no such environment exists. Args: req (Request): The incoming request object. name (str): The name of the environment to retrieve. Returns: JSONResponse: A JSON response with status code 200 and the environment details on success, or 404 if not found. """ db: Session = req.state.db # Validate name before query if not validate_environment_name(name): raise HTTPException(status_code=400, detail="Invalid environment name format") stmt = select(Environment).where(Environment.name == name) result = db.execute(stmt).fetchall() if len(result) == 0: raise HTTPException(status_code=404, detail="No such environment exists") assert len(result) == 1 env: Environment = result[0][0] # Get assigned hosts count as an example of additional info hosts_stmt = select(Host).join(EnvironmentAssignment, Host.id == EnvironmentAssignment.host_id)\ .where(EnvironmentAssignment.environment_id == env.id) hosts_count = db.execute(hosts_stmt).fetchall().__len__() return JSONResponse(status_code=200, content={ 'environment': env.name, 'host_count': hosts_count }) @router.post("/{name}") def environment_create(req: Request, name: str): """ Create a new environment. Creates a new environment record in the database with the provided parameters. Args: req (Request): The incoming request object. name (str): The name of the environment (must be unique). Returns: JSONResponse: A JSON response with status code 201 on success, or appropriate error codes (e.g., 409 if already exists, 400 for invalid format). """ db = req.state.db # Validate name format if not validate_environment_name(name): raise HTTPException(status_code=400, detail="Invalid environment name. Use only alphanumeric, underscore or dash characters.") # Check if environment already exists stmt_check = select(Environment).where(Environment.name == name) existing = db.execute(stmt_check).fetchall() if len(existing) > 0: raise HTTPException(status_code=409, detail="Environment already exists") new_id = uuid.uuid4() db.execute(insert(Environment).values(id=new_id, name=name)) return JSONResponse(status_code=201, content={ 'id': str(new_id), 'name': name }) @router.delete("/{name}") def environment_delete(req: Request, name: str): """ Delete an environment by name. Deletes the specified environment from the database. Returns 409 if hosts are still assigned to this environment. Returns 404 if no such environment exists. Args: req (Request): The incoming request object. name (str): The name of the environment to delete. Returns: JSONResponse: A JSON response with status code 204 on successful deletion, or appropriate error codes for conflicts or not found. """ db = req.state.db # Validate name format if not validate_environment_name(name): raise HTTPException(status_code=400, detail="Invalid environment name format") stmt = select(Environment).where(Environment.name == name) result = db.execute(stmt).fetchall() if len(result) == 0: raise HTTPException(status_code=404, detail="No such environment exists") assert len(result) == 1 env: Environment = result[0][0] # Check for assigned hosts before deleting assignments_stmt = select(EnvironmentAssignment).where( EnvironmentAssignment.environment_id == env.id ) assignments = db.execute(assignments_stmt).fetchall() if len(assignments) > 0: host_ids_stmt = select(EnvironmentAssignment.host_id).where( EnvironmentAssignment.environment_id == env.id ) host_ids = [row[0] for row in db.execute(host_ids_stmt).fetchall()] # Get host names (could optimize this) hosts_stmt = select(Host).where(Host.id.in_(host_ids)) hosts: list[Host] = list(map(lambda x: x[0], db.execute(hosts_stmt).fetchall())) return JSONResponse(status_code=409, content={ 'message': "Cannot delete an environment that still has hosts assigned", 'assigned_hosts': [h.name for h in hosts] }) # Delete the environment db.execute(delete(Environment).where(Environment.id == env.id)) return JSONResponse(status_code=204, content={}) @router.patch("/{name}") def environment_patch(req: Request, name: str) -> JSONResponse: db: Session = req.state.db cfg: Config = config() # Attempt to check the requested branch out try: checkout_remote_branch(cfg, name) # create the environment if it did not exist already select_env_res = db.execute(select(Environment).where(Environment.name == name)).fetchall() print(select_env_res) if len(select_env_res) == 0: insert_env_res = db.execute(insert(Environment).values(id=uuid.uuid4(), name=name).returning(Environment)).fetchall() print(insert_env_res) if len(insert_env_res) == 0: raise HTTPException(status_code=404, detail=f"Failed to create non-existent environment '{name}'") else: env: Environment = insert_env_res[0][0] else: env: Environment = select_env_res[0][0] # Branch has been checked out print(f"Reading states that are available in '{name}':") print(f"{cfg.git.state_repo_path}:") all_files = recursive_list_dir(cfg.git.state_repo_path) sls_files = filter(lambda f: f.endswith(".sls"), all_files) state_file_paths = map(lambda x: x.replace("/init.sls", "").replace(".sls", ""), sls_files) state_names = list(map(lambda x: x.replace(f"{cfg.git.state_repo_path}/", "").replace("/", "."), state_file_paths)) # get all the existing states and the to be created ones select_res = db.execute(select(State).where(State.name.in_(state_names))).fetchall() states_known = {} for row in select_res: state: State = row[0] states_known[state.name] = state states_new = { state_name: State(name=state_name, id=uuid.uuid4()) for state_name in state_names if state_name not in states_known } states = {} states.update(states_known) states.update(states_new) # insert all states that don't exist already insert_stmt = insert(State) insert_assignment_stmt = insert(StateAssignment) for state in states_new.values(): db.execute(insert_stmt.values(id=state.id, name=state.name)) # ensure that all the state assignments exist state_ids = list(map(lambda x: x.id, states.values())) state_assignments_known = [ x[0].state_id for x in db.execute(select(StateAssignment).where(StateAssignment.environment_id == env.id)).fetchall() ] state_assignments_new = [ state_id for state_id in map(lambda x: x.id, states.values()) if state_id not in state_assignments_known ] state_assignments_to_delete = [ sid for sid in filter(lambda x: x not in state_ids, state_assignments_known) ] delete_stmt = delete(StateAssignment) for sid in state_assignments_to_delete: db.execute(delete_stmt.where(and_(StateAssignment.state_id == sid, StateAssignment.environment_id == env.id))) insert_stmt = insert(StateAssignment) for sid in state_assignments_new: db.execute(insert_stmt.values(state_id=sid, environment_id=env.id)) except Exception as exc: print(f"Failed to import environment: {exc}") raise HTTPException(status_code=404, detail=str(exc))