-
Notifications
You must be signed in to change notification settings - Fork 216
Add Tensor Stream component for efficient safetensors-based model tensor streaming #3741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
e02fe7a
to
b9a16c4
Compare
…treaming Introduce new components for both server and clients to enable streaming of model tensors using safetensors serialization. When configured, these components replace the standard tensor serialization and transfer mechanisms in NVFlare with a streaming approach, optimizing the handling, transfer, and reinjection of tensors during distributed training rounds. This enhances memory efficiency and scalability for applications dealing with large model states.
b9a16c4
to
3472c82
Compare
47f93a6
to
c87fa64
Compare
c87fa64
to
14070a7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new "Tensor Stream" component for NVFlare that optimizes tensor transfer between server and clients by streaming tensors separately from task payloads using safetensors serialization.
Key Changes
- Implements streaming-based tensor transfer that replaces embedding tensors in main payload
- Uses safetensors for efficient serialization and streaming of individual tensors
- Provides separate components for server-side and client-side tensor streaming with automatic cleanup
Reviewed Changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
nvflare/app_opt/tensor_stream/server.py | Server-side streamer that sends task data tensors to clients and receives task results |
nvflare/app_opt/tensor_stream/client.py | Client-side streamer that receives task data from server and sends results back |
nvflare/app_opt/tensor_stream/sender.py | Component for sending tensors using StreamableEngine with torch/numpy conversion |
nvflare/app_opt/tensor_stream/receiver.py | Component for receiving tensors and reconstructing them into FL context |
nvflare/app_opt/tensor_stream/producer.py | Produces stream data from torch tensors using safetensors serialization |
nvflare/app_opt/tensor_stream/consumer.py | Consumes streamed tensor data and reconstructs tensors from safetensors |
nvflare/app_opt/tensor_stream/utils.py | Utility functions for cleaning task data/results and managing topics/targets |
nvflare/app_opt/tensor_stream/types.py | Type definitions and constants for tensor streaming |
nvflare/app_opt/tensor_stream/executors.py | Example executor showing tensor stream usage |
tests/unit_test/app_opt/tensor_stream/ | Comprehensive test suite covering all components |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 19 out of 19 changed files in this pull request and generated 8 comments.
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
0c81313
to
a1bcd45
Compare
@rfilgueiras one question since the tensor is streamed separately from the task data (payload), do we have a race condition where the tensor might be arrived the earlier than the task data ? How do we handle this ? |
This is a great question @chesterxgchen :) I had this issue and solved it by waiting for all transfers to be done before cleaning the context and releasing the tasks: Each client's task must not only wait for the transfer to the same client, but also until the transfer to all clients is done. I will introduce a timeout in the wait logic, which should be the maximum time it should wait for all the tensors to be transferred to all clients. Note that in this line, after each client is done, it will increment |
All producers share the same tensors dict and should treat it as read-only data.
replaced the | operator in type definitions with typing.Union.
I don't understand this part. Each client should only need to wait for its own tensor. It the tensor is delivered, then it can process. This is independent to all other clients. Some client can be fast or slow, depending on their network or computing resources. Waiting for all tensors from all clients seems not necessary. Are you afraid that the server task data is cleaned too earlier and before it it broadcasted to the individual site ? |
I think I understand what you trying to do here. You want to make sure the data are actually sent to all clients before you clean the task data.You wait for data "sent" ( not necessarily delivered) to all clients, then call clean data on server side. send() is non blocking call. This is ok. But what I was asking is different question. All above is on the server side, on the client side (Client Executor), is it possible to have a race condition that Task data and Tensor Steam arrived in different sequence |
- add parameter to define which tasks the tensor stream should be enabled - improved TensorServerStreamer syncronization before releasing tasks to clients - add root_keys auto-detection: removed parameter from Client, Server and Sender
feat: add tasks parameter, default "train" - add a parameter to define which tasks the tensor stream should be enabled for - improved TensorServerStreamer synchronization before releasing tasks to clients - add root_keys auto-detection: removed parameter from Client, Server, and Sender
- receiver: handle receiving multiple tensor maps (when multiple non-empty root keys are present) - server and client: clean-up received tensors after updating task
…e-dict fix: tensor received with root keys
Yes. I am reading the tensors from the fl_ctx without copying them. If I clean the context before they finish, it impacts the ones that are still processing the tensors. This was strange to me at first, since they might have different filters applied, etc. |
In this case, the send is synchronous and blocks on each thread (for each client) on the server side until all data is produced by the server and consumed by the client. I updated the logic to be clearer about what is happening. The send operation is triggered by the AFTER_TASK_DATA_FILTER event:
This is fine, the server sends the tensors, and they must arrive before the task data. When the task data arrives, the client streamer handles the BEFORE_TASK_DATA_FILTER event: It will ensure the task data is updated with the received tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very clever solution to use event handlers for tensor streaming. I wonder whether it could be generalized for other types of streaming.
With this solution, the streaming is between the server and the client. This works well only if the client is the one that does training. But this is not always the case. For example, in case of Client API based applications, the client does not do training. It is the external app that actually does training.
A better solution (for future) might be using the Object Downloader. Instead of pushing the tensors to the client, the server will wait for the trainer to pull tensors from it. The trainer could be the client, or the external app.
Yes, the main advantage of using event handlers is that neither the controller nor executors need to be modified. What other types of streaming services do you think of?
Yes, the current solution is not compatible with the ClientAPI for many reasons, even when it is running in-memory.
This would work in the opposite way than I do now. From Server -> Client (Task Data):
From Client -> Server (Task Result):
I would love to discuss how to improve these interfaces to facilitate such patterns:
One possibility would be:
|
I understand that using the |
If the task data or task results cannot be found in the FLContext, the send operation will not be performed. In this case, a warming will logged but task will not be aborted. It will fallback to the default serialization and transmission process. On the receiver side, the same applies.
Improve performance while keeping memory usage low. By leveraging generators data flow is faster. When doing IO the CPU bound code and prepare the next batch. Combined with chunking, it reduces CPU usage and latency. When testing with 5000 tensors (2.2GB) it reduced transfer times from 70s to 30s.
feat: add chunking with smart generators
- avoid possible memory leaks - reduce duplication when converting to and from numpy
…ation-when-handling-numpy-conversions fix: ensure memory is released as soon as possible
fix: use TASK_RESULT_RECEIVED instead of BEFORE_TASK_RESULT_FILTER. It ensures that for each task we will add the tensors back into the context.
Keep track of seens tasks and ensure received tensors are cleaned-up when the round is over.
…-of-peer-name feat: use task id to store and retrieve received tensors
Added new method to ensure the tensors arrived before trying to set them back into the task data/result.
When setting enable_request_task_data_tensors=True, the client will send a federated event to the server requesting the tensors to be sent. It means that the server can first send the task data without tensors and wait for the client to requet the tensors to be sent. It should allow supporing the ClientAPI in the future.
Sender: Added new method to parse and store the tensors only. This should be done before calling send.
It is common to have model params with mixed data, not only tensors. This is now handled ensuring other data will be preserved. A new approach to recursive dicts was implemented, ensuring now that all tensors will be serialized with one producer and consumer.
- Client: add support for requesting tensors after the task data has arrived - Server: update event handling to use `Sender.store_tensors` and `Receiver.wait_for_tensors` - Sender: decouple storing the model params reference (store_tensors method) from sending the tensors - Receive: improve logic to wait for tensors, since task data or result can arrive before the tensors are received - Producer: use the new tensor chunking logic from model params and update the data model to pass the location of tensors when they are part of nested dicts - Consumer: add logic to reconstruct the original params using received tensors and the new parent_keys field - Utils: add new functions to handle tensor chunking, update params, and extract non-tensor values from params
…data-from-context fix: clean task data and task result data should
Description
This PR introduces a new component for NVFlare called Tensor Stream, which optimizes the transfer of model tensors between server and clients by streaming them separately from the main task payload.
Instead of embedding tensors directly into task data/results, Tensor Stream serializes and streams each tensor individually using safetensors, ensuring efficient and scalable communication. On the receiving side, tensors are deserialized, stored in memory, and reinjected back into the task payload to preserve the original structure.
Server → Clients (Task Data)
Clients → Server (Task Results)
This approach replaces the standard tensor transfer in NVFlare, avoiding unnecessary copies and enabling more efficient handling of large model states.
Types of changes
./runtest.sh
.