Skip to content

[llama4 -> llama4_jax] Refactor to be a proper installable Python package #9

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

SamuelMarks
Copy link

@SamuelMarks SamuelMarks commented Apr 17, 2025

[llama4 -> llama4_jax] Refactor to be a proper installable Python package ; [llama4_jax/pyproject.toml] Add missing dependency ; [llama4_jax/README.md] Document new usage

Also I know you have path = Path(path).expanduser().absolute() but that doesn't provide nice --help text and should be expanded earlier anyway. I can remove your Path(path).expanduser() if you give the go-ahead.

PS: If you merge this PR I can send you a new PR for deepseek_r1_jax

…kage ; [llama4_jax/pyproject.toml] Add missing dependency ; [llama4_jax/README.md] Document new usage
Copy link
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! I left some comments

@@ -44,9 +37,9 @@ def main(model_path: str | Path, ckpt_path: str | Path):

additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
for additional_file in additional_files:
full_paths = list(model_path.glob(f"**/{additional_file}"))
full_paths = list(model_path.glob(os.path.join("**", additional_file)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we stick with pathlib.Path throughout here, it largely replaces os.path.*

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After 11 years maybe I can rely on it being available!

os.path.join is actually used by it internally IIRC; but sure can move to pathlib for my jax-ml/jax-llm-examples contributions.

)
parser.add_argument(
"--dest-root-path",
required=True,
default="~/",
default=os.path.join(os.path.expanduser("~"), ""),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Path("~/").expanduser()

@@ -21,7 +16,8 @@ def main(path: str | Path, suffix: str):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"--path", default="~/DeepSeek-R1-Distill-Llama-70B", required=True, help="Existing JAX model checkpoint path"
"--path", default=os.path.join(os.path.expanduser("~"), "DeepSeek-R1-Distill-Llama-70B"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Path("~/DeepSeek-R1-Distill-Llama-70B").expanduser()

@@ -1,7 +1,7 @@
[project]
name = "llama4_jax"
version = "0.1.0"
description = ""
description = "Pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

@@ -11,8 +11,8 @@ This is a pure JAX implementation of Llama 4 inference, including a checkpoint
converter for the weights. It currently runs on TPU. Support for GPU is
in-progress.

The entire model is defined in [model.py](llama4_jax/model.py) and invoked
via [main.py](main.py). Among other things, the model code demonstrates:
The entire model is defined in [__main__.py](llama4_jax/__main__.py) and invoked
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep main.py separate, I think an explicit script might be more clear that it's just a starting point for a larger program rather than default behavior for the llama4 module

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdyro It's your call. But it is very nonstandard Python, the __main__.py properly integrates… seeing itself as within Python module; being relocatable; and installable.

One other idea is to create a hierarchy:

.
├── jax_examples_cli
├── deepseek_r1_jax
└── llama4_jax

All packages are installable. There are no main.pys. Outside of jax_examples_cli there are no __main__.pys. In the Command Line Interface it finds installed packages and/or packages within a specific location (e.g., os.getcwd() or JAX_LLM_EXAMPLES env var); indicating where to find packages compatible with the jax_examples_cli/__main__.py.

Already your main.pys are very similar, it wouldn't be hard to hoist them up. Usage would be: jax_examples_cli --model <deepseek-r1-jax | llama4_jax> --ckpt_path …

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants