-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add tensor parallel support to T5 via NxD #697
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an awesome first step: I need to investigate a bit more on the LLama side to see if this is actually compatible. The main differences I see for now are that the modeling is explicitly redefined instead of patched (because of several optimized layers/operations), and the export/compilation uses the new ModelBuilder (and I think eventually this will be mandatory).
@@ -112,6 +112,12 @@ def parse_args_neuronx(parser: "ArgumentParser"): | |||
choices=["bf16", "fp16", "tf32"], | |||
help='The data type to cast FP32 operations to when auto-cast mode is enabled. Can be `"bf16"`, `"fp16"` or `"tf32"`.', | |||
) | |||
optional_group.add_argument( | |||
"--tensor_parallel_size", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They tend to use tp_degree
in the AWS Neuron SDK documentation and APIs. I used num_cores
in the decoder, but now I regret it: can't we just align on their terminology ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all three terms were used before the PR; I tried to use the same term as the API I reuse. Indeed, we should agree on one to avoid confusion.
cc. @michaelbenayoun
@@ -174,6 +175,7 @@ def __init__( | |||
self._config = config | |||
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) | |||
self.mandatory_axes = () | |||
self.tp_degree = tensor_parallel_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hummm ... see my comment above
# Start trace | ||
if tp_degree > 1: | ||
# 1. use NxD to trace for parallel | ||
neuron_model = neuronx_distributed.trace.parallel_model_trace( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is ok in a first step, but for LLama example they are not using this anymore, but instead the ModelBuilder
class that wraps the model into NxDModel
classes that contains several sub-models with different input shapes (bucketing).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the use of bucketing mature and justified? I think we can start with parallel_model_trace
anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It goes a bit beyond that, because prefill / decode already use two different input shapes, not even mentioning bucketing, and using the builder allows to share the same weights between all the alternate graphs.
What does this PR do?
Fixes #317
Fixes #479
Add tensor parallel support to large T5 models.
optimum-cli export neuron --model google-t5/t5-small --tensor_parallel_size 2 --task text2text-generation --batch_size 1 --sequence_length 128 --num_beams 4 t5_neuronx_tp2/
optimum-cli export neuron --model google/flan-t5-xl --tensor_parallel_size 8 --task text2text-generation --batch_size 1 --sequence_length 128 --num_beams 4 flan_t5_xl_neuronx_tp8/
Before submitting