Skip to content

Commit

Permalink
[Bugfix] Fix device ordinal for multi-node spec decode (#13269)
Browse files Browse the repository at this point in the history
Signed-off-by: Shangming Cai <[email protected]>
  • Loading branch information
ShangmingCai authored Feb 19, 2025
1 parent 377d10b commit 5ae9f26
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import (broadcast_tensor_dict,
get_tp_group,
tensor_model_parallel_gather)
from vllm.logger import init_logger
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
Expand Down Expand Up @@ -365,7 +366,7 @@ def init_device(self) -> None:
target_lm_head_weight)

self._metrics.init_tensors(self.rank, device_type=self.device)
self.spec_decode_sampler.init_tensors(self.rank,
self.spec_decode_sampler.init_tensors(get_tp_group().local_rank,
device_type=self.device)

scorer_cls: Type[SpeculativeScorer]
Expand Down

0 comments on commit 5ae9f26

Please sign in to comment.