Skip to content

Commit 8b94a16

Browse files
shanmugamr1992jaredcasper
authored andcommitted
Adding proper test cases
1 parent f861467 commit 8b94a16

File tree

13 files changed

+330
-108
lines changed

13 files changed

+330
-108
lines changed

.coverage

-52 KB
Binary file not shown.

.coveragerc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
[html]
2-
directory = coverage
2+
directory = coverage
3+
4+
[run]
5+
data_file = .coverage_$LOCAL_RANK

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
__pycache__
22
*.so
33
build
4+
.coverage_*
45
*.egg-info

.gitlab-ci.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ test:
44
tags:
55
- docker_gpu_enabled
66
script:
7-
- nvidia-smi
8-
- torchrun --nproc_per_node=2 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
7+
- torchrun --nproc_per_node=8 -m pytest --cov-report=term --cov-report=html --cov=megatron/core tests/
98
coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
109
artifacts:
1110
paths:

megatron/core/tensor_parallel/random.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
gather_split_1d_tensor,
2323
)
2424

25+
from megatron.core.utils import safely_set_viewless_tensor_data
26+
2527
# Default name for the model parallel rng tracker.
2628
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
2729

tests/__init__.py

Whitespace-only changes.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from megatron.core.tensor_parallel.cross_entropy import vocab_parallel_cross_entropy
2+
import torch
3+
from tests.test_utilities import Utils
4+
import numpy as np
5+
6+
def test_vocab_parallel_cross_entropy():
7+
Utils.initialize_model_parallel(4,2)
8+
vocab_parallel_logits = torch.range(0,7).repeat(16,4).cuda()
9+
target = torch.arange(0,32,2).cuda()
10+
output = vocab_parallel_cross_entropy(vocab_parallel_logits, target)
11+
expected_output = torch.tensor([10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309,
12+
10.2309, 8.2309, 6.2309, 4.2309, 10.2309, 8.2309, 6.2309, 4.2309]).cuda()
13+
assert(torch.equal(torch.round(expected_output), torch.round(output)))
14+
Utils.destroy_model_parallel()

tests/tensor_parallel/test_data.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from megatron.core.tensor_parallel.data import broadcast_data
2+
import torch
3+
from tests.test_utilities import Utils
4+
5+
def test_broadcast_data():
6+
Utils.initialize_model_parallel(2,4)
7+
input_data = {
8+
0 : torch.ones((8,8)).cuda() * 0.0,
9+
1 : torch.ones((8,8)).cuda() * 1.0,
10+
2 : torch.ones((8,8)).cuda() * 2.0,
11+
3 : torch.ones((8,8)).cuda() * 3.0,
12+
4 : torch.ones((8,8)).cuda() * 4.0,
13+
5 : torch.ones((8,8)).cuda() * 5.0,
14+
6 : torch.ones((8,8)).cuda() * 6.0,
15+
7 : torch.ones((8,8)).cuda() * 7.0
16+
}
17+
dtype = torch.float32
18+
actual_output = broadcast_data([0,1],input_data, dtype)
19+
assert(torch.equal(actual_output[0], input_data[0]))
20+
assert(torch.equal(actual_output[1], input_data[1]))
21+
Utils.destroy_model_parallel()
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from megatron.core.tensor_parallel import mappings
2+
from tests.test_utilities import Utils
3+
import torch
4+
5+
def test_CopyToModelParallelRegion():
6+
Utils.initialize_model_parallel(4,2)
7+
input_data = torch.ones((1)).cuda()*Utils.rank
8+
output_data = mappings._CopyToModelParallelRegion.backward(None, input_data)
9+
result = torch.ones(1).cuda()
10+
result = result * 22 if Utils.rank >= 4 else result * 6
11+
assert(torch.equal(output_data, result))
12+
assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data)))
13+
assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data)))
14+
Utils.destroy_model_parallel()
15+
16+
def test_ReduceFromModelParallelRegion():
17+
Utils.initialize_model_parallel(4,2)
18+
input_data = torch.ones((1)).cuda()*Utils.rank
19+
output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data)
20+
result = torch.ones(1).cuda()
21+
result = result * 22 if Utils.rank >= 4 else result * 6
22+
assert(torch.equal(output_data, result))
23+
input_data = torch.ones((1)).cuda()*Utils.rank
24+
assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result))
25+
assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data)))
26+
Utils.destroy_model_parallel()
27+
28+
def test_ScatterToModelParallelRegion():
29+
Utils.initialize_model_parallel(4,2)
30+
input_data = torch.rand((8,4)).cuda()
31+
output_data = mappings.scatter_to_tensor_model_parallel_region(input_data)
32+
req_dim = int(Utils.rank%(Utils.world_size/2))
33+
assert(torch.equal(output_data, input_data[:,req_dim].reshape((8,1))))
34+
output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data)
35+
assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1))))
36+
37+
input_data = torch.ones(8).cuda() * Utils.rank
38+
actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
39+
expected_output = torch.cat((
40+
torch.ones(8)*0,
41+
torch.ones(8)*1,
42+
torch.ones(8)*2,
43+
torch.ones(8)*3)).cuda()
44+
if (Utils.rank >= 4):
45+
expected_output = expected_output + 4
46+
assert(torch.equal(actual_output_data, expected_output))
47+
Utils.destroy_model_parallel()
48+
49+
def test_GatherFromModelParallelRegion():
50+
Utils.initialize_model_parallel(4,2)
51+
input_data = torch.rand((8,4)).cuda()
52+
req_dim = int(Utils.rank%(Utils.world_size/2))
53+
output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data)
54+
assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1))))
55+
input_data = torch.ones(8).cuda() * Utils.rank
56+
actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data)
57+
expected_output = torch.cat((
58+
torch.ones(8)*0,
59+
torch.ones(8)*1,
60+
torch.ones(8)*2,
61+
torch.ones(8)*3)).cuda()
62+
if (Utils.rank >= 4):
63+
expected_output = expected_output + 4
64+
assert(torch.equal(actual_output_data, expected_output))
65+
assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output))
66+
Utils.destroy_model_parallel()
67+
68+
def test_ScatterToSequenceParallelRegion():
69+
Utils.initialize_model_parallel(4,2)
70+
input_data = torch.rand((8,4)).cuda()
71+
req_dim = int(Utils.rank%(Utils.world_size/2))*2
72+
output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data)
73+
assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :]))
74+
output_data = mappings.scatter_to_sequence_parallel_region(input_data)
75+
assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :]))
76+
input_data = torch.ones(4).cuda() * Utils.rank
77+
output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data)
78+
expected_output = torch.concat((
79+
torch.ones(4)*0,
80+
torch.ones(4)*1,
81+
torch.ones(4)*2,
82+
torch.ones(4)*3)).cuda()
83+
if (Utils.rank >= 4):
84+
expected_output = expected_output + 4
85+
assert(torch.equal(output_data, expected_output))
86+
Utils.destroy_model_parallel()
87+
88+
def test_GatherFromSequenceParallelRegion():
89+
Utils.initialize_model_parallel(4,2)
90+
input_data = torch.ones(4).cuda() * Utils.rank
91+
output_data = mappings.gather_from_sequence_parallel_region(input_data)
92+
expected_output = torch.concat((
93+
torch.ones(4)*0,
94+
torch.ones(4)*1,
95+
torch.ones(4)*2,
96+
torch.ones(4)*3)).cuda()
97+
if (Utils.rank >= 4):
98+
expected_output = expected_output + 4
99+
assert(torch.equal(output_data, expected_output))
100+
assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output))
101+
input_data = torch.vstack((
102+
torch.ones(4)*0,
103+
torch.ones(4)*1,
104+
torch.ones(4)*2,
105+
torch.ones(4)*3)).cuda()
106+
class Ctx:
107+
tensor_parallel_output_grad = True
108+
output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data)
109+
expected_output = torch.ones((1,4)).cuda() * 4 * int(Utils.rank % 4)
110+
assert(torch.equal(output_data[0], expected_output))
111+
Utils.destroy_model_parallel()
112+
113+
def test_ReduceScatterToSequenceParallelRegion():
114+
Utils.initialize_model_parallel(4,2)
115+
input_data = torch.vstack((
116+
torch.ones(4)*0,
117+
torch.ones(4)*1,
118+
torch.ones(4)*2,
119+
torch.ones(4)*3)).cuda()
120+
output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data)
121+
expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4)
122+
assert(torch.equal(output_data[0], expected_output))
123+
assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data) , expected_output.reshape((1,4))))
124+
input_data = torch.ones(4).cuda() * Utils.rank
125+
output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None,input_data)
126+
expected_output = torch.concat((
127+
torch.ones(4)*0,
128+
torch.ones(4)*1,
129+
torch.ones(4)*2,
130+
torch.ones(4)*3)).cuda()
131+
if (Utils.rank >= 4):
132+
expected_output = expected_output + 4
133+
assert(torch.equal(output_data, expected_output))
134+
Utils.destroy_model_parallel()
135+

tests/tensor_parallel/test_random.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from megatron.core.tensor_parallel.random import CudaRNGStatesTracker
2+
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
3+
from megatron.core.tensor_parallel.random import _CUDA_RNG_STATE_TRACKER
4+
from megatron.core.tensor_parallel.random import checkpoint
5+
from tests.test_utilities import Utils
6+
import pytest
7+
import torch
8+
9+
def test_cuda_rng_states_tracker():
10+
rng_tracker = CudaRNGStatesTracker()
11+
rng_tracker.set_states({"state1":1234})
12+
assert(rng_tracker.get_states()["state1"] == 1234)
13+
rng_tracker.reset()
14+
assert(rng_tracker.get_states() == {})
15+
seed = 1111
16+
rng_tracker.add("state2",seed)
17+
with pytest.raises(Exception):
18+
assert(rng_tracker.add("state3",seed))
19+
with pytest.raises(Exception):
20+
assert(rng_tracker.add("state2",111))
21+
assert(rng_tracker.get_states()['state2'] is not None)
22+
with pytest.raises(Exception):
23+
assert()
24+
25+
rng_tracker.fork("state2")
26+
torch.cuda.manual_seed(seed)
27+
rng_state = torch.cuda.get_rng_state()
28+
assert torch.equal(rng_tracker.get_states()['state2'], rng_state)
29+
30+
def test_model_parallel_cuda_manual_seed():
31+
Utils.initialize_model_parallel(4,2)
32+
model_parallel_cuda_manual_seed(0)
33+
assert(_CUDA_RNG_STATE_TRACKER.get_states()['model-parallel-rng'] is not None)
34+
Utils.destroy_model_parallel()
35+
36+
def test_checkpoint():
37+
def test_forward(*input):
38+
return input[0]+input[1]
39+
assert(torch.equal(torch.ones(16)*3,checkpoint(test_forward, None, torch.ones(16), torch.ones(16)*2)))
40+
Utils.initialize_model_parallel()
41+
input1 = torch.ones((4,4))
42+
checkpoint(test_forward, True, input1, torch.ones((4,4))*2)
43+
assert(torch.equal(torch.ones(input1.numel()).cuda(), input1))
44+
Utils.destroy_model_parallel()

0 commit comments

Comments
 (0)