|
| 1 | +# TPUtils |
| 2 | + |
| 3 | +Babysit your preemptible TPUs - in python |
| 4 | + |
| 5 | +## Usage |
| 6 | + |
| 7 | +### Long-running preemptible training |
| 8 | + |
| 9 | +For example, the following code can be used to create a production-ready v3-256 using |
| 10 | +the [HomebrewNLP-Jax](https://github.com/HomebrewNLP/HomebrewNLP-Jax) codebase ( |
| 11 | +see [examples/pod.py](https://github.com/clashluke/tputils/blob/main/examples/pod.py) for an executable version): |
| 12 | + |
| 13 | +```PYTHON |
| 14 | +import dataclasses |
| 15 | +import typing |
| 16 | +from netrc import netrc |
| 17 | + |
| 18 | +import wandb |
| 19 | +import yaml |
| 20 | + |
| 21 | +from tputils import exec_command, exec_on_tpu, send_to_tpu, start_single |
| 22 | + |
| 23 | +_, _, wandb_key = netrc().authenticators("api.wandb.ai") |
| 24 | + |
| 25 | + |
| 26 | +@dataclasses.dataclass |
| 27 | +class Context: |
| 28 | + retry: int |
| 29 | + |
| 30 | + |
| 31 | +ZONE = "europe-west4-a" |
| 32 | +HOST = "big-pod" |
| 33 | +RUN_NAME = "256-core-tpu" |
| 34 | + |
| 35 | + |
| 36 | +def load_config(ctx: Context): |
| 37 | + with open("config.yaml", 'r') as f: |
| 38 | + config = f.read() |
| 39 | + config = yaml.safe_load(config) |
| 40 | + |
| 41 | + wandb_api = wandb.Api() |
| 42 | + config["training"]["do_checkpoint"] = True |
| 43 | + base_checkpoint_path = config["training"]["checkpoint_path"] |
| 44 | + |
| 45 | + start_step = 0 |
| 46 | + for run in wandb_api.runs(f"{config['wandb']['entity']}/{config['wandb']['project']}"): |
| 47 | + if run.name == config['wandb']['name']: |
| 48 | + start_step = run.summary["_step"] |
| 49 | + break |
| 50 | + start_step -= start_step % config["training"]["checkpoint_interval"] |
| 51 | + |
| 52 | + config["training"]["start_step"] = start_step |
| 53 | + config["wandb"]["name"] = f"{RUN_NAME}-{ctx.retry}" |
| 54 | + if ctx.retry > 0: |
| 55 | + config["training"]["checkpoint_load_path"] = config["training"]["checkpoint_path"] |
| 56 | + config["training"]["checkpoint_path"] = f"{base_checkpoint_path}-{ctx.retry}" |
| 57 | + return yaml.dump(config) |
| 58 | + |
| 59 | + |
| 60 | +def start_fn(ctx: Context, worker: int): |
| 61 | + """ |
| 62 | + This function gets executed in threads to start a run on a new TPU. It receives the context object returned by |
| 63 | + `creation_callback` as well as the worker id which corresponds to the slice id this code was executed on in a |
| 64 | + multi-host setup. For single-host setups, such as v3-8s, the "worker" will always be set to 0. |
| 65 | + Ideally, it'd copy necessary files to the TPU and then run those. Here, `exec_command` can be used to create an |
| 66 | + execution command that automatically spawns a `screen` session which persists even when the SSH connection gets cut. |
| 67 | + """ |
| 68 | + send_to_tpu(ZONE, HOST, "config.yaml", load_config(ctx), worker) |
| 69 | + cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key) |
| 70 | + send_to_tpu(ZONE, HOST, "setup.sh", cmd, worker) |
| 71 | + exec_on_tpu(ZONE, HOST, "bash setup.sh", worker) |
| 72 | + |
| 73 | + |
| 74 | +def creation_callback(host: str, ctx: typing.Optional[Context]) -> Context: |
| 75 | + """ |
| 76 | + The `creation_callback` is called once whenever a new TPU gets created and can be used to persist state |
| 77 | + (such as retry counters) across multiple invocations. |
| 78 | + """ |
| 79 | + if ctx is None: # first invocation |
| 80 | + return Context(0) |
| 81 | + ctx.retry += 1 |
| 82 | + return ctx |
| 83 | + |
| 84 | + |
| 85 | +def main(service_account: str, tpu_version: int = 3, slices: int = 32, preemptible: bool = True): |
| 86 | + start_single(host=HOST, tpu_version=tpu_version, zone=ZONE, preemptible=preemptible, |
| 87 | + service_account=service_account, slices=slices, start_fn=start_fn, |
| 88 | + creation_callback=creation_callback) |
| 89 | +``` |
| 90 | + |
| 91 | +### Sweeps |
| 92 | + |
| 93 | +Similarly, large swarms of instances can be launched trivially using TPUtils. Here, we largely do the same setup as |
| 94 | +above, but call `launch_multiple` instead of `launch_single` which takes the additional argument `tpus` specifying the |
| 95 | +number of TPUs that should be launched and babysit. Depending on capacity and quota, the actual number of TPUs you get |
| 96 | +might be lower than the number of TPUs specified. |
| 97 | + |
| 98 | +```PYTHON |
| 99 | +def main(service_account: str, tpus: int, tpu_version: int = 3, slices: int = 32, preemptible: bool = True): |
| 100 | + start_multiple(prefix=HOST, tpu_version=tpu_version, zone=ZONE, preemptible=preemptible, |
| 101 | + service_account=service_account, slices=slices, start_fn=start_fn, |
| 102 | + creation_callback=creation_callback, tpus=tpus) |
| 103 | +``` |
| 104 | + |
| 105 | +However, this would simply launch the same run many times. If you instead plan to register them with a |
| 106 | +[WandB Sweep](https://docs.wandb.ai/guides/sweeps/configuration), we need to modify the `start_fn` to join the wandb |
| 107 | +sweep.\ |
| 108 | +By patching in the code below, TPUtils will start and maintain a large swarm of TPUs all working towards the same |
| 109 | +hyperparameter optimization problem. |
| 110 | + |
| 111 | +```PYTHON |
| 112 | +with open("sweep.yaml", 'r') as f: # sweep config passed straight to wandb |
| 113 | + config = yaml.safe_load(f.read()) |
| 114 | +sweep_id = wandb.sweep(config, entity="homebrewnlp", project="gpt") |
| 115 | + |
| 116 | + |
| 117 | +def start_fn(ctx: Context, worker: int): |
| 118 | + cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key, |
| 119 | + run_command=f"/home/ubuntu/.local/bin/wandb agent {sweep_id}") |
| 120 | + send_to_tpu(ZONE, HOST, "setup.sh", cmd, worker) |
| 121 | + exec_on_tpu(ZONE, HOST, "bash setup.sh", worker) |
| 122 | +``` |
| 123 | + |
| 124 | +The full executable code can be found in [examples/sweep.py](https://github.com/clashluke/tputils/blob/main/examples/sweep.py). |
| 125 | + |
| 126 | +Similarly, the `start_fn` could be adapted to start an inference server |
| 127 | +for [HomebrewNLP](https://github.com/HomebrewNLP/HomebrewNLP-Jax/) |
| 128 | +or [Craiyon](https://huggingface.co/spaces/dalle-mini/dalle-mini) or even execute machine learning unit-tests in |
| 129 | +parallel. |
0 commit comments