Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lora] Support long context lora #4787

Merged
merged 29 commits into from May 18, 2024

Conversation

rkooo567
Copy link
Collaborator

@rkooo567 rkooo567 commented May 13, 2024

Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.

It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.

Follow up of https://github.com/vllm-project/vllm/pull/3095/files


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@@ -1193,3 +1203,204 @@ def can_replace_layer(cls, source_layer: nn.Module,
model_config: Optional[PretrainedConfig]) -> bool:
# Special handling for the LogitsProcessor.
return False


class LoRALinearScalingRotaryEmbedding(LinearScalingRotaryEmbedding):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

return query, key


class LoRAPagedAttentionWithRoPE(LoRALayer):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed

@Yard1 Yard1 self-requested a review May 13, 2024 17:18
@rkooo567 rkooo567 changed the title [WIP] Support long context lora [Lora] Support long context lora May 15, 2024
@@ -21,6 +21,41 @@
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader import get_model

LONG_LORA_INFOS = [
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be moved to hf hub

)
return lora_llm

def test_batched_rope_kernel(self, long_context_infos):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently this test has illegal memory access

@rkooo567
Copy link
Collaborator Author

The PR is ready to be reviewed

@richardliaw
Copy link
Collaborator

tests failing?

@rkooo567
Copy link
Collaborator Author

that' expected because we need to move fine tuned lora model to hf hub so that CI can access it.

will be done soon. The test succeeds locally.

@rkooo567
Copy link
Collaborator Author

Use hf instead of s3 now

Copy link
Collaborator Author

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 compared to internal version;

  • clean up on some of variables and logic
  • Added a test to test_layers.py
  • Replaced rotary embedding instead of self.attn
  • fixed a bug where rotary emb is only replacing the first layer
  • download lora fine tuned models from hf hub I created

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good! Some small comments

Comment on lines 653 to 656
prompt_limit = (seq_group.lora_request.long_lora_max_len
if seq_group.lora_request
and seq_group.lora_request.long_lora_max_len else
self.prompt_limit)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put it in a method/function

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed self.prompt_limit btw because in this case, it doesn't make much sense to have it.

vllm/engine/arg_utils.py Outdated Show resolved Hide resolved
vllm/engine/output_processor/stop_checker.py Outdated Show resolved Hide resolved
vllm/lora/layers.py Outdated Show resolved Hide resolved
@@ -28,6 +28,7 @@ def test_load_checkpoints(
# and the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_lora_files,
4096,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use kwarg instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default None == meaning it will use scaling factor 1. Also added docstring

tests/lora/test_lora_manager.py Outdated Show resolved Hide resolved
@rkooo567
Copy link
Collaborator Author

@Yard1 everything addressed

@rkooo567 rkooo567 merged commit 2e9a222 into vllm-project:main May 18, 2024
61 checks passed
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 19, 2024
Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.

It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.

Follow up of https://github.com/vllm-project/vllm/pull/3095/files
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.

It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.

Follow up of https://github.com/vllm-project/vllm/pull/3095/files
@AllenDou
Copy link
Contributor

May I ask a question? Only LinearScalingRotaryEmbedding requires the 'withlora' version. Do DynamicNTKScalingRotaryEmbedding/YaRNScalingRotaryEmbedding/Phi3SuScaledRotaryEmbedding also require it?

@rkooo567
Copy link
Collaborator Author

I think those are not working with long context multi lora. In order to get it working, I think other rotary embedding should also support multi scaling factors like we are doing for LinearScalingRotaryEmbedding .

@AllenDou
Copy link
Contributor

ok, thanks for your response :)

tybalex pushed a commit to tybalex/vllm-function-call that referenced this pull request May 25, 2024
Currently we need to call rotary embedding kernel for each LoRA, which makes it hard to serve multiple long context length LoRA. Add batched rotary embedding kernel and pipe it through.

It replaces the rotary embedding layer to the one that is aware of multiple cos-sin-cache per scaling factors.

Follow up of https://github.com/vllm-project/vllm/pull/3095/files
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants