-
Notifications
You must be signed in to change notification settings - Fork 13
[llama4 -> llama4_jax] Refactor to be a proper installable Python package #9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…kage ; [llama4_jax/pyproject.toml] Add missing dependency ; [llama4_jax/README.md] Document new usage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! I left some comments
@@ -44,9 +37,9 @@ def main(model_path: str | Path, ckpt_path: str | Path): | |||
|
|||
additional_files = ["config.json", "tokenizer.json", "tokenizer_config.json"] | |||
for additional_file in additional_files: | |||
full_paths = list(model_path.glob(f"**/{additional_file}")) | |||
full_paths = list(model_path.glob(os.path.join("**", additional_file))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we stick with pathlib.Path throughout here, it largely replaces os.path.*
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After 11 years maybe I can rely on it being available!
os.path.join
is actually used by it internally IIRC; but sure can move to pathlib
for my jax-ml/jax-llm-examples
contributions.
) | ||
parser.add_argument( | ||
"--dest-root-path", | ||
required=True, | ||
default="~/", | ||
default=os.path.join(os.path.expanduser("~"), ""), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Path("~/").expanduser()
@@ -21,7 +16,8 @@ def main(path: str | Path, suffix: str): | |||
if __name__ == "__main__": | |||
parser = ArgumentParser() | |||
parser.add_argument( | |||
"--path", default="~/DeepSeek-R1-Distill-Llama-70B", required=True, help="Existing JAX model checkpoint path" | |||
"--path", default=os.path.join(os.path.expanduser("~"), "DeepSeek-R1-Distill-Llama-70B"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Path("~/DeepSeek-R1-Distill-Llama-70B").expanduser()
@@ -1,7 +1,7 @@ | |||
[project] | |||
name = "llama4_jax" | |||
version = "0.1.0" | |||
description = "" | |||
description = "Pure JAX implementation of Llama 4 inference, including a checkpoint converter for the weights." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
@@ -11,8 +11,8 @@ This is a pure JAX implementation of Llama 4 inference, including a checkpoint | |||
converter for the weights. It currently runs on TPU. Support for GPU is | |||
in-progress. | |||
|
|||
The entire model is defined in [model.py](llama4_jax/model.py) and invoked | |||
via [main.py](main.py). Among other things, the model code demonstrates: | |||
The entire model is defined in [__main__.py](llama4_jax/__main__.py) and invoked |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep main.py
separate, I think an explicit script might be more clear that it's just a starting point for a larger program rather than default behavior for the llama4 module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rdyro It's your call. But it is very nonstandard Python, the __main__.py
properly integrates… seeing itself as within Python module; being relocatable; and installable.
One other idea is to create a hierarchy:
.
├── jax_examples_cli
├── deepseek_r1_jax
└── llama4_jax
All packages are installable. There are no main.py
s. Outside of jax_examples_cli
there are no __main__.py
s. In the Command Line Interface it finds installed packages and/or packages within a specific location (e.g., os.getcwd()
or JAX_LLM_EXAMPLES
env var); indicating where to find packages compatible with the jax_examples_cli/__main__.py
.
Already your main.py
s are very similar, it wouldn't be hard to hoist them up. Usage would be: jax_examples_cli --model <deepseek-r1-jax | llama4_jax> --ckpt_path …
Also I know you have
path = Path(path).expanduser().absolute()
but that doesn't provide nice--help
text and should be expanded earlier anyway. I can remove yourPath(path).expanduser()
if you give the go-ahead.PS: If you merge this PR I can send you a new PR for deepseek_r1_jax