#!/usr/bin/env python3 from __future__ import annotations import argparse import asyncio import sys from pathlib import Path def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Create or refresh a default local TransformerLab user." ) parser.add_argument("--transformerlab-dir", required=True, help="Path to the TransformerLab home directory") parser.add_argument("--email", required=True, help="Email address for the default user") parser.add_argument("--password", required=True, help="Password for the default user") parser.add_argument("--first-name", default="", help="Optional first name") parser.add_argument("--last-name", default="", help="Optional last name") return parser.parse_args() def load_environment(transformerlab_dir: Path) -> None: env_file = transformerlab_dir / ".env" if env_file.exists(): try: from dotenv import load_dotenv except ImportError: return load_dotenv(env_file) async def ensure_user( transformerlab_dir: Path, email: str, password: str, first_name: str, last_name: str, ) -> None: src_dir = transformerlab_dir / "src" if not src_dir.exists(): raise FileNotFoundError(f"TransformerLab source directory not found: {src_dir}") sys.path.insert(0, str(src_dir)) load_environment(transformerlab_dir) from fastapi_users.db import SQLAlchemyUserDatabase from sqlalchemy import select from transformerlab.models.users import UserCreate, UserManager, UserUpdate from transformerlab.services.provider_service import initialize_team_local_provider from transformerlab.shared.models.models import TeamRole, User, UserTeam from transformerlab.shared.models.user_model import AsyncSessionLocal, create_personal_team async with AsyncSessionLocal() as session: user_db = SQLAlchemyUserDatabase(session, User) user_manager = UserManager(user_db) stmt = select(User).where(User.email == email) result = await session.execute(stmt) existing_user = result.unique().scalar_one_or_none() created = existing_user is None if created: await user_manager.create( UserCreate( email=email, password=password, is_active=True, is_superuser=False, is_verified=True, first_name=first_name or None, last_name=last_name or None, ), safe=False, request=None, ) else: await user_manager.update( UserUpdate( password=password, is_active=True, is_superuser=False, is_verified=True, first_name=first_name or existing_user.first_name, last_name=last_name or existing_user.last_name, ), existing_user, safe=False, request=None, ) result = await session.execute(stmt) user = result.unique().scalar_one() user_id = str(user.id) team_stmt = select(UserTeam).where(UserTeam.user_id == user_id).limit(1) team_result = await session.execute(team_stmt) user_team = team_result.scalar_one_or_none() if user_team is None: personal_team = await create_personal_team(session, user) user_team = UserTeam(user_id=user_id, team_id=personal_team.id, role=TeamRole.OWNER.value) session.add(user_team) await session.commit() team_id = personal_team.id else: team_id = user_team.team_id try: await initialize_team_local_provider(session, team_id, user_id) except Exception as exc: print(f"warning: failed to initialize local provider for {email}: {exc}") print( f"{'created' if created else 'updated'} default TransformerLab user {email} " f"(team_id={team_id})" ) def main() -> int: args = parse_args() transformerlab_dir = Path(args.transformerlab_dir).expanduser().resolve() try: asyncio.run( ensure_user( transformerlab_dir=transformerlab_dir, email=args.email, password=args.password, first_name=args.first_name, last_name=args.last_name, ) ) except Exception as exc: print(f"error: {exc}", file=sys.stderr) return 1 return 0 if __name__ == "__main__": raise SystemExit(main())