Skip to content

Commit d7f3ebd

Browse files
committed
feat: add readme + examples
1 parent 8101a58 commit d7f3ebd

File tree

6 files changed

+342
-2
lines changed

6 files changed

+342
-2
lines changed

README.md

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# TPUtils
2+
3+
Babysit your preemptible TPUs - in python
4+
5+
## Usage
6+
7+
### Long-running preemptible training
8+
9+
For example, the following code can be used to create a production-ready v3-256 using
10+
the [HomebrewNLP-Jax](https://github.com/HomebrewNLP/HomebrewNLP-Jax) codebase (
11+
see [examples/pod.py](https://github.com/clashluke/tputils/blob/main/examples/pod.py) for an executable version):
12+
13+
```PYTHON
14+
import dataclasses
15+
import typing
16+
from netrc import netrc
17+
18+
import wandb
19+
import yaml
20+
21+
from tputils import exec_command, exec_on_tpu, send_to_tpu, start_single
22+
23+
_, _, wandb_key = netrc().authenticators("api.wandb.ai")
24+
25+
26+
@dataclasses.dataclass
27+
class Context:
28+
retry: int
29+
30+
31+
ZONE = "europe-west4-a"
32+
HOST = "big-pod"
33+
RUN_NAME = "256-core-tpu"
34+
35+
36+
def load_config(ctx: Context):
37+
with open("config.yaml", 'r') as f:
38+
config = f.read()
39+
config = yaml.safe_load(config)
40+
41+
wandb_api = wandb.Api()
42+
config["training"]["do_checkpoint"] = True
43+
base_checkpoint_path = config["training"]["checkpoint_path"]
44+
45+
start_step = 0
46+
for run in wandb_api.runs(f"{config['wandb']['entity']}/{config['wandb']['project']}"):
47+
if run.name == config['wandb']['name']:
48+
start_step = run.summary["_step"]
49+
break
50+
start_step -= start_step % config["training"]["checkpoint_interval"]
51+
52+
config["training"]["start_step"] = start_step
53+
config["wandb"]["name"] = f"{RUN_NAME}-{ctx.retry}"
54+
if ctx.retry > 0:
55+
config["training"]["checkpoint_load_path"] = config["training"]["checkpoint_path"]
56+
config["training"]["checkpoint_path"] = f"{base_checkpoint_path}-{ctx.retry}"
57+
return yaml.dump(config)
58+
59+
60+
def start_fn(ctx: Context, worker: int):
61+
"""
62+
This function gets executed in threads to start a run on a new TPU. It receives the context object returned by
63+
`creation_callback` as well as the worker id which corresponds to the slice id this code was executed on in a
64+
multi-host setup. For single-host setups, such as v3-8s, the "worker" will always be set to 0.
65+
Ideally, it'd copy necessary files to the TPU and then run those. Here, `exec_command` can be used to create an
66+
execution command that automatically spawns a `screen` session which persists even when the SSH connection gets cut.
67+
"""
68+
send_to_tpu(ZONE, HOST, "config.yaml", load_config(ctx), worker)
69+
cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key)
70+
send_to_tpu(ZONE, HOST, "setup.sh", cmd, worker)
71+
exec_on_tpu(ZONE, HOST, "bash setup.sh", worker)
72+
73+
74+
def creation_callback(host: str, ctx: typing.Optional[Context]) -> Context:
75+
"""
76+
The `creation_callback` is called once whenever a new TPU gets created and can be used to persist state
77+
(such as retry counters) across multiple invocations.
78+
"""
79+
if ctx is None: # first invocation
80+
return Context(0)
81+
ctx.retry += 1
82+
return ctx
83+
84+
85+
def main(service_account: str, tpu_version: int = 3, slices: int = 32, preemptible: bool = True):
86+
start_single(host=HOST, tpu_version=tpu_version, zone=ZONE, preemptible=preemptible,
87+
service_account=service_account, slices=slices, start_fn=start_fn,
88+
creation_callback=creation_callback)
89+
```
90+
91+
### Sweeps
92+
93+
Similarly, large swarms of instances can be launched trivially using TPUtils. Here, we largely do the same setup as
94+
above, but call `launch_multiple` instead of `launch_single` which takes the additional argument `tpus` specifying the
95+
number of TPUs that should be launched and babysit. Depending on capacity and quota, the actual number of TPUs you get
96+
might be lower than the number of TPUs specified.
97+
98+
```PYTHON
99+
def main(service_account: str, tpus: int, tpu_version: int = 3, slices: int = 32, preemptible: bool = True):
100+
start_multiple(prefix=HOST, tpu_version=tpu_version, zone=ZONE, preemptible=preemptible,
101+
service_account=service_account, slices=slices, start_fn=start_fn,
102+
creation_callback=creation_callback, tpus=tpus)
103+
```
104+
105+
However, this would simply launch the same run many times. If you instead plan to register them with a
106+
[WandB Sweep](https://docs.wandb.ai/guides/sweeps/configuration), we need to modify the `start_fn` to join the wandb
107+
sweep.\
108+
By patching in the code below, TPUtils will start and maintain a large swarm of TPUs all working towards the same
109+
hyperparameter optimization problem.
110+
111+
```PYTHON
112+
with open("sweep.yaml", 'r') as f: # sweep config passed straight to wandb
113+
config = yaml.safe_load(f.read())
114+
sweep_id = wandb.sweep(config, entity="homebrewnlp", project="gpt")
115+
116+
117+
def start_fn(ctx: Context, worker: int):
118+
cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key,
119+
run_command=f"/home/ubuntu/.local/bin/wandb agent {sweep_id}")
120+
send_to_tpu(ZONE, HOST, "setup.sh", cmd, worker)
121+
exec_on_tpu(ZONE, HOST, "bash setup.sh", worker)
122+
```
123+
124+
The full executable code can be found in [examples/sweep.py](https://github.com/clashluke/tputils/blob/main/examples/sweep.py).
125+
126+
Similarly, the `start_fn` could be adapted to start an inference server
127+
for [HomebrewNLP](https://github.com/HomebrewNLP/HomebrewNLP-Jax/)
128+
or [Craiyon](https://huggingface.co/spaces/dalle-mini/dalle-mini) or even execute machine learning unit-tests in
129+
parallel.

build.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
rm -rf dist/*
2+
python3 setup.py sdist bdist_wheel
3+
twine upload dist/*

examples/pod.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import argparse
2+
import dataclasses
3+
import typing
4+
from netrc import netrc
5+
6+
import wandb
7+
import yaml
8+
9+
from tputils import exec_command, exec_on_tpu, send_to_tpu, start_single, synchronous_deletion
10+
11+
_, _, wandb_key = netrc().authenticators("api.wandb.ai")
12+
13+
14+
@dataclasses.dataclass
15+
class Context:
16+
retry: int
17+
zone: str
18+
host: str
19+
branch: str
20+
run_name: str
21+
data_path: str
22+
config_path: str
23+
24+
25+
def load_config(ctx: Context):
26+
with open(ctx.config_path, 'r') as f:
27+
config = f.read()
28+
config = yaml.safe_load(config)
29+
30+
wandb_api = wandb.Api()
31+
config["training"]["do_checkpoint"] = True
32+
base_checkpoint_path = config["training"]["checkpoint_path"]
33+
34+
start_step = 0
35+
for run in wandb_api.runs(f"{config['wandb']['entity']}/{config['wandb']['project']}"):
36+
if run.name == config['wandb']['name']:
37+
start_step = run.summary["_step"]
38+
break
39+
start_step -= start_step % config["training"]["checkpoint_interval"]
40+
41+
config["training"]["start_step"] = start_step
42+
config["data"]["path"] = ctx.data_path
43+
config["wandb"]["name"] = f"{ctx.run_name}-{ctx.retry}"
44+
if ctx.retry > 0:
45+
config["training"]["checkpoint_load_path"] = config["training"]["checkpoint_path"]
46+
config["training"]["checkpoint_path"] = f"{base_checkpoint_path}-{ctx.retry}"
47+
return yaml.dump(config)
48+
49+
50+
def start_fn(ctx: Context, worker: int):
51+
"""
52+
This function gets executed in threads to start a run on a new TPU. It receives the context object returned by
53+
`creation_callback` as well as the worker id which corresponds to the slice id this code was executed on in a
54+
multi-host setup. For single-host setups, such as v3-8s, the "worker" will always be set to 0.
55+
Ideally, it'd copy necessary files to the TPU and then run those. Here, `exec_command` can be used to create an
56+
execution command that automatically spawns a `screen` session which persists even when the SSH connection gets cut.
57+
"""
58+
send_to_tpu(ctx.zone, ctx.host, "config.yaml", load_config(ctx), worker)
59+
cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key,
60+
branch=ctx.branch)
61+
send_to_tpu(ctx.zone, ctx.host, "setup.sh", cmd, worker)
62+
exec_on_tpu(ctx.zone, ctx.host, "bash setup.sh", worker)
63+
64+
65+
def parse_args():
66+
parser = argparse.ArgumentParser()
67+
parser.add_argument("--host", type=str, help="Name of the TPU")
68+
parser.add_argument("--tpu-version", type=int, default=3, help="Which TPU version to create (v2-8 or v3-8)")
69+
parser.add_argument("--zone", type=str, default="europe-west4-a", help="GCP Zone TPUs get created in")
70+
parser.add_argument("--data-path", type=str, default="gs://ggpt4/the-char-pile/",
71+
help="Where the data is stored. Should be changed to a bucket in the correct region")
72+
parser.add_argument("--preemptible", default=1, type=int,
73+
help="Whether to create preemptible or non-preemptible TPUs")
74+
parser.add_argument("--service-account", type=str,
75+
help="Service account that controls permissions of TPU (for example, to ensure EU TPUs "
76+
"won't "
77+
"use US data)")
78+
parser.add_argument("--branch", type=str, default="main", help="Branch on github to use")
79+
parser.add_argument("--slices", default=1, type=int,
80+
help="How many TPU slices each TPU should have (1=>vX-8, 4=>vX-32)")
81+
parser.add_argument("--run-name", type=str, help="Prefix to use for all runs on WandB")
82+
parser.add_argument("--config-path", type=str, help="Path to config.yaml")
83+
parser.add_argument("--cleanup", default=0, type=int,
84+
help="Instead of running something new, kill all tpus. 1 or 0 for y/n")
85+
args = parser.parse_args()
86+
return args
87+
88+
89+
def main():
90+
args = parse_args()
91+
if args.cleanup:
92+
synchronous_deletion("", args.host, args.zone)
93+
return
94+
95+
def creation_callback(ctx: typing.Optional[Context]) -> Context:
96+
if ctx is None: # first invocation
97+
return Context(retry=0, zone=args.zone, host=args.host, branch=args.branch, run_name=args.run_name,
98+
data_path=args.data_path, config_path=args.config_path)
99+
ctx.retry += 1
100+
return ctx
101+
102+
return start_single(args.host, args.tpu_version, args.zone, args.preemptible, args.service_account,
103+
args.slices, start_fn, creation_callback)
104+
105+
106+
if __name__ == '__main__':
107+
main()

examples/sweep.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import argparse
2+
import dataclasses
3+
import typing
4+
from netrc import netrc
5+
6+
import wandb
7+
import yaml
8+
9+
from tputils import delete_all, exec_command, exec_on_tpu, send_to_tpu, start_multiple
10+
11+
_, _, wandb_key = netrc().authenticators("api.wandb.ai")
12+
13+
14+
@dataclasses.dataclass
15+
class Context:
16+
zone: str
17+
host: str
18+
sweep_id: str
19+
20+
21+
def start_fn(ctx: Context, worker: int):
22+
cmd = exec_command(repository="https://github.com/HomebrewNLP/HomebrewNLP-Jax", wandb_key=wandb_key,
23+
run_command=f"/home/ubuntu/.local/bin/wandb agent {ctx.sweep_id}")
24+
send_to_tpu(ctx.zone, ctx.host, "setup.sh", cmd, worker)
25+
exec_on_tpu(ctx.zone, ctx.host, "bash setup.sh", worker)
26+
27+
28+
def parse_args():
29+
parser = argparse.ArgumentParser()
30+
parser.add_argument("--prefix", type=str, help="Prefix used to identify TPUs")
31+
parser.add_argument("--tpu-version", type=int, default=3, help="Which TPU version to create (v2-8 or v3-8)")
32+
parser.add_argument("--zone", type=str, default="europe-west4-a", help="GCP Zone TPUs get created in")
33+
parser.add_argument("--preemptible", default=1, type=int,
34+
help="Whether to create preemptible or non-preemptible TPUs")
35+
parser.add_argument("--service-account", type=str,
36+
help="Service account that controls permissions of TPU (for example, to ensure EU TPUs "
37+
"won't use US data)")
38+
parser.add_argument("--branch", type=str, default="main", help="Branch on github to use")
39+
parser.add_argument("--slices", default=1, type=int,
40+
help="How many TPU slices each TPU should have (1=>vX-8, 4=>vX-32)")
41+
parser.add_argument("--config-path", type=str, help="Path to sweep's config.yaml")
42+
parser.add_argument("--cleanup", default=0, type=int,
43+
help="Instead of running something new, kill all tpus. 1 or 0 for y/n")
44+
args = parser.parse_args()
45+
return args
46+
47+
48+
def main():
49+
args = parse_args()
50+
51+
if args.cleanup:
52+
return delete_all(args.prefix, args.zone)
53+
54+
with open(args.config_path, 'r') as f:
55+
config = yaml.safe_load(f.read())
56+
sweep_id = wandb.sweep(config, entity="homebrewnlp", project="gpt")
57+
58+
def creation_callback(host: str, ctx: typing.Optional[Context]) -> Context:
59+
if ctx is None:
60+
return Context(zone=args.zone, host=host, sweep_id=sweep_id)
61+
return ctx
62+
63+
return start_multiple(args.host, args.tpu_version, args.zone, args.preemptible, args.service_account,
64+
args.slices, start_fn, creation_callback, args.tpus)
65+
66+
67+
if __name__ == '__main__':
68+
main()

setup.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import setuptools
2+
3+
4+
with open('README.md') as f:
5+
README = f.read()
6+
7+
setuptools.setup(
8+
author="Lucas Nestler",
9+
author_email="[email protected]",
10+
name='revlib',
11+
license='BSD',
12+
description=' Babysit your preemptible TPUs - in python ',
13+
version='0.0.1',
14+
long_description=README,
15+
url='https://github.com/clashluke/revlib',
16+
packages=setuptools.find_packages(),
17+
python_requires=">=3.7",
18+
long_description_content_type="text/markdown",
19+
install_requires=[],
20+
classifiers=[
21+
# Trove classifiers
22+
# (https://pypi.python.org/pypi?%3Aaction=list_classifiers)
23+
'Development Status :: 5 - Production/Stable',
24+
'License :: OSI Approved :: BSD License',
25+
'Programming Language :: Python',
26+
'Programming Language :: Python :: 3.7',
27+
'Programming Language :: Python :: 3.8',
28+
'Programming Language :: Python :: 3.9',
29+
'Topic :: Software Development :: Libraries',
30+
'Topic :: Software Development :: Libraries :: Python Modules',
31+
'Intended Audience :: Developers',
32+
],
33+
)

tputils/main.py renamed to tputils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def recreate(host: str, zone: str, tpu_version: int, preemptible: bool, service_
124124

125125
def start_single(host: str, tpu_version: int, zone: str, preemptible: bool, service_account: str, slices: int,
126126
start_fn: typing.Callable[[typing.Any, int], None],
127-
created_callback: typing.Callable[[typing.Any], typing.Any],
127+
creation_callback: typing.Callable[[str, typing.Any], typing.Any],
128128
creation_semaphore: typing.Optional[typing.ContextManager] = None):
129129
_, _, wandb_key = netrc.netrc().authenticators("api.wandb.ai")
130130

@@ -136,7 +136,7 @@ def start_single(host: str, tpu_version: int, zone: str, preemptible: bool, serv
136136
try:
137137
with creation_semaphore:
138138
recreate(host, zone, tpu_version, preemptible, service_account, slices)
139-
ctx = created_callback(ctx)
139+
ctx = creation_callback(host, ctx)
140140
threads = [threading.Thread(target=start_fn, args=(ctx, i)) for i in range(slices)]
141141
for t in threads:
142142
t.start()

0 commit comments

Comments
 (0)