Initial snapshot before transformerlab recovery
This commit is contained in:
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Create or refresh a default local TransformerLab user."
|
||||
)
|
||||
parser.add_argument("--transformerlab-dir", required=True, help="Path to the TransformerLab home directory")
|
||||
parser.add_argument("--email", required=True, help="Email address for the default user")
|
||||
parser.add_argument("--password", required=True, help="Password for the default user")
|
||||
parser.add_argument("--first-name", default="", help="Optional first name")
|
||||
parser.add_argument("--last-name", default="", help="Optional last name")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_environment(transformerlab_dir: Path) -> None:
|
||||
env_file = transformerlab_dir / ".env"
|
||||
if env_file.exists():
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except ImportError:
|
||||
return
|
||||
load_dotenv(env_file)
|
||||
|
||||
|
||||
async def ensure_user(
|
||||
transformerlab_dir: Path,
|
||||
email: str,
|
||||
password: str,
|
||||
first_name: str,
|
||||
last_name: str,
|
||||
) -> None:
|
||||
src_dir = transformerlab_dir / "src"
|
||||
if not src_dir.exists():
|
||||
raise FileNotFoundError(f"TransformerLab source directory not found: {src_dir}")
|
||||
|
||||
sys.path.insert(0, str(src_dir))
|
||||
load_environment(transformerlab_dir)
|
||||
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from sqlalchemy import select
|
||||
from transformerlab.models.users import UserCreate, UserManager, UserUpdate
|
||||
from transformerlab.services.provider_service import initialize_team_local_provider
|
||||
from transformerlab.shared.models.models import TeamRole, User, UserTeam
|
||||
from transformerlab.shared.models.user_model import AsyncSessionLocal, create_personal_team
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
user_db = SQLAlchemyUserDatabase(session, User)
|
||||
user_manager = UserManager(user_db)
|
||||
|
||||
stmt = select(User).where(User.email == email)
|
||||
result = await session.execute(stmt)
|
||||
existing_user = result.unique().scalar_one_or_none()
|
||||
|
||||
created = existing_user is None
|
||||
if created:
|
||||
await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
first_name=first_name or None,
|
||||
last_name=last_name or None,
|
||||
),
|
||||
safe=False,
|
||||
request=None,
|
||||
)
|
||||
else:
|
||||
await user_manager.update(
|
||||
UserUpdate(
|
||||
password=password,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
first_name=first_name or existing_user.first_name,
|
||||
last_name=last_name or existing_user.last_name,
|
||||
),
|
||||
existing_user,
|
||||
safe=False,
|
||||
request=None,
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
user = result.unique().scalar_one()
|
||||
user_id = str(user.id)
|
||||
|
||||
team_stmt = select(UserTeam).where(UserTeam.user_id == user_id).limit(1)
|
||||
team_result = await session.execute(team_stmt)
|
||||
user_team = team_result.scalar_one_or_none()
|
||||
|
||||
if user_team is None:
|
||||
personal_team = await create_personal_team(session, user)
|
||||
user_team = UserTeam(user_id=user_id, team_id=personal_team.id, role=TeamRole.OWNER.value)
|
||||
session.add(user_team)
|
||||
await session.commit()
|
||||
team_id = personal_team.id
|
||||
else:
|
||||
team_id = user_team.team_id
|
||||
|
||||
try:
|
||||
await initialize_team_local_provider(session, team_id, user_id)
|
||||
except Exception as exc:
|
||||
print(f"warning: failed to initialize local provider for {email}: {exc}")
|
||||
|
||||
print(
|
||||
f"{'created' if created else 'updated'} default TransformerLab user {email} "
|
||||
f"(team_id={team_id})"
|
||||
)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
transformerlab_dir = Path(args.transformerlab_dir).expanduser().resolve()
|
||||
|
||||
try:
|
||||
asyncio.run(
|
||||
ensure_user(
|
||||
transformerlab_dir=transformerlab_dir,
|
||||
email=args.email,
|
||||
password=args.password,
|
||||
first_name=args.first_name,
|
||||
last_name=args.last_name,
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user