|
| 1 | +from megatron.core.tensor_parallel import mappings |
| 2 | +from tests.test_utilities import Utils |
| 3 | +import torch |
| 4 | + |
| 5 | +def test_CopyToModelParallelRegion(): |
| 6 | + Utils.initialize_model_parallel(4,2) |
| 7 | + input_data = torch.ones((1)).cuda()*Utils.rank |
| 8 | + output_data = mappings._CopyToModelParallelRegion.backward(None, input_data) |
| 9 | + result = torch.ones(1).cuda() |
| 10 | + result = result * 22 if Utils.rank >= 4 else result * 6 |
| 11 | + assert(torch.equal(output_data, result)) |
| 12 | + assert(torch.equal(input_data, mappings.copy_to_tensor_model_parallel_region(input_data))) |
| 13 | + assert(torch.equal(input_data, mappings._CopyToModelParallelRegion.symbolic(None, input_data))) |
| 14 | + Utils.destroy_model_parallel() |
| 15 | + |
| 16 | +def test_ReduceFromModelParallelRegion(): |
| 17 | + Utils.initialize_model_parallel(4,2) |
| 18 | + input_data = torch.ones((1)).cuda()*Utils.rank |
| 19 | + output_data = mappings._ReduceFromModelParallelRegion.symbolic(None, input_data) |
| 20 | + result = torch.ones(1).cuda() |
| 21 | + result = result * 22 if Utils.rank >= 4 else result * 6 |
| 22 | + assert(torch.equal(output_data, result)) |
| 23 | + input_data = torch.ones((1)).cuda()*Utils.rank |
| 24 | + assert(torch.equal(mappings.reduce_from_tensor_model_parallel_region(input_data), result)) |
| 25 | + assert(torch.equal(input_data, mappings._ReduceFromModelParallelRegion.backward(None, input_data))) |
| 26 | + Utils.destroy_model_parallel() |
| 27 | + |
| 28 | +def test_ScatterToModelParallelRegion(): |
| 29 | + Utils.initialize_model_parallel(4,2) |
| 30 | + input_data = torch.rand((8,4)).cuda() |
| 31 | + output_data = mappings.scatter_to_tensor_model_parallel_region(input_data) |
| 32 | + req_dim = int(Utils.rank%(Utils.world_size/2)) |
| 33 | + assert(torch.equal(output_data, input_data[:,req_dim].reshape((8,1)))) |
| 34 | + output_data = mappings._ScatterToModelParallelRegion.symbolic(None, input_data) |
| 35 | + assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1)))) |
| 36 | + |
| 37 | + input_data = torch.ones(8).cuda() * Utils.rank |
| 38 | + actual_output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data) |
| 39 | + expected_output = torch.cat(( |
| 40 | + torch.ones(8)*0, |
| 41 | + torch.ones(8)*1, |
| 42 | + torch.ones(8)*2, |
| 43 | + torch.ones(8)*3)).cuda() |
| 44 | + if (Utils.rank >= 4): |
| 45 | + expected_output = expected_output + 4 |
| 46 | + assert(torch.equal(actual_output_data, expected_output)) |
| 47 | + Utils.destroy_model_parallel() |
| 48 | + |
| 49 | +def test_GatherFromModelParallelRegion(): |
| 50 | + Utils.initialize_model_parallel(4,2) |
| 51 | + input_data = torch.rand((8,4)).cuda() |
| 52 | + req_dim = int(Utils.rank%(Utils.world_size/2)) |
| 53 | + output_data = mappings._GatherFromModelParallelRegion.backward(None, input_data) |
| 54 | + assert(torch.equal(output_data, input_data[:, req_dim].reshape((8,1)))) |
| 55 | + input_data = torch.ones(8).cuda() * Utils.rank |
| 56 | + actual_output_data = mappings.gather_from_tensor_model_parallel_region(input_data) |
| 57 | + expected_output = torch.cat(( |
| 58 | + torch.ones(8)*0, |
| 59 | + torch.ones(8)*1, |
| 60 | + torch.ones(8)*2, |
| 61 | + torch.ones(8)*3)).cuda() |
| 62 | + if (Utils.rank >= 4): |
| 63 | + expected_output = expected_output + 4 |
| 64 | + assert(torch.equal(actual_output_data, expected_output)) |
| 65 | + assert(torch.equal(mappings._GatherFromModelParallelRegion.symbolic(None, input_data), expected_output)) |
| 66 | + Utils.destroy_model_parallel() |
| 67 | + |
| 68 | +def test_ScatterToSequenceParallelRegion(): |
| 69 | + Utils.initialize_model_parallel(4,2) |
| 70 | + input_data = torch.rand((8,4)).cuda() |
| 71 | + req_dim = int(Utils.rank%(Utils.world_size/2))*2 |
| 72 | + output_data = mappings._ScatterToSequenceParallelRegion.symbolic(None, input_data) |
| 73 | + assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :])) |
| 74 | + output_data = mappings.scatter_to_sequence_parallel_region(input_data) |
| 75 | + assert(torch.equal(output_data, input_data[req_dim:req_dim+2, :])) |
| 76 | + input_data = torch.ones(4).cuda() * Utils.rank |
| 77 | + output_data = mappings._ScatterToModelParallelRegion.backward(None, input_data) |
| 78 | + expected_output = torch.concat(( |
| 79 | + torch.ones(4)*0, |
| 80 | + torch.ones(4)*1, |
| 81 | + torch.ones(4)*2, |
| 82 | + torch.ones(4)*3)).cuda() |
| 83 | + if (Utils.rank >= 4): |
| 84 | + expected_output = expected_output + 4 |
| 85 | + assert(torch.equal(output_data, expected_output)) |
| 86 | + Utils.destroy_model_parallel() |
| 87 | + |
| 88 | +def test_GatherFromSequenceParallelRegion(): |
| 89 | + Utils.initialize_model_parallel(4,2) |
| 90 | + input_data = torch.ones(4).cuda() * Utils.rank |
| 91 | + output_data = mappings.gather_from_sequence_parallel_region(input_data) |
| 92 | + expected_output = torch.concat(( |
| 93 | + torch.ones(4)*0, |
| 94 | + torch.ones(4)*1, |
| 95 | + torch.ones(4)*2, |
| 96 | + torch.ones(4)*3)).cuda() |
| 97 | + if (Utils.rank >= 4): |
| 98 | + expected_output = expected_output + 4 |
| 99 | + assert(torch.equal(output_data, expected_output)) |
| 100 | + assert(torch.equal(mappings._GatherFromSequenceParallelRegion.symbolic(None, input_data), expected_output)) |
| 101 | + input_data = torch.vstack(( |
| 102 | + torch.ones(4)*0, |
| 103 | + torch.ones(4)*1, |
| 104 | + torch.ones(4)*2, |
| 105 | + torch.ones(4)*3)).cuda() |
| 106 | + class Ctx: |
| 107 | + tensor_parallel_output_grad = True |
| 108 | + output_data = mappings._GatherFromSequenceParallelRegion.backward(Ctx(), input_data) |
| 109 | + expected_output = torch.ones((1,4)).cuda() * 4 * int(Utils.rank % 4) |
| 110 | + assert(torch.equal(output_data[0], expected_output)) |
| 111 | + Utils.destroy_model_parallel() |
| 112 | + |
| 113 | +def test_ReduceScatterToSequenceParallelRegion(): |
| 114 | + Utils.initialize_model_parallel(4,2) |
| 115 | + input_data = torch.vstack(( |
| 116 | + torch.ones(4)*0, |
| 117 | + torch.ones(4)*1, |
| 118 | + torch.ones(4)*2, |
| 119 | + torch.ones(4)*3)).cuda() |
| 120 | + output_data = mappings.reduce_scatter_to_sequence_parallel_region(input_data) |
| 121 | + expected_output = torch.ones(4).cuda() * 4 * int(Utils.rank % 4) |
| 122 | + assert(torch.equal(output_data[0], expected_output)) |
| 123 | + assert(torch.equal(mappings._ReduceScatterToSequenceParallelRegion.symbolic(None, input_data) , expected_output.reshape((1,4)))) |
| 124 | + input_data = torch.ones(4).cuda() * Utils.rank |
| 125 | + output_data = mappings._ReduceScatterToSequenceParallelRegion.backward(None,input_data) |
| 126 | + expected_output = torch.concat(( |
| 127 | + torch.ones(4)*0, |
| 128 | + torch.ones(4)*1, |
| 129 | + torch.ones(4)*2, |
| 130 | + torch.ones(4)*3)).cuda() |
| 131 | + if (Utils.rank >= 4): |
| 132 | + expected_output = expected_output + 4 |
| 133 | + assert(torch.equal(output_data, expected_output)) |
| 134 | + Utils.destroy_model_parallel() |
| 135 | + |
0 commit comments