diff --git a/torchrec/distributed/test_utils/test_input.py b/torchrec/distributed/test_utils/test_input.py index 03609317b..c52235b90 100644 --- a/torchrec/distributed/test_utils/test_input.py +++ b/torchrec/distributed/test_utils/test_input.py @@ -8,13 +8,13 @@ # pyre-strict from dataclasses import dataclass -from typing import cast, List, Optional, Tuple, Union +from typing import cast, Dict, List, Optional, Tuple, Union import torch from tensordict import TensorDict from torchrec.distributed.embedding_types import EmbeddingTableConfig from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor from torchrec.streamable import Pipelineable @@ -535,13 +535,121 @@ def _create_batched_standard_kjts( return global_kjt, local_kjts -# @dataclass -# class VbModelInput(ModelInput): -# pass +@dataclass +class VbModelInput(ModelInput): + + @staticmethod + def _create_variable_batch_kjt( + keys: List[str], + world_size: int, + global_constant_batch: bool, + values_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]], + lengths_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]], + strides_per_rank_per_feature: Dict[int, Dict[str, int]], + inverse_indices_per_rank_per_feature: Dict[int, Dict[str, torch.Tensor]], + weights_per_rank_per_feature: Optional[Dict[int, Dict[str, torch.Tensor]]], + use_offsets: bool, + indices_dtype: torch.dtype, + offsets_dtype: torch.dtype, + lengths_dtype: torch.dtype, + ) -> KeyedJaggedTensor: + global_values = [] + global_lengths = [] + global_stride_per_key_per_rank = [] + inverse_indices_per_feature_per_rank = [] + global_weights = [] if weights_per_rank_per_feature is not None else None + + for key in keys: + sum_stride = 0 + for rank in range(world_size): + global_values.append(values_per_rank_per_feature[rank][key]) + global_lengths.append(lengths_per_rank_per_feature[rank][key]) + if weights_per_rank_per_feature is not None: + assert global_weights is not None + global_weights.append(weights_per_rank_per_feature[rank][key]) + sum_stride += strides_per_rank_per_feature[rank][key] + inverse_indices_per_feature_per_rank.append( + inverse_indices_per_rank_per_feature[rank][key] + ) + + global_stride_per_key_per_rank.append([sum_stride]) + + inverse_indices_list: List[torch.Tensor] = [] + + for key in keys: + accum_batch_size = 0 + inverse_indices = [] + + for rank in range(world_size): + inverse_indices.append( + inverse_indices_per_rank_per_feature[rank][key] + accum_batch_size + ) + accum_batch_size += strides_per_rank_per_feature[rank][key] + + inverse_indices_list.append(torch.cat(inverse_indices)) + + global_inverse_indices = (keys, torch.stack(inverse_indices_list)) + + if global_constant_batch: + global_offsets = [] + + for length in global_lengths: + global_offsets.append(_to_offsets(length)) + + reindexed_lengths = [] + + for length, indices in zip( + global_lengths, inverse_indices_per_feature_per_rank + ): + reindexed_lengths.append(torch.index_select(length, 0, indices)) + + lengths = torch.cat(reindexed_lengths) + reindexed_values, reindexed_weights = [], [] + + for i, (values, offsets, indices) in enumerate( + zip(global_values, global_offsets, inverse_indices_per_feature_per_rank) + ): + for idx in indices: + reindexed_values.append(values[offsets[idx] : offsets[idx + 1]]) + if global_weights is not None: + reindexed_weights.append( + global_weights[i][offsets[idx] : offsets[idx + 1]] + ) + + values = torch.cat(reindexed_values) + weights = ( + torch.cat(reindexed_weights) if global_weights is not None else None + ) + global_stride_per_key_per_rank = None + global_inverse_indices = None + + else: + values = torch.cat(global_values) + lengths = torch.cat(global_lengths) + weights = torch.cat(global_weights) if global_weights is not None else None + + if use_offsets: + offsets = torch.cat( + [torch.tensor([0], dtype=offsets_dtype), lengths.cumsum(0)] + ) + return KeyedJaggedTensor( + keys=keys, + values=values, + offsets=offsets, + weights=weights, + stride_per_key_per_rank=global_stride_per_key_per_rank, + inverse_indices=global_inverse_indices, + ) + else: + return KeyedJaggedTensor( + keys=keys, + values=values, + lengths=lengths, + weights=weights, + stride_per_key_per_rank=global_stride_per_key_per_rank, + inverse_indices=global_inverse_indices, + ) -# @staticmethod -# def _create_variable_batch_kjt() -> KeyedJaggedTensor: -# pass # @staticmethod # def _merge_variable_batch_kjts(kjts: List[KeyedJaggedTensor]) -> KeyedJaggedTensor: