#!/usr/bin/env python3 from __future__ import annotations import argparse import asyncio import os import sys from pathlib import Path def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Ensure a known verified TransformerLab user exists for single-user courseware installs." ) parser.add_argument("--transformerlab-dir", required=True, help="Path to the managed TransformerLab home directory.") parser.add_argument("--email", required=True, help="Email address for the default user.") parser.add_argument("--password", required=True, help="Password for the default user.") parser.add_argument("--first-name", default="Student", help="Optional first name for the default user.") parser.add_argument("--last-name", default="", help="Optional last name for the default user.") return parser.parse_args() def bootstrap_source(transformerlab_dir: Path) -> None: source_dir = transformerlab_dir / "src" if not source_dir.is_dir(): raise SystemExit(f"TransformerLab source directory not found: {source_dir}") sys.path.insert(0, str(source_dir)) env_file = transformerlab_dir / ".env" if env_file.exists(): for line in env_file.read_text().splitlines(): if not line or line.lstrip().startswith("#") or "=" not in line: continue key, value = line.split("=", 1) os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'")) async def ensure_user(args: argparse.Namespace) -> int: from sqlalchemy import select from transformerlab.db.constants import DATABASE_FILE_NAME from transformerlab.shared.models.models import OAuthAccount, TeamRole, User, UserTeam from transformerlab.shared.models.user_model import ( AsyncSessionLocal, SQLAlchemyUserDatabaseWithOAuth, create_personal_team, ) from transformerlab.models.users import UserCreate, UserManager database_path = Path(DATABASE_FILE_NAME) if not database_path.exists(): print(f"TransformerLab database is not ready yet at {database_path}; skipping default-user sync.") return 0 async with AsyncSessionLocal() as session: user_db = SQLAlchemyUserDatabaseWithOAuth(session, User, OAuthAccount) user_manager = UserManager(user_db) stmt = select(User).where(User.email == args.email) result = await session.execute(stmt) user = result.unique().scalar_one_or_none() created = False changed = False if user is None: user = await user_manager.create( UserCreate( email=args.email, password=args.password, is_active=True, is_superuser=True, is_verified=True, first_name=args.first_name or None, last_name=args.last_name or None, ), safe=False, request=None, ) created = True changed = True else: verified, new_hash = user_manager.password_helper.verify_and_update(args.password, user.hashed_password) if not verified: user.hashed_password = user_manager.password_helper.hash(args.password) changed = True elif new_hash: user.hashed_password = new_hash changed = True if not user.is_active: user.is_active = True changed = True if not user.is_verified: user.is_verified = True changed = True if not user.is_superuser: user.is_superuser = True changed = True if args.first_name and user.first_name != args.first_name: user.first_name = args.first_name changed = True desired_last_name = args.last_name or None if user.last_name != desired_last_name: user.last_name = desired_last_name changed = True if changed: session.add(user) await session.commit() await session.refresh(user) team_stmt = select(UserTeam).where(UserTeam.user_id == str(user.id)) team_result = await session.execute(team_stmt) user_team = team_result.scalar_one_or_none() if user_team is None: personal_team = await create_personal_team(session, user) user_team = UserTeam(user_id=str(user.id), team_id=personal_team.id, role=TeamRole.OWNER.value) session.add(user_team) await session.commit() print(f"Created personal team '{personal_team.name}' for {args.email}.") elif user_team.role != TeamRole.OWNER.value: user_team.role = TeamRole.OWNER.value session.add(user_team) await session.commit() print(f"Updated team role to owner for {args.email}.") action = "Created" if created else "Verified" print(f"{action} default TransformerLab user {args.email}.") return 0 def main() -> int: args = parse_args() args.email = args.email.strip() args.password = args.password.strip() args.first_name = args.first_name.strip() args.last_name = args.last_name.strip() bootstrap_source(Path(args.transformerlab_dir)) return asyncio.run(ensure_user(args)) if __name__ == "__main__": raise SystemExit(main())