Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use env id for gcp settings #957

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .axlearn/axlearn.default.config
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Project, zone, bucket, and network.
project = "my-gcp-project"
env_id = "us-central2-b"
zone = "us-central2-b"
network = "projects/my-gcp-project/global/networks/default"
subnetwork = "projects/my-gcp-project/regions/us-central2/subnetworks/default"
Expand Down
58 changes: 49 additions & 9 deletions axlearn/cloud/gcp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,43 @@ def _gcp_settings_from_active_config(key: str) -> Optional[str]:
return project_configs.get(key, None)


def default_project():
def default_project() -> Optional[str]:
"""Default project from active `gcp_settings`.

Project is used along with env_id to identify `gcp_settings`.

Returns: the project in active `gcp_settings` config.
"""

return _gcp_settings_from_active_config("project")


def default_zone():
def default_zone() -> Optional[str]:
"""Default zone from active `gcp_settings`.

Besides specifying the GCP zone, this value was also used
along with project to identify `gcp_settings`. It is being replaced by
env_id. See `default_env_id`.

Returns: the zone in active `gcp_settings` config.
"""

return _gcp_settings_from_active_config("zone")


def default_env_id() -> Optional[str]:
"""Default env_id value from active `gcp_settings`.

Env_id is used along with project to identify `gcp_settings`.

When env_id is None, fall back to zone for backwards compatibility.

Returns: the env_id in active `gcp_settings` config; if it doesn't exist, returns the zone.
"""

return _gcp_settings_from_active_config("env_id") or _gcp_settings_from_active_config("zone")


def gcp_settings(
key: str,
*,
Expand Down Expand Up @@ -106,10 +135,16 @@ def gcp_settings(
zone = flag_values.get("zone", None)
if key == "zone" and zone:
return zone

# For backwards compatibility, env_id defaults to zone if not specified.
env_id = flag_values.get("env_id", zone)
if key == "env_id" and env_id:
return env_id

required = required and default is None
config_file, configs = config.load_configs(CONFIG_NAMESPACE, required=required)
if project and zone:
config_name = _project_config_key(project, zone)
if project and env_id:
config_name = _project_config_key(project, env_id)
else:
# Try to infer from active config.
config_name = configs.get("_active", None)
Expand All @@ -118,16 +153,21 @@ def gcp_settings(
if required and not project_configs:
# TODO(markblee): Link to docs once available.
logging.error(
"Unknown settings for project=%s and zone=%s; "
"Unknown settings for project=%s and env_id=%s; "
"You may want to configure this project first; Please refer to the docs for details.",
project,
zone,
env_id,
)
sys.exit(1)

# Only set the default value if the field is omitted. Explicitly falsey values should not be
# defaulted.
value = project_configs.get(key, default)

if key == "env_id" and value is None:
# Fall back to "zone" for backwards compatibility.
value = project_configs.get("zone")

if required and value is None:
logging.error("Could not find key %s in settings.", key)
logging.error(
Expand All @@ -140,9 +180,9 @@ def gcp_settings(
return value


def _project_config_key(project: str, zone: str) -> str:
"""Constructs a toml-friendly name uniquely identified by project, zone."""
return f"{project}:{zone}"
def _project_config_key(project: str, env_id: str) -> str:
"""Constructs a toml-friendly name uniquely identified by project, env_id."""
return f"{project}:{env_id}"


def main(argv: Sequence[str], *, namespace: str = CONFIG_NAMESPACE, fv: flags.FlagValues = FLAGS):
Expand Down
138 changes: 133 additions & 5 deletions axlearn/cloud/gcp/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,91 @@ def test_gcp_settings(self):

flag_values = flags.FlagValues()
flags.DEFINE_string("project", None, "The project name.", flag_values=flag_values)
flags.DEFINE_string("env_id", None, "The env ID.", flag_values=flag_values)
flags.DEFINE_string("zone", None, "The zone name.", flag_values=flag_values)
flag_values.project = "test"
flag_values.zone = "test"
flag_values.project = "test-proj"
flag_values.zone = "test-zone"
flag_values.env_id = "test-env-id"

with self.assertRaisesRegex(RuntimeError, expected_regex="fv must be parsed"):
gcp_config.gcp_settings("bucket", required=False, fv=flag_values)

flag_values.mark_as_parsed()

self.assertEqual("test", gcp_config.gcp_settings("project", fv=flag_values))
self.assertEqual("test", gcp_config.gcp_settings("zone", fv=flag_values))
self.assertEqual("test-proj", gcp_config.gcp_settings("project", fv=flag_values))
self.assertEqual("test-env-id", gcp_config.gcp_settings("env_id", fv=flag_values))
self.assertEqual("test-zone", gcp_config.gcp_settings("zone", fv=flag_values))

# By default, should fail because no config file exists.
with self.assertRaises(SystemExit):
gcp_config.gcp_settings("bucket", fv=flag_values)

# Should not fail if not required.
self.assertIsNone(gcp_config.gcp_settings("bucket", required=False, fv=flag_values))

# Should not fail if a default exists.
self.assertEqual(
"default",
gcp_config.gcp_settings("bucket", required=True, default="default", fv=flag_values),
)

# Create a default config, which should get picked up.
default_config = create_default_config(temp_dir)

# Should fail because no config for --project and --env_id.
with self.assertRaises(SystemExit):
gcp_config.gcp_settings("bucket", fv=flag_values)

# Should not fail if not required.
self.assertIsNone(gcp_config.gcp_settings("bucket", required=False, fv=flag_values))

# Write some values to the config.
config.write_configs_with_header(
str(default_config),
{
gcp_config.CONFIG_NAMESPACE: {
"test-proj:test-env-id": {
"project": "test-proj",
"env_id": "test-env-id",
"zone": "test-zone",
"bucket": "test-bucket",
}
}
},
)

# Should fail because key cannot be found.
with self.assertRaises(SystemExit):
gcp_config.gcp_settings("unknown_key", fv=flag_values)

# Should not fail if not required.
self.assertIsNone(gcp_config.gcp_settings("unknown_key", fv=flag_values, required=False))

# Should succeed.
self.assertEqual("test-bucket", gcp_config.gcp_settings("bucket", fv=flag_values))

def test_gcp_settings_when_env_id_not_set(self):
"""Test the backwards compatibility for env_id.
If env_id is not set, fall back to use zone as part of the gcp_settings key."""

temp_dir = os.path.realpath(self._temp_root.name)
_setup_fake_repo(temp_dir)

flag_values = flags.FlagValues()
flags.DEFINE_string("project", None, "The project name.", flag_values=flag_values)
flags.DEFINE_string("zone", None, "The zone name.", flag_values=flag_values)
flag_values.project = "test-proj"
flag_values.zone = "test-zone"

with self.assertRaisesRegex(RuntimeError, expected_regex="fv must be parsed"):
gcp_config.gcp_settings("bucket", required=False, fv=flag_values)

flag_values.mark_as_parsed()

self.assertEqual("test-proj", gcp_config.gcp_settings("project", fv=flag_values))
self.assertEqual("test-zone", gcp_config.gcp_settings("zone", fv=flag_values))
# When env_id is not set, it falls back to zone
self.assertEqual("test-zone", gcp_config.gcp_settings("env_id", fv=flag_values))

# By default, should fail because no config file exists.
with self.assertRaises(SystemExit):
Expand Down Expand Up @@ -61,7 +135,11 @@ def test_gcp_settings(self):
str(default_config),
{
gcp_config.CONFIG_NAMESPACE: {
"test:test": {"project": "test", "zone": "test", "bucket": "test-bucket"}
"test-proj:test-zone": {
"project": "test-proj",
"zone": "test-zone",
"bucket": "test-bucket",
}
}
},
)
Expand All @@ -80,6 +158,51 @@ def test_gcp_settings_with_active_config(self):
temp_dir = os.path.realpath(self._temp_root.name)
_setup_fake_repo(temp_dir)

flag_values = flags.FlagValues()
# Flags do not set project/zone.
flags.DEFINE_string("project", None, "The project name.", flag_values=flag_values)
flags.DEFINE_string("env_id", None, "The env ID.", flag_values=flag_values)
flags.DEFINE_string("zone", None, "The zone name.", flag_values=flag_values)
flag_values.mark_as_parsed()

self.assertIsNone(gcp_config.default_project())
self.assertIsNone(gcp_config.default_zone())
self.assertIsNone(gcp_config.default_env_id())

# Create a default config, which should get picked up.
default_config = create_default_config(temp_dir)
# Write some values to the config.
config.write_configs_with_header(
str(default_config),
{
gcp_config.CONFIG_NAMESPACE: {
"_active": "test-proj:test-env-id",
"test-proj:test-env-id": {
"project": "test-proj",
"env_id": "test-env-id",
"zone": "test-zone",
"bucket": "test-bucket",
},
}
},
)
self.assertEqual("test-proj", gcp_config.default_project())
self.assertEqual("test-env-id", gcp_config.default_env_id())
self.assertEqual("test-zone", gcp_config.default_zone())

# We follow the default config.
self.assertEqual("test-proj", gcp_config.gcp_settings("project", fv=flag_values))
self.assertEqual("test-env-id", gcp_config.gcp_settings("env_id", fv=flag_values))
self.assertEqual("test-zone", gcp_config.gcp_settings("zone", fv=flag_values))
self.assertEqual("test-bucket", gcp_config.gcp_settings("bucket", fv=flag_values))

def test_gcp_settings_with_active_config_when_env_id_not_set(self):
"""Test the backwards compatibility for env_id.
If env_id is not set, fall back to use zone as part of the gcp_settings key."""

temp_dir = os.path.realpath(self._temp_root.name)
_setup_fake_repo(temp_dir)

flag_values = flags.FlagValues()
# Flags do not set project/zone.
flags.DEFINE_string("project", None, "The project name.", flag_values=flag_values)
Expand All @@ -88,6 +211,7 @@ def test_gcp_settings_with_active_config(self):

self.assertIsNone(gcp_config.default_project())
self.assertIsNone(gcp_config.default_zone())
self.assertIsNone(gcp_config.default_env_id())

# Create a default config, which should get picked up.
default_config = create_default_config(temp_dir)
Expand All @@ -107,8 +231,12 @@ def test_gcp_settings_with_active_config(self):
)
self.assertEqual("test-proj", gcp_config.default_project())
self.assertEqual("test-zone", gcp_config.default_zone())
# When env_id is not set, it falls back to zone
self.assertEqual("test-zone", gcp_config.default_env_id())

# We follow the default config.
self.assertEqual("test-proj", gcp_config.gcp_settings("project", fv=flag_values))
self.assertEqual("test-zone", gcp_config.gcp_settings("zone", fv=flag_values))
# When env_id is not set, it falls back to zone
self.assertEqual("test-zone", gcp_config.gcp_settings("env_id", fv=flag_values))
self.assertEqual("test-bucket", gcp_config.gcp_settings("bucket", fv=flag_values))
10 changes: 9 additions & 1 deletion axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from axlearn.cloud.common.bundler import BaseDockerBundler
from axlearn.cloud.common.job import Job
from axlearn.cloud.common.utils import parse_kv_flags, subprocess_run
from axlearn.cloud.gcp.config import default_project, default_zone, gcp_settings
from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone, gcp_settings
from axlearn.cloud.gcp.node_pool import PRE_PROVISIONER_LABEL
from axlearn.cloud.gcp.scopes import DEFAULT_TPU_SCOPES
from axlearn.cloud.gcp.system_characteristics import (
Expand Down Expand Up @@ -75,6 +75,8 @@ class Config(Job.Config):
project: Required[str] = REQUIRED
# GCP zone.
zone: Required[str] = REQUIRED
# GCP env_id.
env_id: Optional[str] = None
# If not none, the current job will be executed as the service account.
service_account: Optional[str] = None

Expand All @@ -84,6 +86,12 @@ def define_flags(cls, fv: flags.FlagValues):
common_kwargs = dict(flag_values=fv, allow_override=True)
flags.DEFINE_string("project", default_project(), "The GCP project name.", **common_kwargs)
flags.DEFINE_string("zone", default_zone(), "The GCP zone name.", **common_kwargs)
flags.DEFINE_string(
"env_id",
default_env_id(),
"The env_id, used along with project to identify `gcp_settings`.",
**common_kwargs,
)
flags.DEFINE_string(
"service_account",
None,
Expand Down
16 changes: 9 additions & 7 deletions axlearn/cloud/gcp/jobs/bastion_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Notes:
# - Only docker bundler_type is supported.
# - We assume the image is tagged with the same name as the bastion.
# - Unless configured in the settings, the default bastion name is <zone>-shared-bastion.
# - Unless configured in the settings, the default bastion name is <env_id>-shared-bastion.
#
axlearn gcp bastion create --name=shared-bastion

Expand Down Expand Up @@ -134,7 +134,7 @@
from axlearn.cloud.common.scheduler import JobScheduler
from axlearn.cloud.common.uploader import Uploader, with_interval
from axlearn.cloud.common.utils import configure_logging, parse_action
from axlearn.cloud.gcp.config import default_project, default_zone, gcp_settings
from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone, gcp_settings
from axlearn.cloud.gcp.event_queue import event_queue_from_config
from axlearn.cloud.gcp.job import CPUJob, docker_command
from axlearn.cloud.gcp.tpu_cleaner import TPUCleaner
Expand All @@ -153,6 +153,7 @@ def _private_flags(flag_values: flags.FlagValues = FLAGS):
bastion_job_flags(flag_values=flag_values)
flag_values.set_default("project", default_project())
flag_values.set_default("zone", default_zone())
flag_values.set_default("env_id", default_env_id())

flags.DEFINE_string(
"vm_type", "n2-highmem-128", "Machine spec to boot for VM.", flag_values=flag_values
Expand Down Expand Up @@ -208,15 +209,16 @@ def _validate_name(name: str):
def shared_bastion_name(
fv: Optional[flags.FlagValues], gcp_api: Optional[str] = None
) -> Optional[str]:
# The zone-namespacing is necessary because of quirks with compute API. Specifically, even if
# The env_id-namespacing is necessary because of quirks with compute API. Specifically, even if
# creating VMs within a specific zone, names are global. On the other hand, the list API only
# returns VMs within a zone, so there's no easy way to check if a shared bastion already exists
# in another zone.
zone = gcp_settings("zone", fv=fv)
# If env_id is not set, fall back to "zone" for backwards compatibility.
env_id = gcp_settings("env_id", fv=fv, required=False) or gcp_settings("zone", fv=fv)
if gcp_api is not None and gcp_api.lower() == GCPAPI.GKE.lower():
default = f"{zone}-gke-bastion"
default = f"{env_id}-gke-bastion"
else:
default = f"{zone}-shared-bastion"
default = f"{env_id}-shared-bastion"
bastion_name = gcp_settings( # pytype: disable=bad-return-type
"bastion_name",
default=default,
Expand Down Expand Up @@ -320,7 +322,7 @@ def _execute(self):
# flagfile, and reading that.
run_cmd = docker_command(
f"python3 -m axlearn.cloud.gcp.jobs.bastion_vm --name={cfg.name} "
f"--project={cfg.project} --zone={cfg.zone} start 2>&1 | {output_cmd}",
f"--project={cfg.project} --env_id={cfg.env_id} start 2>&1 | {output_cmd}",
image=image,
volumes={"/var/tmp": "/var/tmp"},
detached_session=cfg.name,
Expand Down
Loading