Skip to content

Commit a672bb8

Browse files
gwenzekvctrmnrenerocksai
authored
examples: add modernBERT (zml#192)
Modern bert implementation from Victor: @vctrmn Tested with https://huggingface.co/answerdotai/ModernBERT-base I started from his PR zml#149 and cleaned up a bit the code to make it more what I consider "idiomatic ZML" even if it's still a bit in flux since we are so young. --------- Co-authored-by: vctrmn <[email protected]> Co-authored-by: Rene Schallner <[email protected]>
1 parent d712ab7 commit a672bb8

File tree

12 files changed

+1030
-10
lines changed

12 files changed

+1030
-10
lines changed

examples/MODULE.bazel

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,53 @@ http_file(
139139
url = "https://github.com/karpathy/llama2.c/raw/c02865df300f3bd9e567ce061000dc23bf785a17/tokenizer.bin",
140140
)
141141

142+
# ModernBERT
143+
huggingface.model(
144+
name = "ModernBERT-base",
145+
build_file_content = """\
146+
package(default_visibility = ["//visibility:public"])
147+
filegroup(
148+
name = "model",
149+
srcs = ["model.safetensors"],
150+
)
151+
152+
filegroup(
153+
name = "tokenizer",
154+
srcs = ["tokenizer.json"],
155+
)
156+
""",
157+
commit = "94032bb66234a691cf6248265170006a7ced4970",
158+
includes = [
159+
"model.safetensors",
160+
"tokenizer.json",
161+
],
162+
model = "answerdotai/ModernBERT-base",
163+
)
164+
use_repo(huggingface, "ModernBERT-base")
165+
166+
huggingface.model(
167+
name = "ModernBERT-large",
168+
build_file_content = """\
169+
package(default_visibility = ["//visibility:public"])
170+
filegroup(
171+
name = "model",
172+
srcs = ["model.safetensors"],
173+
)
174+
175+
filegroup(
176+
name = "tokenizer",
177+
srcs = ["tokenizer.json"],
178+
)
179+
""",
180+
commit = "4bbcbf40bed02ce487125bcb3c897ea9bdc88340",
181+
includes = [
182+
"model.safetensors",
183+
"tokenizer.json",
184+
],
185+
model = "answerdotai/ModernBERT-large",
186+
)
187+
use_repo(huggingface, "ModernBERT-large")
188+
142189
bazel_dep(name = "rules_rust", version = "0.57.1")
143190
rust = use_extension("@rules_rust//rust:extensions.bzl", "rust")
144191
rust.toolchain(

examples/modernbert/BUILD.bazel

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
load("@zml//bazel:zig.bzl", "zig_cc_binary")
2+
3+
zig_cc_binary(
4+
name = "modernbert",
5+
srcs = ["modernbert.zig"],
6+
main = "main.zig",
7+
deps = [
8+
"@com_github_hejsil_clap//:clap",
9+
"@zml//async",
10+
"@zml//stdx",
11+
"@zml//zml",
12+
],
13+
)
14+
15+
cc_binary(
16+
name = "ModernBERT-base",
17+
args = [
18+
"--model=$(location @ModernBERT-base//:model.safetensors)",
19+
"--tokenizer=$(location @ModernBERT-base//:tokenizer)",
20+
"--num-attention-heads=12",
21+
"--tie-word-embeddings=true",
22+
],
23+
data = [
24+
"@ModernBERT-base//:model.safetensors",
25+
"@ModernBERT-base//:tokenizer",
26+
],
27+
deps = [":modernbert_lib"],
28+
)
29+
30+
cc_binary(
31+
name = "ModernBERT-large",
32+
args = [
33+
"--model=$(location @ModernBERT-large//:model.safetensors)",
34+
"--tokenizer=$(location @ModernBERT-large//:tokenizer)",
35+
"--num-attention-heads=16",
36+
"--tie-word-embeddings=true",
37+
],
38+
data = [
39+
"@ModernBERT-large//:model.safetensors",
40+
"@ModernBERT-large//:tokenizer",
41+
],
42+
deps = [":modernbert_lib"],
43+
)
44+
45+
zig_cc_binary(
46+
name = "test-implementation",
47+
srcs = ["modernbert.zig"],
48+
args = [
49+
"--model=$(location @ModernBERT-base//:model.safetensors)",
50+
],
51+
data = [
52+
"@ModernBERT-base//:model.safetensors",
53+
],
54+
main = "test.zig",
55+
tags = [
56+
"no_ci",
57+
],
58+
deps = [
59+
"@com_github_hejsil_clap//:clap",
60+
"@zml//async",
61+
"@zml//zml",
62+
],
63+
)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import logging
2+
import torch
3+
from transformers import pipeline
4+
from tools.zml_utils import ActivationCollector
5+
6+
logging.basicConfig(
7+
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s"
8+
)
9+
log = logging.getLogger(__name__)
10+
11+
MODEL_NAME: str = "answerdotai/ModernBERT-base"
12+
13+
14+
def main() -> None:
15+
try:
16+
log.info("Start running main()")
17+
18+
log.info(f"CPU capability : `{torch.backends.cpu.get_cpu_capability()}`")
19+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20+
21+
log.info(f"Loading model : `{MODEL_NAME}`")
22+
23+
fill_mask_pipeline = pipeline(
24+
"fill-mask",
25+
model=MODEL_NAME,
26+
device_map=device,
27+
)
28+
model, tokenizer = fill_mask_pipeline.model, fill_mask_pipeline.tokenizer
29+
log.info(
30+
f"Model loaded successfully {model.config.architectures} - `{model.config.torch_dtype}` - {tokenizer.model_max_length} max tokens" # noqa: E501
31+
)
32+
33+
# Wrap the pipeline, and extract activations.
34+
# Activations files can be huge for big models,
35+
# so let's stop collecting after 1000 layers.
36+
zml_pipeline = ActivationCollector(
37+
fill_mask_pipeline, max_layers=1000, stop_after_first_step=True
38+
)
39+
40+
input_text = "Paris is the [MASK] of France."
41+
outputs, activations = zml_pipeline(input_text)
42+
log.info(f"ouputs : {outputs}")
43+
44+
filename = MODEL_NAME.split("/")[-1] + ".activations.pt"
45+
torch.save(activations, filename)
46+
log.info(f"Saved {len(activations)} activations to {filename}")
47+
48+
log.info("End running main()")
49+
except Exception as exception:
50+
log.error(exception)
51+
raise
52+
53+
54+
if __name__ == "__main__":
55+
main()
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
torch
2+
transformers==4.48.1
3+
accelerate
4+
numpy==1.26.4

0 commit comments

Comments
 (0)