Skip to content
39 changes: 36 additions & 3 deletions src/zenml/cli/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,32 @@ def list_projects(ctx: click.Context, /, **kwargs: Any) -> None:
required=False,
help="The display name of the project.",
)
@click.option(
"--set-default",
"set_default",
is_flag=True,
help="Set this project as the default project.",
)
@click.argument("project_name", type=str, required=True)
def register_project(
project_name: str,
set_project: bool = False,
display_name: Optional[str] = None,
set_default: bool = False,
) -> None:
"""Register a new project.

Args:
project_name: The name of the project to register.
set_project: Whether to set the project as active.
display_name: The display name of the project.
set_default: Whether to set the project as the default project.
"""
check_zenml_pro_project_availability()
client = Client()
with console.status("Creating project...\n"):
try:
client.create_project(
project = client.create_project(
project_name,
description="",
display_name=display_name,
Expand All @@ -105,26 +113,51 @@ def register_project(
client.set_active_project(project_name)
cli_utils.declare(f"The active project has been set to {project_name}")

if set_default:
client.update_user(
name_id_or_prefix=client.active_user.id,
updated_default_project_id=project.id,
)
cli_utils.declare(
f"The default project has been set to {project.name}"
)


@project.command("set")
@click.argument("project_name_or_id", type=str, required=True)
def set_project(project_name_or_id: str) -> None:
@click.option(
"--default",
"default",
is_flag=True,
help="Set this project as the default project.",
)
def set_project(project_name_or_id: str, default: bool = False) -> None:
"""Set the active project.

Args:
project_name_or_id: The name or ID of the project to set as active.
default: Whether to set the project as the default project.
"""
check_zenml_pro_project_availability()
client = Client()
with console.status("Setting project...\n"):
try:
client.set_active_project(project_name_or_id)
project = client.set_active_project(project_name_or_id)
cli_utils.declare(
f"The active project has been set to {project_name_or_id}"
)
except Exception as e:
cli_utils.error(str(e))

if default:
client.update_user(
name_id_or_prefix=client.active_user.id,
updated_default_project_id=project.id,
)
cli_utils.declare(
f"The default project has been set to {project.name}"
)


@project.command("describe")
@click.argument("project_name_or_id", type=str, required=False)
Expand Down
5 changes: 5 additions & 0 deletions src/zenml/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,7 @@ def update_user(
old_password: Optional[str] = None,
updated_is_admin: Optional[bool] = None,
updated_metadata: Optional[Dict[str, Any]] = None,
updated_default_project_id: Optional[UUID] = None,
active: Optional[bool] = None,
) -> UserResponse:
"""Update a user.
Expand All @@ -891,6 +892,7 @@ def update_user(
update.
updated_is_admin: Whether the user should be an admin.
updated_metadata: The new metadata for the user.
updated_default_project_id: The new default project ID for the user.
active: Use to activate or deactivate the user.

Returns:
Expand Down Expand Up @@ -928,6 +930,9 @@ def update_user(
if updated_metadata is not None:
user_update.user_metadata = updated_metadata

if updated_default_project_id is not None:
user_update.default_project_id = updated_default_project_id

return self.zen_store.update_user(
user_id=user.id, user_update=user_update
)
Expand Down
5 changes: 4 additions & 1 deletion src/zenml/config/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass):
version: Optional[str] = None
store: Optional[SerializeAsAny[StoreConfiguration]] = None
active_stack_id: Optional[uuid.UUID] = None
active_project_id: Optional[uuid.UUID] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see this being used anywhere. Why do you need it ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used below as an argument to validate_active_config. If we were to keep using the name, we would keep a project with the same name when connecting to a different server, and therefore ignore the default project configured for the user.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you then be okay with replacing active_project_name with active_project_id ? this is in fact how it's done on the client/repo (i.e. zenml init) configuration side of things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I was wondering whether we should keep it in there for backwards compatibility somehow, but I'm happy to remove it entirely!

active_project_name: Optional[str] = None

_zen_store: Optional["BaseZenStore"] = None
Expand Down Expand Up @@ -393,14 +394,16 @@ def _sanitize_config(self) -> None:
if ENV_ZENML_SERVER in os.environ:
return
active_project, active_stack = self.zen_store.validate_active_config(
self.active_project_name,
self.active_project_id or self.active_project_name,
self.active_stack_id,
config_name="global",
)
if active_project:
self.active_project_id = active_project.id
self.active_project_name = active_project.name
self._active_project = active_project
else:
self.active_project_id = None
self.active_project_name = None
self._active_project = None

Expand Down
17 changes: 17 additions & 0 deletions src/zenml/models/v2/core/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ class UserUpdate(UserBase, BaseUpdate):
"accounts. Required when updating the password.",
max_length=STR_FIELD_MAX_LENGTH,
)
default_project_id: Optional[UUID] = Field(
default=None,
title="The default project ID for the user.",
)

@model_validator(mode="after")
def user_email_updates(self) -> "UserUpdate":
Expand Down Expand Up @@ -279,6 +283,10 @@ class UserResponseBody(BaseDatedResponseBody):
is_admin: bool = Field(
title="Whether the account is an administrator.",
)
default_project_id: Optional[UUID] = Field(
default=None,
title="The default project ID for the user.",
)


class UserResponseMetadata(BaseResponseMetadata):
Expand Down Expand Up @@ -422,6 +430,15 @@ def user_metadata(self) -> Dict[str, Any]:
"""
return self.get_metadata().user_metadata

@property
def default_project_id(self) -> Optional[UUID]:
"""The `default_project_id` property.

Returns:
the value of the property.
"""
return self.get_body().default_project_id

# Helper methods
@classmethod
def _get_crypt_context(cls) -> "CryptContext":
Expand Down
Loading