Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDY] add sdy shard-group unification pass. This cl unifies shard group ids by "merging" groups which contain tensors that belong to more than one group across the module. Additionally it reindexes the sharding group ids to not have any gaps after merging. For example #46

Merged
merged 1 commit into from
Aug 10, 2024

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Aug 7, 2024

[SDY] add sdy shard-group unification pass. This cl unifies shard group ids by "merging" groups which contain tensors that belong to more than one group across the module. Additionally it reindexes the sharding group ids to not have any gaps after merging. For example

module {
func.func @test(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) {
  sdy.sharding_group %arg0 group_id = 12 : tensor<4xf32>
  sdy.sharding_group %arg0 group_id = 25 : tensor<4xf32>
  sdy.sharding_group %arg0 group_id = 39 : tensor<4xf32>
  sdy.sharding_group %arg1 group_id = 39 : tensor<4xf32>
  sdy.sharding_group %arg2 group_id = 44 : tensor<4xf32>
  func.return
}
}

would canonicalize to

module {
func.func @test(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32) {
  sdy.sharding_group %arg0 group_id = 1 : tensor<4xf32>
  sdy.sharding_group %arg0 group_id = 1 : tensor<4xf32>
  sdy.sharding_group %arg0 group_id = 1 : tensor<4xf32>
  sdy.sharding_group %arg1 group_id = 1 : tensor<4xf32>
  sdy.sharding_group %arg2 group_id = 2 : tensor<4xf32>
  func.return
}
}

since %arg0 is in sharding groups {12,25,39} this implies all of these groups should be sharded similarly (and can share the same group id). %arg2 doesn't cause any group_ids to merge but its group id is reindexed after the merging to be minimum.

@copybara-service copybara-service bot force-pushed the test_658712042 branch 2 times, most recently from ce10268 to 976e9fc Compare August 10, 2024 01:22
@copybara-service copybara-service bot changed the title [SDY] add sdy shard-group canonicalization pass. This cl canonicalizes shard group ids by "merging" groups which contain tensors that belong to more than one group across the module. For example [SDY] add sdy shard-group unification pass. This cl unifies shard group ids by "merging" groups which contain tensors that belong to more than one group across the module. Additionally it reindexes the sharding group ids to not have any gaps after merging. For example Aug 10, 2024
…up ids by "merging" groups which contain tensors that belong to more than one group across the module. Additionally it reindexes the sharding group ids to not have any gaps after merging. For example

```
module {
func.func @test(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) {
  sdy.sharding_group %arg0 group_id = 12 : tensor<4xf32>
  sdy.sharding_group %arg0 group_id = 25 : tensor<4xf32>
  sdy.sharding_group %arg0 group_id = 39 : tensor<4xf32>
  sdy.sharding_group %arg1 group_id = 39 : tensor<4xf32>
  sdy.sharding_group %arg2 group_id = 44 : tensor<4xf32>
  func.return
}
}
```
would **canonicalize** to
```
module {
func.func @test(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32) {
  sdy.sharding_group %arg0 group_id = 1 : tensor<4xf32>
  sdy.sharding_group %arg0 group_id = 1 : tensor<4xf32>
  sdy.sharding_group %arg0 group_id = 1 : tensor<4xf32>
  sdy.sharding_group %arg1 group_id = 1 : tensor<4xf32>
  sdy.sharding_group %arg2 group_id = 2 : tensor<4xf32>
  func.return
}
}
```
since `%arg0` is in sharding groups {12,25,39} this implies all of these groups should be sharded similarly (and can share the same group id). `%arg2` doesn't cause any group_ids to merge but its group id is reindexed after the merging to be minimum.

PiperOrigin-RevId: 661475723
@copybara-service copybara-service bot merged commit 88e11f6 into main Aug 10, 2024
@copybara-service copybara-service bot deleted the test_658712042 branch August 10, 2024 01:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant