diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 22a93b675..1cba5503d 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -5,6 +5,7 @@ Note that these utilities do not handle resource management. """ import atexit +import importlib import io import logging import math @@ -455,6 +456,7 @@ class Config(GKEJob.Config): enable_tpu_ici_resiliency: Optional[bool] = None location_hint: Optional[str] = None enable_tpu_smart_repair: bool = False + use_pathways: Optional[bool] = False @classmethod def define_flags(cls, fv: flags.FlagValues): @@ -469,6 +471,9 @@ def define_flags(cls, fv: flags.FlagValues): "not all TPU types support this flag.", **common_kwargs, ) + flags.DEFINE_boolean( + "use_pathways", False, "Wether the workload is pathways-enabled.", **common_kwargs + ) @classmethod def from_flags(cls, fv: flags.FlagValues, **kwargs) -> Config: @@ -492,7 +497,7 @@ def __init__(self, cfg: Config): raise NotImplementedError(f"Missing system characteristics for {self._tpu_type}") super().__init__(cfg) self._output_volume_mount = dict(name="shared-output", mountPath="/output") - + def _maybe_add_volume_mount(self, volume_mounts: list[dict], *, spec: Optional[VolumeMount]): if spec: volume_mounts.append( @@ -503,7 +508,17 @@ def _maybe_add_volume_mount(self, volume_mounts: list[dict], *, spec: Optional[V ), ) - def _build_container(self) -> Nested[Any]: + def _get_pathways_tpu_type(self, device: str) -> str: + pathways_tpu_devices = { + "v6e": "tpuv6e", + "v5p": "tpuv5", + "v5litepod": "tpuv5e", + "v4": "tpuv4", + "v3": "tpuv3", + } + return pathways_tpu_devices[device.split("-")[0].lower()] + + def _build_container(self, job_type: str = None) -> Nested[Any]: """Builds a config for a single container. Returns: @@ -512,6 +527,7 @@ def _build_container(self) -> Nested[Any]: cfg: TPUGKEJob.Config = self.config system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type] volume_mounts = [self._output_volume_mount] + resources = {"limits": {}} self._maybe_add_volume_mount(volume_mounts, spec=cfg.gcsfuse_mount) if cfg.host_mounts: @@ -522,15 +538,96 @@ def _build_container(self) -> Nested[Any]: if cfg.enable_tpu_ici_resiliency is not None: env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower() - resources = {"limits": {"google.com/tpu": system.chips_per_vm}} - # Set request memory by host machine type. - machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get( - system.gce_machine_type, None - ) - if machine_memory_gi is not None: - request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE - resources["limits"]["memory"] = f"{machine_memory_gi}Gi" - resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"} + if not cfg.use_pathways: + resources = {"limits": {"google.com/tpu": system.chips_per_vm}} + # Set request memory by host machine type. + machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get( + system.gce_machine_type, None + ) + if machine_memory_gi is not None: + request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE + resources["limits"]["memory"] = f"{machine_memory_gi}Gi" + resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"} + + container_name = cfg.name + args = [] + image = self._bundler.id(cfg.name) + ports = [ + dict(containerPort=8471), # Port using which TPU VMs communicate. + dict(containerPort=8080), # Port for MXLA coordinator. + dict(containerPort=8431), # Port to export TPU runtime metrics. + ] + + if cfg.use_pathways: + container_name = f"{cfg.name}-{job_type}" + volume_mounts.append( + dict( + name="shared-tmp", + mountPath="/tmp", + ), + ) + staging_location = "gs://cloud-pathways-staging/tmp" + cluster = gcp_settings("gke_cluster") + rm_address = f"{cfg.name}-rm-0-0.{cfg.name}.default.svc.{cluster}-domain.:38677" + + if job_type == "worker": + args.extend( + [ + "--server_port=38677", + f"--resource_manager_address={rm_address}", + f"--gcs_scratch_location={staging_location}", + ] + ) + image = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest" + ports.append(dict(containerPort=38677)) + resources = {"limits": {"google.com/tpu": system.chips_per_vm}} + + elif job_type == "rm": + tpu_type = self._get_pathways_tpu_type(system.device_type) + args.extend( + [ + "--server_port=38677", + "--node_type=resource_manager", + f"--gcs_scratch_location={staging_location}", + f"--instance_count={system.vms_per_slice}", + f"--instance_type={tpu_type}:{system.topology}", + ] + ) + env_vars.update( + TPU_SKIP_MDS_QUERY="true", + HOST_ADDRESS="$(JOBSET_NAME)-$(REPLICATED_JOB_NAME)-0-0.$(JOBSET_NAME)", + REPLICATED_JOB_NAME=job_type, + JOBSET_NAME=cfg.name, + ) + image = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:latest" + resources["limits"]["memory"] = "8Gi" + resources["limits"]["cpu"] = "4" + ports.append(dict(containerPort=38677)) + + elif job_type == "proxy": + args.extend( + [ + "--server_port=38676", + f"--resource_manager_address={rm_address}", + f"--gcs_scratch_location={staging_location}", + ] + ) + image = "us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:latest" + resources["limits"]["memory"] = "100Gi" + resources["limits"]["cpu"] = "24" + ports.append(dict(containerPort=38676)) + + elif job_type == "user": + resources["limits"]["memory"] = "100Gi" + resources["limits"]["cpu"] = "24" + proxy = ( + f"grpc://{cfg.name}-proxy-0-0.{cfg.name}.default.svc.{cluster}-domain.:38676" + ) + env_vars.update( + JAX_BACKEND_TARGET=proxy, + XCLOUD_ENVIRONMENT="GCP", + JOBSET_NAME=cfg.name, + ) k8s_env_vars = [dict(name=k, value=str(v)) for k, v in env_vars.items()] k8s_env_vars.append( @@ -555,18 +652,16 @@ def _build_container(self) -> Nested[Any]: ) return dict( - name=cfg.name, - image=self._bundler.id(cfg.name), + name=container_name, + image=image, # https://cloud.google.com/kubernetes-engine/docs/how-to/tpus#tpu-chips-node-pool # https://cloud.google.com/kubernetes-engine/docs/how-to/tpu-multislice#run_workload - ports=[ - dict(containerPort=8471), # Port using which TPU VMs communicate. - dict(containerPort=8080), # Port for MXLA coordinator. - dict(containerPort=8431), # Port to export TPU runtime metrics. - ], + ports=ports, securityContext=dict(privileged=True), # TODO(markblee): Improve SIGTERM behavior for command. - command=["bash", "-c", cfg.command], + command=["bash", "-c", cfg.command] + if not cfg.use_pathways or job_type == "user" + else None, resources=resources, # Env var values should always be strings. env=k8s_env_vars, @@ -609,9 +704,10 @@ def _build_uploader_container(self) -> Nested[Any]: args=[sync_command], resources=resources, volumeMounts=volume_mounts, + args=args, ) - def _build_pod(self) -> Nested[Any]: + def _build_pod(self, job_type: str = None) -> Nested[Any]: """Builds a config for a single Pod, which is a set of containers. https://kubernetes.io/docs/concepts/workloads/pods @@ -665,22 +761,24 @@ def _build_pod(self) -> Nested[Any]: # Tier "0" corresponds to reserved; otherwise we use preemptible. tier = os.environ.get("BASTION_TIER", None) - if tier == "0" and cfg.reservation is not None: - logging.info("Found tier=%s in env. Using reservation=%s", tier, cfg.reservation) - selector.update({"cloud.google.com/reservation-name": cfg.reservation}) - labels.update({"bastion-tier": "reserved"}) - else: - logging.info("Found tier=%s in env. Using spot quota", tier) - selector.update({"cloud.google.com/gke-spot": "true"}) - tolerations.append( - { - "key": "cloud.google.com/gke-spot", - "operator": "Equal", - "value": "true", - "effect": "NoSchedule", - } - ) - labels.update({"bastion-tier": "spot"}) + # skip reservation/spot flags for Pathways CPU jobs. + if job_type not in ("rm", "proxy", "user"): + if tier == "0" and cfg.reservation is not None: + logging.info("Found tier=%s in env. Using reservation=%s", tier, cfg.reservation) + selector.update({"cloud.google.com/reservation-name": cfg.reservation}) + labels.update({"bastion-tier": "reserved"}) + else: + logging.info("Found tier=%s in env. Using spot quota", tier) + selector.update({"cloud.google.com/gke-spot": "true"}) + tolerations.append( + { + "key": "cloud.google.com/gke-spot", + "operator": "Equal", + "value": "true", + "effect": "NoSchedule", + } + ) + labels.update({"bastion-tier": "spot"}) if cfg.enable_tpu_ici_resiliency is not None: selector.update( @@ -701,7 +799,7 @@ def _build_pod(self) -> Nested[Any]: PRE_PROVISIONER_LABEL: cfg.name, } ) - else: + elif not cfg.use_pathways: # Used by GCP auto-provisioner. selector.update( { @@ -751,6 +849,27 @@ def _build_pod(self) -> Nested[Any]: } ) + if cfg.use_pathways: + volumes.append( + dict( + hostPath=dict( + path="/tmp", + type="DirectoryOrCreate", + ), + name="shared-tmp", + ) + ) + + if job_type in ("rm", "proxy", "user"): + selector.update({"cloud.google.com/gke-nodepool": f"cpu-{job_type}-np"}) + else: + selector.update( + { + "cloud.google.com/gke-tpu-accelerator": system.gke_accelerator, + "cloud.google.com/gke-tpu-topology": system.topology, + } + ) + # Hardcode metadata.google.internal ip address to avoid transient DNS resolution issue. metadata_host_alias = dict( ip=_METADATA_GOOGLE_INTERNAL_IP, @@ -765,15 +884,15 @@ def _build_pod(self) -> Nested[Any]: # https://kubernetes.io/docs/tasks/network/customize-hosts-file-for-pods/#adding-additional-entries-with-hostaliases hostAliases=[metadata_host_alias], nodeSelector={ - "cloud.google.com/gke-tpu-accelerator": system.gke_accelerator, - "cloud.google.com/gke-tpu-topology": system.topology, **selector, }, tolerations=tolerations, - containers=[self._build_container()], + containers=[self._build_container(job_type)], initContainers=[self._build_uploader_container()], serviceAccountName=cfg.service_account, volumes=volumes, + hostNetwork=True, + dnsPolicy="ClusterFirstWithHostNet", ) if cfg.priority_class: @@ -784,7 +903,7 @@ def _build_pod(self) -> Nested[Any]: spec=spec, ) - def _build_job(self) -> Nested[Any]: + def _build_job(self, job_type: str = None) -> Nested[Any]: """Builds a config for a single Job, which is a set of Pods. https://kubernetes.io/docs/concepts/workloads/controllers/job/ @@ -793,6 +912,32 @@ def _build_job(self) -> Nested[Any]: A nested dict corresponding to a k8s Job config, including the job metadata and spec. """ system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type] + + if job_type == "worker": + return dict( + metadata=dict( + annotations={ + # pylint: disable=line-too-long + "alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool" + } + ), + spec=dict( + parallelism=system.vms_per_slice, + completions=system.vms_per_slice, + backoffLimit=system.vms_per_slice * 4, + template=self._build_pod(job_type), + ), + ) + elif job_type in ("rm", "proxy", "user"): + return dict( + spec=dict( + parallelism=1, + completions=1, + backoffLimit=0, + template=self._build_pod(job_type), + ), + ) + return dict( spec=dict( parallelism=system.vms_per_slice, @@ -812,31 +957,72 @@ def _build_jobset(self) -> Nested[Any]: """ cfg: TPUGKEJob.Config = self.config - annotations = { - # The exclusive topology annotation will ensure that all Pods will have affinity - # rules added that will ensure that they are fully scheduled on the same - # pod-slice node-pools. - "alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool", - } + annotations, labels = {}, {} + + if not cfg.use_pathways: + annotations.update( + { + # The exclusive topology annotation will ensure that all Pods will have affinity + # rules added that will ensure that they are fully scheduled on the same + # pod-slice node-pools. + "alpha.jobset.sigs.k8s.io/exclusive-topology": "cloud.google.com/gke-nodepool", + } + ) + if cfg.queue: - annotations["kueue.x-k8s.io/queue-name"] = cfg.queue + if cfg.use_pathways: + labels["kueue.x-k8s.io/queue-name"] = cfg.queue + else: + annotations["kueue.x-k8s.io/queue-name"] = cfg.queue - return dict( - metadata=dict( - name=cfg.name, - annotations=annotations, - ), - spec=dict( + spec = dict( + failurePolicy=dict(maxRestarts=cfg.max_tries - 1), + replicatedJobs=[ + # NOTE: the suffix here impacts how long job names can be. + dict( + name="job", + replicas=cfg.accelerator.num_replicas, + template=self._build_job(), + ), + ], + ) + + if cfg.use_pathways: + logging.info("Building pathways jobset.") + spec = dict( failurePolicy=dict(maxRestarts=cfg.max_tries - 1), + successPolicy=dict(operator="All", targetReplicatedJobs=["user"]), replicatedJobs=[ - # NOTE: the suffix here impacts how long job names can be. dict( - name="job", + name="worker", replicas=cfg.accelerator.num_replicas, - template=self._build_job(), + template=self._build_job("worker"), + ), + dict( + name="rm", + replicas=1, + template=self._build_job("rm"), + ), + dict( + name="proxy", + replicas=1, + template=self._build_job("proxy"), + ), + dict( + name="user", + replicas=1, + template=self._build_job("user"), ), ], + ) + + return dict( + metadata=dict( + name=cfg.name, + annotations=annotations, + labels=labels, ), + spec=spec, ) def _delete(self):