From 7b613b9b9d165dd8db5016184748eb468daa1b6f Mon Sep 17 00:00:00 2001 From: Linus Vogel Date: Sat, 25 Apr 2026 21:31:20 +0200 Subject: [PATCH] importing environments now works --- pillar_tool/routers/environment.py | 70 ++++++++++++++++++++++++++---- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/pillar_tool/routers/environment.py b/pillar_tool/routers/environment.py index 94b8fce..b0ca887 100644 --- a/pillar_tool/routers/environment.py +++ b/pillar_tool/routers/environment.py @@ -3,15 +3,17 @@ import uuid import pygit2 from pygit2.enums import BranchType -from sqlalchemy import select, insert, delete +from sqlalchemy import select, delete, bindparam from sqlalchemy.orm import Session +from sqlalchemy.dialects.postgresql import insert +from sqlalchemy.sql.operators import in_op, and_ from starlette.exceptions import HTTPException from starlette.requests import Request from fastapi import APIRouter from starlette.responses import JSONResponse -from pillar_tool.db import Host -from pillar_tool.db.models.top_data import Environment, EnvironmentAssignment +from pillar_tool.db import Host, StateAssignment +from pillar_tool.db.models.top_data import Environment, EnvironmentAssignment, State from pillar_tool.util.files import recursive_list_dir from pillar_tool.util.validation import validate_environment_name from pillar_tool.util import config, Config @@ -191,7 +193,18 @@ def environment_patch(req: Request, name: str) -> JSONResponse: try: checkout_remote_branch(cfg, name) - # TODO: read the file tree and find all sls files and init.sls files to enumerate the available states + # create the environment if it did not exist already + select_env_res = db.execute(select(Environment).where(Environment.name == name)).fetchall() + print(select_env_res) + if len(select_env_res) == 0: + insert_env_res = db.execute(insert(Environment).values(id=uuid.uuid4(), name=name).returning(Environment)).fetchall() + print(insert_env_res) + if len(insert_env_res) == 0: + raise HTTPException(status_code=404, detail=f"Failed to create non-existent environment '{name}'") + else: + env: Environment = insert_env_res[0][0] + else: + env: Environment = select_env_res[0][0] # Branch has been checked out print(f"Reading states that are available in '{name}':") @@ -199,13 +212,52 @@ def environment_patch(req: Request, name: str) -> JSONResponse: all_files = recursive_list_dir(cfg.git.state_repo_path) sls_files = filter(lambda f: f.endswith(".sls"), all_files) state_file_paths = map(lambda x: x.replace("/init.sls", "").replace(".sls", ""), sls_files) - state_names = map(lambda x: x.replace(f"{cfg.git.state_repo_path}/", "").replace("/", "."), state_file_paths) + state_names = list(map(lambda x: x.replace(f"{cfg.git.state_repo_path}/", "").replace("/", "."), state_file_paths)) - for state_name in state_names: - print(f"Checking state '{state_name}'") + # get all the existing states and the to be created ones + select_res = db.execute(select(State).where(State.name.in_(state_names))).fetchall() + states_known = {} + for row in select_res: + state: State = row[0] + states_known[state.name] = state - # TODO: the environment should be a field of a state in the database - # TODO: ensure that each named state exists in the current environment + states_new = { + state_name: State(name=state_name, id=uuid.uuid4()) for state_name in state_names if state_name not in states_known + } + + states = {} + states.update(states_known) + states.update(states_new) + + # insert all states that don't exist already + insert_stmt = insert(State) + insert_assignment_stmt = insert(StateAssignment) + for state in states_new.values(): + db.execute(insert_stmt.values(id=state.id, name=state.name)) + + + # ensure that all the state assignments exist + state_ids = list(map(lambda x: x.id, states.values())) + state_assignments_known = [ + x[0].id + for x in db.execute(select(StateAssignment).where(StateAssignment.environment_id == env.id)).fetchall() + ] + state_assignments_new = [ + state_id + for state_id in map(lambda x: x.id, states.values()) if state_id not in state_assignments_known + ] + state_assignments_to_delete = [ + sid + for sid in filter(lambda x: x not in state_ids, state_assignments_known) + ] + + delete_stmt = delete(StateAssignment) + for sid in state_assignments_to_delete: + db.execute(delete_stmt.where(and_(StateAssignment.state_id == sid, StateAssignment.environment_id == env.id))) + + insert_stmt = insert(StateAssignment) + for sid in state_assignments_new: + db.execute(insert_stmt.values(state_id=sid, environment_id=env.id)) except Exception as exc: