Skip to content

support zero collision tables in ssd operator #2919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def _gen_named_parameters_by_table_ssd_pmt(
name as well as the parameter itself. The embedding table is in the form of
PartiallyMaterializedTensor to support windowed access.
"""
pmts = emb_module.split_embedding_weights()
pmts, _, _ = emb_module.split_embedding_weights()
for table_config, pmt in zip(config.embedding_tables, pmts):
table_name = table_config.name
emb_table = pmt
Expand Down Expand Up @@ -963,7 +963,7 @@ def state_dict(
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
Expand Down Expand Up @@ -1002,25 +1002,30 @@ def named_split_embedding_weights(
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(),
self.split_embedding_weights()[0],
):
key = append_prefix(prefix, f"{config.name}.weight")
yield key, tensor

def get_named_split_embedding_weights_snapshot(
self, prefix: str = ""
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
Tuple[
str,
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(no_snapshot=False),
self.split_embedding_weights(no_snapshot=False)[0],
):
key = append_prefix(prefix, f"{config.name}")
yield key, tensor
yield key, tensor, None, None

def flush(self) -> None:
"""
Expand All @@ -1037,9 +1042,11 @@ def purge(self) -> None:
self.emb_module.lxu_cache_state.fill_(-1)

# pyre-ignore [15]
def split_embedding_weights(
self, no_snapshot: bool = True
) -> List[PartiallyMaterializedTensor]:
def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
List[PartiallyMaterializedTensor],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
return self.emb_module.split_embedding_weights(no_snapshot)


Expand Down Expand Up @@ -1455,7 +1462,7 @@ def state_dict(
# in the case no_snapshot=False, a flush is required. we rely on the flush operation in
# ShardedEmbeddingBagCollection._pre_state_dict_hook()

emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_tables, _, _ = self.split_embedding_weights(no_snapshot=no_snapshot)
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
for emb_table in emb_table_config_copy:
emb_table.local_metadata.placement._device = torch.device("cpu")
Expand Down Expand Up @@ -1494,25 +1501,30 @@ def named_split_embedding_weights(
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(),
self.split_embedding_weights()[0],
):
key = append_prefix(prefix, f"{config.name}.weight")
yield key, tensor

def get_named_split_embedding_weights_snapshot(
self, prefix: str = ""
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterator[
Tuple[
str,
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
RocksDB snapshot to support windowed access.
"""
for config, tensor in zip(
self._config.embedding_tables,
self.split_embedding_weights(no_snapshot=False),
self.split_embedding_weights(no_snapshot=False)[0],
):
key = append_prefix(prefix, f"{config.name}")
yield key, tensor
yield key, tensor, None, None

def flush(self) -> None:
"""
Expand All @@ -1529,9 +1541,11 @@ def purge(self) -> None:
self.emb_module.lxu_cache_state.fill_(-1)

# pyre-ignore [15]
def split_embedding_weights(
self, no_snapshot: bool = True
) -> List[PartiallyMaterializedTensor]:
def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
List[PartiallyMaterializedTensor],
Optional[List[torch.Tensor]],
Optional[List[torch.Tensor]],
]:
return self.emb_module.split_embedding_weights(no_snapshot)


Expand Down
8 changes: 6 additions & 2 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,8 +952,12 @@ def post_state_dict_hook(
module._lookups, module._sharding_type_to_sharding.keys()
):
if sharding_type != ShardingType.DATA_PARALLEL.value:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for key, v in lookup.get_named_split_embedding_weights_snapshot():
for (
key,
v,
_,
_,
) in lookup.get_named_split_embedding_weights_snapshot(): # pyre-ignore
assert key in sharded_kvtensors_copy
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
for (
Expand Down
18 changes: 16 additions & 2 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,14 @@ def named_parameters_by_table(

def get_named_split_embedding_weights_snapshot(
self,
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
) -> Iterator[
Tuple[
str,
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
Expand Down Expand Up @@ -657,7 +664,14 @@ def named_parameters_by_table(

def get_named_split_embedding_weights_snapshot(
self,
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
) -> Iterator[
Tuple[
str,
Union[ShardedTensor, PartiallyMaterializedTensor],
Optional[ShardedTensor],
Optional[ShardedTensor],
]
]:
"""
Return an iterator over embedding tables, yielding both the table name as well as the embedding
table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
Expand Down
8 changes: 6 additions & 2 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,8 +1091,12 @@ def post_state_dict_hook(
sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
for lookup, sharding in zip(module._lookups, module._embedding_shardings):
if not isinstance(sharding, DpPooledEmbeddingSharding):
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
for key, v in lookup.get_named_split_embedding_weights_snapshot():
for (
key,
v,
_,
_,
) in lookup.get_named_split_embedding_weights_snapshot(): # pyre-ignore
assert key in sharded_kvtensors_copy
sharded_kvtensors_copy[key].local_shards()[0].tensor = v
for (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,25 @@ def _copy_ssd_emb_modules(
"SSDEmbeddingBag or SSDEmbeddingBag."
)

emb1_kv = dict(
emb_module1.get_named_split_embedding_weights_snapshot()
)
emb1_kv = {
t: (w, w_id, bucket_cnt)
for t, w, w_id, bucket_cnt in emb_module1.get_named_split_embedding_weights_snapshot()
}
for (
k,
v,
t,
w,
_,
_,
) in emb_module2.get_named_split_embedding_weights_snapshot():
v1 = emb1_kv.get(k)
v1_full_tensor = v1.full_tensor()
w1 = emb1_kv[t][0]
w1_full_tensor = w1.full_tensor()

# write value into ssd for both emb module for later comparison
v.wrapped.set_range(
0, 0, v1_full_tensor.size(0), v1_full_tensor
w.wrapped.set_range(
0, 0, w1_full_tensor.size(0), w1_full_tensor
)
v1.wrapped.set_range(
0, 0, v1_full_tensor.size(0), v1_full_tensor
w1.wrapped.set_range(
0, 0, w1_full_tensor.size(0), w1_full_tensor
)

# purge after loading. This is needed, since we pass a batch
Expand Down Expand Up @@ -682,11 +685,26 @@ def _copy_ssd_emb_modules(
"SSDEmbeddingBag or SSDEmbeddingBag."
)

weights = emb_module1.emb_module.debug_split_embedding_weights()
# need to set emb_module1 as well, since otherwise emb_module1 would
# produce a random debug_split_embedding_weights everytime
_load_split_embedding_weights(emb_module1, weights)
_load_split_embedding_weights(emb_module2, weights)
emb1_kv = {
t: (w, w_id, bucket_cnt)
for t, w, w_id, bucket_cnt in emb_module1.get_named_split_embedding_weights_snapshot()
}
for (
t,
w,
_,
_,
) in emb_module2.get_named_split_embedding_weights_snapshot():
w1 = emb1_kv[t][0]
w1_full_tensor = w1.full_tensor()

# write value into ssd for both emb module for later comparison
w.wrapped.set_range(
0, 0, w1_full_tensor.size(0), w1_full_tensor
)
w1.wrapped.set_range(
0, 0, w1_full_tensor.size(0), w1_full_tensor
)

# purge after loading. This is needed, since we pass a batch
# through dmp when instantiating them.
Expand Down
Loading