#!/usr/bin/env python3 from __future__ import annotations import argparse import asyncio import os import shutil import sys from pathlib import Path DEFAULT_WORKSPACE_PLUGINS = ("fastchat_server",) DEFAULT_WORKSPACE_EXPERIMENTS = ("alpha", "beta", "gamma") DEFAULT_WORKSPACE_MODELS = ("unsloth_Llama-3.2-1B-Instruct",) DEFAULT_MODEL_METADATA_FILES = ("_tlab_complete_provenance.json",) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Ensure a known verified TransformerLab user exists for single-user courseware installs." ) parser.add_argument("--transformerlab-dir", required=True, help="Path to the managed TransformerLab home directory.") parser.add_argument("--email", required=True, help="Email address for the default user.") parser.add_argument("--password", required=True, help="Password for the default user.") parser.add_argument("--first-name", default="Student", help="Optional first name for the default user.") parser.add_argument("--last-name", default="", help="Optional last name for the default user.") return parser.parse_args() def bootstrap_source(transformerlab_dir: Path) -> None: source_dir = transformerlab_dir / "src" if not source_dir.is_dir(): raise SystemExit(f"TransformerLab source directory not found: {source_dir}") sys.path.insert(0, str(source_dir)) env_file = transformerlab_dir / ".env" if env_file.exists(): for line in env_file.read_text().splitlines(): if not line or line.lstrip().startswith("#") or "=" not in line: continue key, value = line.split("=", 1) os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) def target_workspace(transformerlab_dir: Path, team_id: str) -> Path: return transformerlab_dir / "orgs" / team_id / "workspace" def workspace_team_id(workspace: Path, transformerlab_dir: Path) -> str | None: orgs_dir = transformerlab_dir / "orgs" try: relative = workspace.relative_to(orgs_dir) except ValueError: return None if len(relative.parts) >= 2 and relative.parts[1] == "workspace": return relative.parts[0] return None def candidate_workspaces(transformerlab_dir: Path, excluded_team_id: str) -> list[Path]: candidates: list[Path] = [] root_workspace = transformerlab_dir / "workspace" if root_workspace.is_dir(): candidates.append(root_workspace) orgs_dir = transformerlab_dir / "orgs" if not orgs_dir.is_dir(): return candidates for workspace in sorted(orgs_dir.glob("*/workspace")): if not workspace.is_dir(): continue if workspace_team_id(workspace, transformerlab_dir) == excluded_team_id: continue candidates.append(workspace) return candidates def copy_dir_if_missing(source: Path | None, target: Path, label: str) -> bool: if source is None or not source.is_dir() or target.exists(): return False target.parent.mkdir(parents=True, exist_ok=True) shutil.copytree(source, target) print(f"Seeded {label} from {source}.") return True def copy_file_if_missing(source: Path | None, target: Path, label: str) -> bool: if source is None or not source.is_file() or target.exists(): return False target.parent.mkdir(parents=True, exist_ok=True) shutil.copy2(source, target) print(f"Seeded {label} from {source}.") return True def find_workspace_seed(transformerlab_dir: Path, category: str, name: str, excluded_team_id: str) -> Path | None: for workspace in candidate_workspaces(transformerlab_dir, excluded_team_id): candidate = workspace / category / name if candidate.exists(): return candidate return None def seed_workspace(transformerlab_dir: Path, team_id: str) -> None: workspace = target_workspace(transformerlab_dir, team_id) workspace.mkdir(parents=True, exist_ok=True) for plugin in DEFAULT_WORKSPACE_PLUGINS: source = transformerlab_dir / "src" / "transformerlab" / "plugins" / plugin copy_dir_if_missing(source, workspace / "plugins" / plugin, f"plugin '{plugin}'") for experiment in DEFAULT_WORKSPACE_EXPERIMENTS: source = find_workspace_seed(transformerlab_dir, "experiments", experiment, team_id) copy_dir_if_missing(source, workspace / "experiments" / experiment, f"experiment '{experiment}'") copied_model = False for model in DEFAULT_WORKSPACE_MODELS: source = find_workspace_seed(transformerlab_dir, "models", model, team_id) copied_model = copy_dir_if_missing(source, workspace / "models" / model, f"model '{model}'") or copied_model for metadata_name in DEFAULT_MODEL_METADATA_FILES: source = find_workspace_seed(transformerlab_dir, "models", metadata_name, team_id) if copied_model or source is not None: copy_file_if_missing(source, workspace / "models" / metadata_name, f"model metadata '{metadata_name}'") async def ensure_user(args: argparse.Namespace) -> int: from sqlalchemy import select from transformerlab.db.constants import DATABASE_FILE_NAME from transformerlab.shared.models.models import OAuthAccount, TeamRole, User, UserTeam from transformerlab.shared.models.user_model import ( AsyncSessionLocal, SQLAlchemyUserDatabaseWithOAuth, create_personal_team, ) from transformerlab.models.users import UserCreate, UserManager database_path = Path(DATABASE_FILE_NAME) if not database_path.exists(): print(f"TransformerLab database is not ready yet at {database_path}; skipping default-user sync.") return 0 async with AsyncSessionLocal() as session: user_db = SQLAlchemyUserDatabaseWithOAuth(session, User, OAuthAccount) user_manager = UserManager(user_db) stmt = select(User).where(User.email == args.email) result = await session.execute(stmt) user = result.unique().scalar_one_or_none() created = False changed = False if user is None: user = await user_manager.create( UserCreate( email=args.email, password=args.password, is_active=True, is_superuser=True, is_verified=True, first_name=args.first_name or None, last_name=args.last_name or None, ), safe=False, request=None, ) created = True changed = True else: verified, new_hash = user_manager.password_helper.verify_and_update(args.password, user.hashed_password) if not verified: user.hashed_password = user_manager.password_helper.hash(args.password) changed = True elif new_hash: user.hashed_password = new_hash changed = True if not user.is_active: user.is_active = True changed = True if not user.is_verified: user.is_verified = True changed = True if not user.is_superuser: user.is_superuser = True changed = True if args.first_name and user.first_name != args.first_name: user.first_name = args.first_name changed = True desired_last_name = args.last_name or None if user.last_name != desired_last_name: user.last_name = desired_last_name changed = True if changed: session.add(user) await session.commit() await session.refresh(user) team_stmt = select(UserTeam).where(UserTeam.user_id == str(user.id)) team_result = await session.execute(team_stmt) user_team = team_result.scalar_one_or_none() if user_team is None: personal_team = await create_personal_team(session, user) user_team = UserTeam(user_id=str(user.id), team_id=personal_team.id, role=TeamRole.OWNER.value) session.add(user_team) await session.commit() print(f"Created personal team '{personal_team.name}' for {args.email}.") elif user_team.role != TeamRole.OWNER.value: user_team.role = TeamRole.OWNER.value session.add(user_team) await session.commit() print(f"Updated team role to owner for {args.email}.") seed_workspace(Path(args.transformerlab_dir), str(user_team.team_id)) action = "Created" if created else "Verified" print(f"{action} default TransformerLab user {args.email}.") return 0 def main() -> int: args = parse_args() args.email = args.email.strip() args.password = args.password.strip() args.first_name = args.first_name.strip() args.last_name = args.last_name.strip() bootstrap_source(Path(args.transformerlab_dir)) return asyncio.run(ensure_user(args)) if __name__ == "__main__": raise SystemExit(main())