From 44f94cc28e379847e157ee650aa2d42d90047584 Mon Sep 17 00:00:00 2001 From: Ethan Li Date: Fri, 24 Jan 2025 17:12:44 -0800 Subject: [PATCH 1/3] Use env_id to replace zone as gcp_settings key to support multiple env under the same zone --- .axlearn/axlearn.default.config | 1 + axlearn/cloud/gcp/config.py | 30 ++++-- axlearn/cloud/gcp/config_test.py | 138 ++++++++++++++++++++++++++- axlearn/cloud/gcp/job.py | 10 +- axlearn/cloud/gcp/jobs/bastion_vm.py | 15 +-- axlearn/cloud/gcp/jobs/launch.py | 10 +- axlearn/cloud/gcp/test_utils.py | 2 +- axlearn/cloud/gcp/utils.py | 3 + 8 files changed, 185 insertions(+), 24 deletions(-) diff --git a/.axlearn/axlearn.default.config b/.axlearn/axlearn.default.config index 4aee480b4..ee2287037 100644 --- a/.axlearn/axlearn.default.config +++ b/.axlearn/axlearn.default.config @@ -4,6 +4,7 @@ # Project, zone, bucket, and network. project = "my-gcp-project" +env_id = "us-central2-b" zone = "us-central2-b" network = "projects/my-gcp-project/global/networks/default" subnetwork = "projects/my-gcp-project/regions/us-central2/subnetworks/default" diff --git a/axlearn/cloud/gcp/config.py b/axlearn/cloud/gcp/config.py index b8b9c3e8d..ca71b1fcb 100644 --- a/axlearn/cloud/gcp/config.py +++ b/axlearn/cloud/gcp/config.py @@ -73,6 +73,11 @@ def default_zone(): return _gcp_settings_from_active_config("zone") +def default_env_id(): + # When env_id is not set, fall back to zone for backwards compatibility. + return _gcp_settings_from_active_config("env_id") or _gcp_settings_from_active_config("zone") + + def gcp_settings( key: str, *, @@ -106,10 +111,16 @@ def gcp_settings( zone = flag_values.get("zone", None) if key == "zone" and zone: return zone + + # For backwards compatibility, env_id defaults to zone if not specified. + env_id = flag_values.get("env_id", zone) + if key == "env_id" and env_id: + return env_id + required = required and default is None config_file, configs = config.load_configs(CONFIG_NAMESPACE, required=required) - if project and zone: - config_name = _project_config_key(project, zone) + if project and env_id: + config_name = _project_config_key(project, env_id) else: # Try to infer from active config. config_name = configs.get("_active", None) @@ -118,16 +129,21 @@ def gcp_settings( if required and not project_configs: # TODO(markblee): Link to docs once available. logging.error( - "Unknown settings for project=%s and zone=%s; " + "Unknown settings for project=%s and env_id=%s; " "You may want to configure this project first; Please refer to the docs for details.", project, - zone, + env_id, ) sys.exit(1) # Only set the default value if the field is omitted. Explicitly falsey values should not be # defaulted. value = project_configs.get(key, default) + + if key == "env_id" and value is None: + # Fall back to "zone" for backwards compatibility. + value = project_configs.get("zone") + if required and value is None: logging.error("Could not find key %s in settings.", key) logging.error( @@ -140,9 +156,9 @@ def gcp_settings( return value -def _project_config_key(project: str, zone: str) -> str: - """Constructs a toml-friendly name uniquely identified by project, zone.""" - return f"{project}:{zone}" +def _project_config_key(project: str, env_id: str) -> str: + """Constructs a toml-friendly name uniquely identified by project, env_id.""" + return f"{project}:{env_id}" def main(argv: Sequence[str], *, namespace: str = CONFIG_NAMESPACE, fv: flags.FlagValues = FLAGS): diff --git a/axlearn/cloud/gcp/config_test.py b/axlearn/cloud/gcp/config_test.py index d83e5e3b3..0614a3109 100644 --- a/axlearn/cloud/gcp/config_test.py +++ b/axlearn/cloud/gcp/config_test.py @@ -21,17 +21,91 @@ def test_gcp_settings(self): flag_values = flags.FlagValues() flags.DEFINE_string("project", None, "The project name.", flag_values=flag_values) + flags.DEFINE_string("env_id", None, "The env ID.", flag_values=flag_values) flags.DEFINE_string("zone", None, "The zone name.", flag_values=flag_values) - flag_values.project = "test" - flag_values.zone = "test" + flag_values.project = "test-proj" + flag_values.zone = "test-zone" + flag_values.env_id = "test-env-id" with self.assertRaisesRegex(RuntimeError, expected_regex="fv must be parsed"): gcp_config.gcp_settings("bucket", required=False, fv=flag_values) flag_values.mark_as_parsed() - self.assertEqual("test", gcp_config.gcp_settings("project", fv=flag_values)) - self.assertEqual("test", gcp_config.gcp_settings("zone", fv=flag_values)) + self.assertEqual("test-proj", gcp_config.gcp_settings("project", fv=flag_values)) + self.assertEqual("test-env-id", gcp_config.gcp_settings("env_id", fv=flag_values)) + self.assertEqual("test-zone", gcp_config.gcp_settings("zone", fv=flag_values)) + + # By default, should fail because no config file exists. + with self.assertRaises(SystemExit): + gcp_config.gcp_settings("bucket", fv=flag_values) + + # Should not fail if not required. + self.assertIsNone(gcp_config.gcp_settings("bucket", required=False, fv=flag_values)) + + # Should not fail if a default exists. + self.assertEqual( + "default", + gcp_config.gcp_settings("bucket", required=True, default="default", fv=flag_values), + ) + + # Create a default config, which should get picked up. + default_config = create_default_config(temp_dir) + + # Should fail because no config for --project and --env_id. + with self.assertRaises(SystemExit): + gcp_config.gcp_settings("bucket", fv=flag_values) + + # Should not fail if not required. + self.assertIsNone(gcp_config.gcp_settings("bucket", required=False, fv=flag_values)) + + # Write some values to the config. + config.write_configs_with_header( + str(default_config), + { + gcp_config.CONFIG_NAMESPACE: { + "test-proj:test-env-id": { + "project": "test-proj", + "env_id": "test-env-id", + "zone": "test-zone", + "bucket": "test-bucket", + } + } + }, + ) + + # Should fail because key cannot be found. + with self.assertRaises(SystemExit): + gcp_config.gcp_settings("unknown_key", fv=flag_values) + + # Should not fail if not required. + self.assertIsNone(gcp_config.gcp_settings("unknown_key", fv=flag_values, required=False)) + + # Should succeed. + self.assertEqual("test-bucket", gcp_config.gcp_settings("bucket", fv=flag_values)) + + def test_gcp_settings_when_env_id_not_set(self): + """Test the backwards compatibility for env_id. + If env_id is not set, fall back to use zone as part of the gcp_settings key.""" + + temp_dir = os.path.realpath(self._temp_root.name) + _setup_fake_repo(temp_dir) + + flag_values = flags.FlagValues() + flags.DEFINE_string("project", None, "The project name.", flag_values=flag_values) + flags.DEFINE_string("zone", None, "The zone name.", flag_values=flag_values) + flag_values.project = "test-proj" + flag_values.zone = "test-zone" + + with self.assertRaisesRegex(RuntimeError, expected_regex="fv must be parsed"): + gcp_config.gcp_settings("bucket", required=False, fv=flag_values) + + flag_values.mark_as_parsed() + + self.assertEqual("test-proj", gcp_config.gcp_settings("project", fv=flag_values)) + self.assertEqual("test-zone", gcp_config.gcp_settings("zone", fv=flag_values)) + # When env_id is not set, it falls back to zone + self.assertEqual("test-zone", gcp_config.gcp_settings("env_id", fv=flag_values)) # By default, should fail because no config file exists. with self.assertRaises(SystemExit): @@ -61,7 +135,11 @@ def test_gcp_settings(self): str(default_config), { gcp_config.CONFIG_NAMESPACE: { - "test:test": {"project": "test", "zone": "test", "bucket": "test-bucket"} + "test-proj:test-zone": { + "project": "test-proj", + "zone": "test-zone", + "bucket": "test-bucket", + } } }, ) @@ -80,6 +158,51 @@ def test_gcp_settings_with_active_config(self): temp_dir = os.path.realpath(self._temp_root.name) _setup_fake_repo(temp_dir) + flag_values = flags.FlagValues() + # Flags do not set project/zone. + flags.DEFINE_string("project", None, "The project name.", flag_values=flag_values) + flags.DEFINE_string("env_id", None, "The env ID.", flag_values=flag_values) + flags.DEFINE_string("zone", None, "The zone name.", flag_values=flag_values) + flag_values.mark_as_parsed() + + self.assertIsNone(gcp_config.default_project()) + self.assertIsNone(gcp_config.default_zone()) + self.assertIsNone(gcp_config.default_env_id()) + + # Create a default config, which should get picked up. + default_config = create_default_config(temp_dir) + # Write some values to the config. + config.write_configs_with_header( + str(default_config), + { + gcp_config.CONFIG_NAMESPACE: { + "_active": "test-proj:test-env-id", + "test-proj:test-env-id": { + "project": "test-proj", + "env_id": "test-env-id", + "zone": "test-zone", + "bucket": "test-bucket", + }, + } + }, + ) + self.assertEqual("test-proj", gcp_config.default_project()) + self.assertEqual("test-env-id", gcp_config.default_env_id()) + self.assertEqual("test-zone", gcp_config.default_zone()) + + # We follow the default config. + self.assertEqual("test-proj", gcp_config.gcp_settings("project", fv=flag_values)) + self.assertEqual("test-env-id", gcp_config.gcp_settings("env_id", fv=flag_values)) + self.assertEqual("test-zone", gcp_config.gcp_settings("zone", fv=flag_values)) + self.assertEqual("test-bucket", gcp_config.gcp_settings("bucket", fv=flag_values)) + + def test_gcp_settings_with_active_config_when_env_id_not_set(self): + """Test the backwards compatibility for env_id. + If env_id is not set, fall back to use zone as part of the gcp_settings key.""" + + temp_dir = os.path.realpath(self._temp_root.name) + _setup_fake_repo(temp_dir) + flag_values = flags.FlagValues() # Flags do not set project/zone. flags.DEFINE_string("project", None, "The project name.", flag_values=flag_values) @@ -88,6 +211,7 @@ def test_gcp_settings_with_active_config(self): self.assertIsNone(gcp_config.default_project()) self.assertIsNone(gcp_config.default_zone()) + self.assertIsNone(gcp_config.default_env_id()) # Create a default config, which should get picked up. default_config = create_default_config(temp_dir) @@ -107,8 +231,12 @@ def test_gcp_settings_with_active_config(self): ) self.assertEqual("test-proj", gcp_config.default_project()) self.assertEqual("test-zone", gcp_config.default_zone()) + # When env_id is not set, it falls back to zone + self.assertEqual("test-zone", gcp_config.default_env_id()) # We follow the default config. self.assertEqual("test-proj", gcp_config.gcp_settings("project", fv=flag_values)) self.assertEqual("test-zone", gcp_config.gcp_settings("zone", fv=flag_values)) + # When env_id is not set, it falls back to zone + self.assertEqual("test-zone", gcp_config.gcp_settings("env_id", fv=flag_values)) self.assertEqual("test-bucket", gcp_config.gcp_settings("bucket", fv=flag_values)) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index b1ad5e357..85e4bda21 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -30,7 +30,7 @@ from axlearn.cloud.common.bundler import BaseDockerBundler from axlearn.cloud.common.job import Job from axlearn.cloud.common.utils import parse_kv_flags, subprocess_run -from axlearn.cloud.gcp.config import default_project, default_zone, gcp_settings +from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone, gcp_settings from axlearn.cloud.gcp.node_pool import PRE_PROVISIONER_LABEL from axlearn.cloud.gcp.scopes import DEFAULT_TPU_SCOPES from axlearn.cloud.gcp.system_characteristics import ( @@ -75,6 +75,8 @@ class Config(Job.Config): project: Required[str] = REQUIRED # GCP zone. zone: Required[str] = REQUIRED + # GCP env_id. + env_id: Optional[str] = None # If not none, the current job will be executed as the service account. service_account: Optional[str] = None @@ -84,6 +86,12 @@ def define_flags(cls, fv: flags.FlagValues): common_kwargs = dict(flag_values=fv, allow_override=True) flags.DEFINE_string("project", default_project(), "The GCP project name.", **common_kwargs) flags.DEFINE_string("zone", default_zone(), "The GCP zone name.", **common_kwargs) + flags.DEFINE_string( + "env_id", + default_env_id(), + "The env_id, used along with project to identify gcp settings", + **common_kwargs, + ) flags.DEFINE_string( "service_account", None, diff --git a/axlearn/cloud/gcp/jobs/bastion_vm.py b/axlearn/cloud/gcp/jobs/bastion_vm.py index a917bcd06..d3f8db798 100644 --- a/axlearn/cloud/gcp/jobs/bastion_vm.py +++ b/axlearn/cloud/gcp/jobs/bastion_vm.py @@ -21,7 +21,7 @@ # Notes: # - Only docker bundler_type is supported. # - We assume the image is tagged with the same name as the bastion. - # - Unless configured in the settings, the default bastion name is -shared-bastion. + # - Unless configured in the settings, the default bastion name is -shared-bastion. # axlearn gcp bastion create --name=shared-bastion @@ -134,7 +134,7 @@ from axlearn.cloud.common.scheduler import JobScheduler from axlearn.cloud.common.uploader import Uploader, with_interval from axlearn.cloud.common.utils import configure_logging, parse_action -from axlearn.cloud.gcp.config import default_project, default_zone, gcp_settings +from axlearn.cloud.gcp.config import default_env_id, default_project, default_zone, gcp_settings from axlearn.cloud.gcp.event_queue import event_queue_from_config from axlearn.cloud.gcp.job import CPUJob, docker_command from axlearn.cloud.gcp.tpu_cleaner import TPUCleaner @@ -153,6 +153,7 @@ def _private_flags(flag_values: flags.FlagValues = FLAGS): bastion_job_flags(flag_values=flag_values) flag_values.set_default("project", default_project()) flag_values.set_default("zone", default_zone()) + flag_values.set_default("env_id", default_env_id()) flags.DEFINE_string( "vm_type", "n2-highmem-128", "Machine spec to boot for VM.", flag_values=flag_values @@ -208,15 +209,15 @@ def _validate_name(name: str): def shared_bastion_name( fv: Optional[flags.FlagValues], gcp_api: Optional[str] = None ) -> Optional[str]: - # The zone-namespacing is necessary because of quirks with compute API. Specifically, even if + # The env_id-namespacing is necessary because of quirks with compute API. Specifically, even if # creating VMs within a specific zone, names are global. On the other hand, the list API only # returns VMs within a zone, so there's no easy way to check if a shared bastion already exists # in another zone. - zone = gcp_settings("zone", fv=fv) + env_id = gcp_settings("env_id", fv=fv) if gcp_api is not None and gcp_api.lower() == GCPAPI.GKE.lower(): - default = f"{zone}-gke-bastion" + default = f"{env_id}-gke-bastion" else: - default = f"{zone}-shared-bastion" + default = f"{env_id}-shared-bastion" bastion_name = gcp_settings( # pytype: disable=bad-return-type "bastion_name", default=default, @@ -320,7 +321,7 @@ def _execute(self): # flagfile, and reading that. run_cmd = docker_command( f"python3 -m axlearn.cloud.gcp.jobs.bastion_vm --name={cfg.name} " - f"--project={cfg.project} --zone={cfg.zone} start 2>&1 | {output_cmd}", + f"--project={cfg.project} --env_id={cfg.env_id} start 2>&1 | {output_cmd}", image=image, volumes={"/var/tmp": "/var/tmp"}, detached_session=cfg.name, diff --git a/axlearn/cloud/gcp/jobs/launch.py b/axlearn/cloud/gcp/jobs/launch.py index 48d7cb3ce..d0bb7e656 100644 --- a/axlearn/cloud/gcp/jobs/launch.py +++ b/axlearn/cloud/gcp/jobs/launch.py @@ -181,6 +181,8 @@ class BaseBastionManagedJob(Job): class Config(Job.Config): """Configures BaseBastionManagedJob.""" + # Used along with project to identify gcp settings + env_id: Optional[str] = None # Where to run the remote job. zone: Required[str] = REQUIRED # Instance type to launch. @@ -363,12 +365,14 @@ def _execute(self) -> JobSpec: f"\nStop/cancel the job with:\n" f"{infer_cli_name()} gcp launch stop " f"--name={cfg.name} --bastion={cfg.bastion_name} --instance_type={cfg.instance_type} " - f"--zone={cfg.zone} --gcp_api={gcp_api}\n" + f"--env_id={cfg.env_id} --gcp_api={gcp_api}\n" "\nCheck job history with:\n" - f"{infer_cli_name()} gcp bastion history --name={cfg.bastion_name} --zone={cfg.zone} " + f"{infer_cli_name()} gcp bastion history " + f"--name={cfg.bastion_name} --env_id={cfg.env_id} " f"--job_name={cfg.name}" "\nCheck project history with:\n" - f"{infer_cli_name()} gcp bastion history --name={cfg.bastion_name} --zone={cfg.zone} " + f"{infer_cli_name()} gcp bastion history " + f"--name={cfg.bastion_name} --env_id={cfg.env_id} " f"{cfg.project_id or ''}" ) return jobspec diff --git a/axlearn/cloud/gcp/test_utils.py b/axlearn/cloud/gcp/test_utils.py index e4dfc2f72..c2400b664 100644 --- a/axlearn/cloud/gcp/test_utils.py +++ b/axlearn/cloud/gcp/test_utils.py @@ -28,7 +28,7 @@ def gcp_settings( return value def gcp_settings_from_active_config(project_or_zone: str): - return settings[project_or_zone] + return settings.get(project_or_zone, None) if isinstance(module_name, str): module_name = [module_name] diff --git a/axlearn/cloud/gcp/utils.py b/axlearn/cloud/gcp/utils.py index d49290811..7606ec833 100644 --- a/axlearn/cloud/gcp/utils.py +++ b/axlearn/cloud/gcp/utils.py @@ -26,6 +26,9 @@ def common_flags(**kwargs): """Defines common GCP flags. Keyword args will be forwarded to flag definitions.""" flags.DEFINE_string("project", None, "The GCP project name.", **kwargs) flags.DEFINE_string("zone", None, "The GCP zone name.", **kwargs) + flags.DEFINE_string( + "env_id", None, "The env_id, used along with project to identify gcp settings", **kwargs + ) def get_credentials( From 6f721f8f0b88a23998589b46f018d0154ff68924 Mon Sep 17 00:00:00 2001 From: Ethan Li Date: Sun, 26 Jan 2025 16:04:12 -0800 Subject: [PATCH 2/3] fall back to zone --- axlearn/cloud/gcp/jobs/bastion_vm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/axlearn/cloud/gcp/jobs/bastion_vm.py b/axlearn/cloud/gcp/jobs/bastion_vm.py index d3f8db798..b9eae7cf4 100644 --- a/axlearn/cloud/gcp/jobs/bastion_vm.py +++ b/axlearn/cloud/gcp/jobs/bastion_vm.py @@ -213,7 +213,8 @@ def shared_bastion_name( # creating VMs within a specific zone, names are global. On the other hand, the list API only # returns VMs within a zone, so there's no easy way to check if a shared bastion already exists # in another zone. - env_id = gcp_settings("env_id", fv=fv) + # If env_id is not set, fall back to "zone" for backwards compatibility. + env_id = gcp_settings("env_id", fv=fv, required=False) or gcp_settings("zone", fv=fv) if gcp_api is not None and gcp_api.lower() == GCPAPI.GKE.lower(): default = f"{env_id}-gke-bastion" else: From cb6f143f31bd1e658e925431fe0ec565afd1fb00 Mon Sep 17 00:00:00 2001 From: Ethan Li Date: Tue, 28 Jan 2025 17:36:49 -0800 Subject: [PATCH 3/3] address comments --- axlearn/cloud/gcp/config.py | 32 ++++++++++++++++++++++++++++---- axlearn/cloud/gcp/job.py | 2 +- axlearn/cloud/gcp/jobs/launch.py | 2 +- axlearn/cloud/gcp/utils.py | 2 +- 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/axlearn/cloud/gcp/config.py b/axlearn/cloud/gcp/config.py index ca71b1fcb..1578c4e28 100644 --- a/axlearn/cloud/gcp/config.py +++ b/axlearn/cloud/gcp/config.py @@ -65,16 +65,40 @@ def _gcp_settings_from_active_config(key: str) -> Optional[str]: return project_configs.get(key, None) -def default_project(): +def default_project() -> Optional[str]: + """Default project from active `gcp_settings`. + + Project is used along with env_id to identify `gcp_settings`. + + Returns: the project in active `gcp_settings` config. + """ + return _gcp_settings_from_active_config("project") -def default_zone(): +def default_zone() -> Optional[str]: + """Default zone from active `gcp_settings`. + + Besides specifying the GCP zone, this value was also used + along with project to identify `gcp_settings`. It is being replaced by + env_id. See `default_env_id`. + + Returns: the zone in active `gcp_settings` config. + """ + return _gcp_settings_from_active_config("zone") -def default_env_id(): - # When env_id is not set, fall back to zone for backwards compatibility. +def default_env_id() -> Optional[str]: + """Default env_id value from active `gcp_settings`. + + Env_id is used along with project to identify `gcp_settings`. + + When env_id is None, fall back to zone for backwards compatibility. + + Returns: the env_id in active `gcp_settings` config; if it doesn't exist, returns the zone. + """ + return _gcp_settings_from_active_config("env_id") or _gcp_settings_from_active_config("zone") diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 85e4bda21..3aec404dd 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -89,7 +89,7 @@ def define_flags(cls, fv: flags.FlagValues): flags.DEFINE_string( "env_id", default_env_id(), - "The env_id, used along with project to identify gcp settings", + "The env_id, used along with project to identify `gcp_settings`.", **common_kwargs, ) flags.DEFINE_string( diff --git a/axlearn/cloud/gcp/jobs/launch.py b/axlearn/cloud/gcp/jobs/launch.py index d0bb7e656..8b1a76b92 100644 --- a/axlearn/cloud/gcp/jobs/launch.py +++ b/axlearn/cloud/gcp/jobs/launch.py @@ -181,7 +181,7 @@ class BaseBastionManagedJob(Job): class Config(Job.Config): """Configures BaseBastionManagedJob.""" - # Used along with project to identify gcp settings + # Used along with project to identify `gcp_settings`. env_id: Optional[str] = None # Where to run the remote job. zone: Required[str] = REQUIRED diff --git a/axlearn/cloud/gcp/utils.py b/axlearn/cloud/gcp/utils.py index 7606ec833..e85a5f88d 100644 --- a/axlearn/cloud/gcp/utils.py +++ b/axlearn/cloud/gcp/utils.py @@ -27,7 +27,7 @@ def common_flags(**kwargs): flags.DEFINE_string("project", None, "The GCP project name.", **kwargs) flags.DEFINE_string("zone", None, "The GCP zone name.", **kwargs) flags.DEFINE_string( - "env_id", None, "The env_id, used along with project to identify gcp settings", **kwargs + "env_id", None, "The env_id, used along with project to identify `gcp_settings`.", **kwargs )