Files
LLM-Labs-Local/scripts/ensure_transformerlab_user.py
2026-03-31 18:35:14 -06:00

149 lines
5.4 KiB
Python

#!/usr/bin/env python3
from __future__ import annotations
import argparse
import asyncio
import os
import sys
from pathlib import Path
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Ensure a known verified TransformerLab user exists for single-user courseware installs."
)
parser.add_argument("--transformerlab-dir", required=True, help="Path to the managed TransformerLab home directory.")
parser.add_argument("--email", required=True, help="Email address for the default user.")
parser.add_argument("--password", required=True, help="Password for the default user.")
parser.add_argument("--first-name", default="Student", help="Optional first name for the default user.")
parser.add_argument("--last-name", default="", help="Optional last name for the default user.")
return parser.parse_args()
def bootstrap_source(transformerlab_dir: Path) -> None:
source_dir = transformerlab_dir / "src"
if not source_dir.is_dir():
raise SystemExit(f"TransformerLab source directory not found: {source_dir}")
sys.path.insert(0, str(source_dir))
env_file = transformerlab_dir / ".env"
if env_file.exists():
for line in env_file.read_text().splitlines():
if not line or line.lstrip().startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
async def ensure_user(args: argparse.Namespace) -> int:
from sqlalchemy import select
from transformerlab.db.constants import DATABASE_FILE_NAME
from transformerlab.shared.models.models import OAuthAccount, TeamRole, User, UserTeam
from transformerlab.shared.models.user_model import (
AsyncSessionLocal,
SQLAlchemyUserDatabaseWithOAuth,
create_personal_team,
)
from transformerlab.models.users import UserCreate, UserManager
database_path = Path(DATABASE_FILE_NAME)
if not database_path.exists():
print(f"TransformerLab database is not ready yet at {database_path}; skipping default-user sync.")
return 0
async with AsyncSessionLocal() as session:
user_db = SQLAlchemyUserDatabaseWithOAuth(session, User, OAuthAccount)
user_manager = UserManager(user_db)
stmt = select(User).where(User.email == args.email)
result = await session.execute(stmt)
user = result.unique().scalar_one_or_none()
created = False
changed = False
if user is None:
user = await user_manager.create(
UserCreate(
email=args.email,
password=args.password,
is_active=True,
is_superuser=True,
is_verified=True,
first_name=args.first_name or None,
last_name=args.last_name or None,
),
safe=False,
request=None,
)
created = True
changed = True
else:
verified, new_hash = user_manager.password_helper.verify_and_update(args.password, user.hashed_password)
if not verified:
user.hashed_password = user_manager.password_helper.hash(args.password)
changed = True
elif new_hash:
user.hashed_password = new_hash
changed = True
if not user.is_active:
user.is_active = True
changed = True
if not user.is_verified:
user.is_verified = True
changed = True
if not user.is_superuser:
user.is_superuser = True
changed = True
if args.first_name and user.first_name != args.first_name:
user.first_name = args.first_name
changed = True
desired_last_name = args.last_name or None
if user.last_name != desired_last_name:
user.last_name = desired_last_name
changed = True
if changed:
session.add(user)
await session.commit()
await session.refresh(user)
team_stmt = select(UserTeam).where(UserTeam.user_id == str(user.id))
team_result = await session.execute(team_stmt)
user_team = team_result.scalar_one_or_none()
if user_team is None:
personal_team = await create_personal_team(session, user)
user_team = UserTeam(user_id=str(user.id), team_id=personal_team.id, role=TeamRole.OWNER.value)
session.add(user_team)
await session.commit()
print(f"Created personal team '{personal_team.name}' for {args.email}.")
elif user_team.role != TeamRole.OWNER.value:
user_team.role = TeamRole.OWNER.value
session.add(user_team)
await session.commit()
print(f"Updated team role to owner for {args.email}.")
action = "Created" if created else "Verified"
print(f"{action} default TransformerLab user {args.email}.")
return 0
def main() -> int:
args = parse_args()
args.email = args.email.strip()
args.password = args.password.strip()
args.first_name = args.first_name.strip()
args.last_name = args.last_name.strip()
bootstrap_source(Path(args.transformerlab_dir))
return asyncio.run(ensure_user(args))
if __name__ == "__main__":
raise SystemExit(main())