243 lines
9.0 KiB
Python
243 lines
9.0 KiB
Python
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import os
|
|
import shutil
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
DEFAULT_WORKSPACE_PLUGINS = ("fastchat_server",)
|
|
DEFAULT_WORKSPACE_EXPERIMENTS = ("alpha", "beta", "gamma")
|
|
DEFAULT_WORKSPACE_MODELS = ("unsloth_Llama-3.2-1B-Instruct",)
|
|
DEFAULT_MODEL_METADATA_FILES = ("_tlab_complete_provenance.json",)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Ensure a known verified TransformerLab user exists for single-user courseware installs."
|
|
)
|
|
parser.add_argument("--transformerlab-dir", required=True, help="Path to the managed TransformerLab home directory.")
|
|
parser.add_argument("--email", required=True, help="Email address for the default user.")
|
|
parser.add_argument("--password", required=True, help="Password for the default user.")
|
|
parser.add_argument("--first-name", default="Student", help="Optional first name for the default user.")
|
|
parser.add_argument("--last-name", default="", help="Optional last name for the default user.")
|
|
return parser.parse_args()
|
|
|
|
|
|
def bootstrap_source(transformerlab_dir: Path) -> None:
|
|
source_dir = transformerlab_dir / "src"
|
|
if not source_dir.is_dir():
|
|
raise SystemExit(f"TransformerLab source directory not found: {source_dir}")
|
|
|
|
sys.path.insert(0, str(source_dir))
|
|
|
|
env_file = transformerlab_dir / ".env"
|
|
if env_file.exists():
|
|
for line in env_file.read_text().splitlines():
|
|
if not line or line.lstrip().startswith("#") or "=" not in line:
|
|
continue
|
|
key, value = line.split("=", 1)
|
|
os.environ.setdefault(key.strip(), value.strip().strip('"').strip("'"))
|
|
|
|
|
|
def target_workspace(transformerlab_dir: Path, team_id: str) -> Path:
|
|
return transformerlab_dir / "orgs" / team_id / "workspace"
|
|
|
|
|
|
def workspace_team_id(workspace: Path, transformerlab_dir: Path) -> str | None:
|
|
orgs_dir = transformerlab_dir / "orgs"
|
|
try:
|
|
relative = workspace.relative_to(orgs_dir)
|
|
except ValueError:
|
|
return None
|
|
|
|
if len(relative.parts) >= 2 and relative.parts[1] == "workspace":
|
|
return relative.parts[0]
|
|
return None
|
|
|
|
|
|
def candidate_workspaces(transformerlab_dir: Path, excluded_team_id: str) -> list[Path]:
|
|
candidates: list[Path] = []
|
|
root_workspace = transformerlab_dir / "workspace"
|
|
if root_workspace.is_dir():
|
|
candidates.append(root_workspace)
|
|
|
|
orgs_dir = transformerlab_dir / "orgs"
|
|
if not orgs_dir.is_dir():
|
|
return candidates
|
|
|
|
for workspace in sorted(orgs_dir.glob("*/workspace")):
|
|
if not workspace.is_dir():
|
|
continue
|
|
if workspace_team_id(workspace, transformerlab_dir) == excluded_team_id:
|
|
continue
|
|
candidates.append(workspace)
|
|
return candidates
|
|
|
|
|
|
def copy_dir_if_missing(source: Path | None, target: Path, label: str) -> bool:
|
|
if source is None or not source.is_dir() or target.exists():
|
|
return False
|
|
|
|
target.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copytree(source, target)
|
|
print(f"Seeded {label} from {source}.")
|
|
return True
|
|
|
|
|
|
def copy_file_if_missing(source: Path | None, target: Path, label: str) -> bool:
|
|
if source is None or not source.is_file() or target.exists():
|
|
return False
|
|
|
|
target.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copy2(source, target)
|
|
print(f"Seeded {label} from {source}.")
|
|
return True
|
|
|
|
|
|
def find_workspace_seed(transformerlab_dir: Path, category: str, name: str, excluded_team_id: str) -> Path | None:
|
|
for workspace in candidate_workspaces(transformerlab_dir, excluded_team_id):
|
|
candidate = workspace / category / name
|
|
if candidate.exists():
|
|
return candidate
|
|
return None
|
|
|
|
|
|
def seed_workspace(transformerlab_dir: Path, team_id: str) -> None:
|
|
workspace = target_workspace(transformerlab_dir, team_id)
|
|
workspace.mkdir(parents=True, exist_ok=True)
|
|
|
|
for plugin in DEFAULT_WORKSPACE_PLUGINS:
|
|
source = transformerlab_dir / "src" / "transformerlab" / "plugins" / plugin
|
|
copy_dir_if_missing(source, workspace / "plugins" / plugin, f"plugin '{plugin}'")
|
|
|
|
for experiment in DEFAULT_WORKSPACE_EXPERIMENTS:
|
|
source = find_workspace_seed(transformerlab_dir, "experiments", experiment, team_id)
|
|
copy_dir_if_missing(source, workspace / "experiments" / experiment, f"experiment '{experiment}'")
|
|
|
|
copied_model = False
|
|
for model in DEFAULT_WORKSPACE_MODELS:
|
|
source = find_workspace_seed(transformerlab_dir, "models", model, team_id)
|
|
copied_model = copy_dir_if_missing(source, workspace / "models" / model, f"model '{model}'") or copied_model
|
|
|
|
for metadata_name in DEFAULT_MODEL_METADATA_FILES:
|
|
source = find_workspace_seed(transformerlab_dir, "models", metadata_name, team_id)
|
|
if copied_model or source is not None:
|
|
copy_file_if_missing(source, workspace / "models" / metadata_name, f"model metadata '{metadata_name}'")
|
|
|
|
|
|
async def ensure_user(args: argparse.Namespace) -> int:
|
|
from sqlalchemy import select
|
|
from transformerlab.db.constants import DATABASE_FILE_NAME
|
|
from transformerlab.shared.models.models import OAuthAccount, TeamRole, User, UserTeam
|
|
from transformerlab.shared.models.user_model import (
|
|
AsyncSessionLocal,
|
|
SQLAlchemyUserDatabaseWithOAuth,
|
|
create_personal_team,
|
|
)
|
|
from transformerlab.models.users import UserCreate, UserManager
|
|
|
|
database_path = Path(DATABASE_FILE_NAME)
|
|
if not database_path.exists():
|
|
print(f"TransformerLab database is not ready yet at {database_path}; skipping default-user sync.")
|
|
return 0
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
user_db = SQLAlchemyUserDatabaseWithOAuth(session, User, OAuthAccount)
|
|
user_manager = UserManager(user_db)
|
|
|
|
stmt = select(User).where(User.email == args.email)
|
|
result = await session.execute(stmt)
|
|
user = result.unique().scalar_one_or_none()
|
|
created = False
|
|
changed = False
|
|
|
|
if user is None:
|
|
user = await user_manager.create(
|
|
UserCreate(
|
|
email=args.email,
|
|
password=args.password,
|
|
is_active=True,
|
|
is_superuser=True,
|
|
is_verified=True,
|
|
first_name=args.first_name or None,
|
|
last_name=args.last_name or None,
|
|
),
|
|
safe=False,
|
|
request=None,
|
|
)
|
|
created = True
|
|
changed = True
|
|
else:
|
|
verified, new_hash = user_manager.password_helper.verify_and_update(args.password, user.hashed_password)
|
|
if not verified:
|
|
user.hashed_password = user_manager.password_helper.hash(args.password)
|
|
changed = True
|
|
elif new_hash:
|
|
user.hashed_password = new_hash
|
|
changed = True
|
|
|
|
if not user.is_active:
|
|
user.is_active = True
|
|
changed = True
|
|
|
|
if not user.is_verified:
|
|
user.is_verified = True
|
|
changed = True
|
|
|
|
if not user.is_superuser:
|
|
user.is_superuser = True
|
|
changed = True
|
|
|
|
if args.first_name and user.first_name != args.first_name:
|
|
user.first_name = args.first_name
|
|
changed = True
|
|
|
|
desired_last_name = args.last_name or None
|
|
if user.last_name != desired_last_name:
|
|
user.last_name = desired_last_name
|
|
changed = True
|
|
|
|
if changed:
|
|
session.add(user)
|
|
await session.commit()
|
|
await session.refresh(user)
|
|
|
|
team_stmt = select(UserTeam).where(UserTeam.user_id == str(user.id))
|
|
team_result = await session.execute(team_stmt)
|
|
user_team = team_result.scalar_one_or_none()
|
|
|
|
if user_team is None:
|
|
personal_team = await create_personal_team(session, user)
|
|
user_team = UserTeam(user_id=str(user.id), team_id=personal_team.id, role=TeamRole.OWNER.value)
|
|
session.add(user_team)
|
|
await session.commit()
|
|
print(f"Created personal team '{personal_team.name}' for {args.email}.")
|
|
elif user_team.role != TeamRole.OWNER.value:
|
|
user_team.role = TeamRole.OWNER.value
|
|
session.add(user_team)
|
|
await session.commit()
|
|
print(f"Updated team role to owner for {args.email}.")
|
|
|
|
seed_workspace(Path(args.transformerlab_dir), str(user_team.team_id))
|
|
|
|
action = "Created" if created else "Verified"
|
|
print(f"{action} default TransformerLab user {args.email}.")
|
|
return 0
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
args.email = args.email.strip()
|
|
args.password = args.password.strip()
|
|
args.first_name = args.first_name.strip()
|
|
args.last_name = args.last_name.strip()
|
|
bootstrap_source(Path(args.transformerlab_dir))
|
|
return asyncio.run(ensure_user(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|