Skip to content

Commit be8ca4e

Browse files
committed
reoptimize
1 parent 94e211c commit be8ca4e

File tree

14 files changed

+248
-166
lines changed

14 files changed

+248
-166
lines changed

client-spark/common/src/main/java/org/apache/spark/shuffle/writer/AddBlockEvent.java

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,31 @@
1919

2020
import java.util.ArrayList;
2121
import java.util.List;
22+
import java.util.concurrent.Future;
23+
import java.util.function.Consumer;
2224

2325
import org.apache.uniffle.common.ShuffleBlockInfo;
2426

2527
public class AddBlockEvent {
2628

29+
private Long eventId;
2730
private String taskId;
2831
private int stageAttemptNumber;
2932
private List<ShuffleBlockInfo> shuffleDataInfoList;
3033
private List<Runnable> processedCallbackChain;
3134

35+
private Consumer<Future> prepare;
36+
3237
public AddBlockEvent(String taskId, List<ShuffleBlockInfo> shuffleDataInfoList) {
33-
this(taskId, 0, shuffleDataInfoList);
38+
this(-1L, taskId, 0, shuffleDataInfoList);
3439
}
3540

3641
public AddBlockEvent(
37-
String taskId, int stageAttemptNumber, List<ShuffleBlockInfo> shuffleDataInfoList) {
42+
Long eventId,
43+
String taskId,
44+
int stageAttemptNumber,
45+
List<ShuffleBlockInfo> shuffleDataInfoList) {
46+
this.eventId = eventId;
3847
this.taskId = taskId;
3948
this.stageAttemptNumber = stageAttemptNumber;
4049
this.shuffleDataInfoList = shuffleDataInfoList;
@@ -46,10 +55,24 @@ public void addCallback(Runnable callback) {
4655
processedCallbackChain.add(callback);
4756
}
4857

58+
public void addPrepare(Consumer<Future> prepare) {
59+
this.prepare = prepare;
60+
}
61+
62+
public void doPrepare(Future future) {
63+
if (prepare != null) {
64+
prepare.accept(future);
65+
}
66+
}
67+
4968
public String getTaskId() {
5069
return taskId;
5170
}
5271

72+
public Long getEventId() {
73+
return eventId;
74+
}
75+
5376
public int getStageAttemptNumber() {
5477
return stageAttemptNumber;
5578
}

client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
import java.util.Map;
2525
import java.util.Optional;
2626
import java.util.Set;
27-
import java.util.concurrent.CompletableFuture;
2827
import java.util.concurrent.ExecutorService;
28+
import java.util.concurrent.Future;
29+
import java.util.concurrent.FutureTask;
2930
import java.util.concurrent.ThreadPoolExecutor;
3031
import java.util.concurrent.TimeUnit;
3132

@@ -80,11 +81,12 @@ public DataPusher(
8081
ThreadUtils.getThreadFactory(this.getClass().getName()));
8182
}
8283

83-
public CompletableFuture<Long> send(AddBlockEvent event) {
84+
public Future<Long> send(AddBlockEvent event) {
8485
if (rssAppId == null) {
8586
throw new RssException("RssAppId should be set.");
8687
}
87-
return CompletableFuture.supplyAsync(
88+
FutureTask<Long> future =
89+
new FutureTask(
8890
() -> {
8991
String taskId = event.getTaskId();
9092
List<ShuffleBlockInfo> shuffleBlockInfoList = event.getShuffleDataInfoList();
@@ -116,14 +118,11 @@ public CompletableFuture<Long> send(AddBlockEvent event) {
116118
.filter(x -> succeedBlockIds.contains(x.getBlockId()))
117119
.map(x -> x.getFreeMemory())
118120
.reduce((a, b) -> a + b)
119-
.get();
120-
},
121-
executorService)
122-
.exceptionally(
123-
ex -> {
124-
LOGGER.error("Unexpected exceptions occurred while sending shuffle data", ex);
125-
return null;
121+
.orElseGet(() -> 0L);
126122
});
123+
event.doPrepare(future);
124+
executorService.submit(future);
125+
return future;
127126
}
128127

129128
private Set<Long> getSucceedBlockIds(SendShuffleDataResult result) {

client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import java.util.List;
2424
import java.util.Map;
2525
import java.util.Optional;
26-
import java.util.concurrent.CompletableFuture;
26+
import java.util.concurrent.Future;
2727
import java.util.concurrent.TimeUnit;
2828
import java.util.concurrent.TimeoutException;
2929
import java.util.concurrent.atomic.AtomicInteger;
@@ -70,6 +70,8 @@ public class WriteBufferManager extends MemoryConsumer {
7070
private AtomicLong recordCounter = new AtomicLong(0);
7171
/** An atomic counter used to keep track of the number of blocks */
7272
private AtomicLong blockCounter = new AtomicLong(0);
73+
74+
private AtomicLong eventIdGenerator = new AtomicLong(0);
7375
// it's part of blockId
7476
private Map<Integer, AtomicInteger> partitionToSeqNo = Maps.newHashMap();
7577
private long askExecutorMemory;
@@ -96,7 +98,7 @@ public class WriteBufferManager extends MemoryConsumer {
9698
private long requireMemoryInterval;
9799
private int requireMemoryRetryMax;
98100
private Optional<Codec> codec;
99-
private Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc;
101+
private Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc;
100102
private long sendSizeLimit;
101103
private boolean memorySpillEnabled;
102104
private int memorySpillTimeoutSec;
@@ -138,7 +140,7 @@ public WriteBufferManager(
138140
TaskMemoryManager taskMemoryManager,
139141
ShuffleWriteMetrics shuffleWriteMetrics,
140142
RssConf rssConf,
141-
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc,
143+
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc,
142144
Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc) {
143145
this(
144146
shuffleId,
@@ -163,7 +165,7 @@ public WriteBufferManager(
163165
TaskMemoryManager taskMemoryManager,
164166
ShuffleWriteMetrics shuffleWriteMetrics,
165167
RssConf rssConf,
166-
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc,
168+
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc,
167169
Function<Integer, List<ShuffleServerInfo>> partitionAssignmentRetrieveFunc,
168170
int stageAttemptNumber) {
169171
super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
@@ -212,7 +214,7 @@ public WriteBufferManager(
212214
TaskMemoryManager taskMemoryManager,
213215
ShuffleWriteMetrics shuffleWriteMetrics,
214216
RssConf rssConf,
215-
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc,
217+
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc,
216218
int stageAttemptNumber) {
217219
this(
218220
shuffleId,
@@ -528,7 +530,12 @@ public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> shuffleBlockI
528530
+ totalSize
529531
+ " bytes");
530532
}
531-
events.add(new AddBlockEvent(taskId, stageAttemptNumber, shuffleBlockInfosPerEvent));
533+
events.add(
534+
new AddBlockEvent(
535+
eventIdGenerator.incrementAndGet(),
536+
taskId,
537+
stageAttemptNumber,
538+
shuffleBlockInfosPerEvent));
532539
shuffleBlockInfosPerEvent = Lists.newArrayList();
533540
totalSize = 0;
534541
}
@@ -543,7 +550,12 @@ public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> shuffleBlockI
543550
+ " bytes");
544551
}
545552
// Use final temporary variables for closures
546-
events.add(new AddBlockEvent(taskId, stageAttemptNumber, shuffleBlockInfosPerEvent));
553+
events.add(
554+
new AddBlockEvent(
555+
eventIdGenerator.incrementAndGet(),
556+
taskId,
557+
stageAttemptNumber,
558+
shuffleBlockInfosPerEvent));
547559
}
548560
return events;
549561
}
@@ -555,15 +567,19 @@ public long spill(long size, MemoryConsumer trigger) {
555567
return 0L;
556568
}
557569

558-
List<CompletableFuture<Long>> futures = spillFunc.apply(clear(bufferSpillRatio));
559-
CompletableFuture<Void> allOfFutures =
560-
CompletableFuture.allOf(futures.toArray(new CompletableFuture[futures.size()]));
570+
List<Future<Long>> futures = spillFunc.apply(clear(bufferSpillRatio));
571+
long end = System.currentTimeMillis() + memorySpillTimeoutSec * 1000;
561572
try {
562-
allOfFutures.get(memorySpillTimeoutSec, TimeUnit.SECONDS);
573+
for (Future f : futures) {
574+
f.get(end - System.currentTimeMillis(), TimeUnit.MILLISECONDS);
575+
}
563576
} catch (TimeoutException timeoutException) {
564577
// A best effort strategy to wait.
565578
// If timeout exception occurs, the underlying tasks won't be cancelled.
566579
LOG.warn("[taskId: {}] Spill tasks timeout after {} seconds", taskId, memorySpillTimeoutSec);
580+
} catch (InterruptedException e) {
581+
Thread.currentThread().interrupt();
582+
LOG.warn("[taskId: {}] Spill interrupted due to kill", taskId);
567583
} catch (Exception e) {
568584
LOG.warn("[taskId: {}] Failed to spill buffers due to ", taskId, e);
569585
} finally {
@@ -608,6 +624,10 @@ public long getBlockCount() {
608624
return blockCounter.get();
609625
}
610626

627+
public Long getLastEventId() {
628+
return eventIdGenerator.get();
629+
}
630+
611631
public void freeAllocatedMemory(long freeMemory) {
612632
freeMemory(freeMemory);
613633
allocatedBytes.addAndGet(-freeMemory);
@@ -671,8 +691,7 @@ public void setTaskId(String taskId) {
671691
}
672692

673693
@VisibleForTesting
674-
public void setSpillFunc(
675-
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc) {
694+
public void setSpillFunc(Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc) {
676695
this.spillFunc = spillFunc;
677696
}
678697

client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.Optional;
3131
import java.util.Set;
3232
import java.util.concurrent.CompletableFuture;
33+
import java.util.concurrent.Future;
3334
import java.util.concurrent.ScheduledExecutorService;
3435
import java.util.concurrent.TimeUnit;
3536
import java.util.concurrent.atomic.AtomicBoolean;
@@ -1587,7 +1588,7 @@ public Map<String, FailedBlockSendTracker> getTaskToFailedBlockSendTracker() {
15871588
return taskToFailedBlockSendTracker;
15881589
}
15891590

1590-
public CompletableFuture<Long> sendData(AddBlockEvent event) {
1591+
public Future<Long> sendData(AddBlockEvent event) {
15911592
if (dataPusher != null && event != null) {
15921593
return dataPusher.send(event);
15931594
}

client-spark/common/src/test/java/org/apache/spark/shuffle/writer/DataPusherTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import java.util.List;
2323
import java.util.Map;
2424
import java.util.Set;
25-
import java.util.concurrent.CompletableFuture;
2625
import java.util.concurrent.ExecutionException;
26+
import java.util.concurrent.Future;
2727
import java.util.function.Supplier;
2828

2929
import com.google.common.collect.Maps;
@@ -119,7 +119,7 @@ public void testSendData() throws ExecutionException, InterruptedException {
119119
new ShuffleBlockInfo(1, 1, 1, 1, 1, new byte[1], null, 1, 100, 1);
120120
AddBlockEvent event = new AddBlockEvent("taskId", Arrays.asList(shuffleBlockInfo));
121121
// sync send
122-
CompletableFuture<Long> future = dataPusher.send(event);
122+
Future<Long> future = dataPusher.send(event);
123123
long memoryFree = future.get();
124124
assertEquals(100, memoryFree);
125125
assertTrue(taskToSuccessBlockIds.get("taskId").contains(1L));

client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.List;
2424
import java.util.Optional;
2525
import java.util.concurrent.CompletableFuture;
26+
import java.util.concurrent.Future;
2627
import java.util.concurrent.TimeUnit;
2728
import java.util.function.Function;
2829
import java.util.stream.Stream;
@@ -370,7 +371,7 @@ public void spillByOwnTest() {
370371
null,
371372
0);
372373

373-
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc =
374+
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc =
374375
blocks -> {
375376
long sum = 0L;
376377
List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
@@ -481,7 +482,7 @@ public void spillPartial() {
481482
null,
482483
0);
483484

484-
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc =
485+
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc =
485486
blocks -> {
486487
long sum = 0L;
487488
List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
@@ -579,7 +580,7 @@ public void spillByOwnWithSparkTaskMemoryManagerTest() {
579580

580581
List<ShuffleBlockInfo> blockList = new ArrayList<>();
581582

582-
Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc =
583+
Function<List<ShuffleBlockInfo>, List<Future<Long>>> spillFunc =
583584
blocks -> {
584585
blockList.addAll(blocks);
585586
long sum = 0L;

client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import java.util.List;
2525
import java.util.Map;
2626
import java.util.Set;
27-
import java.util.concurrent.CompletableFuture;
2827
import java.util.concurrent.ExecutorService;
2928
import java.util.concurrent.Executors;
3029
import java.util.concurrent.Future;
@@ -361,8 +360,7 @@ private void checkSentBlockCount() {
361360
*
362361
* @param shuffleBlockInfoList
363362
*/
364-
private List<CompletableFuture<Long>> processShuffleBlockInfos(
365-
List<ShuffleBlockInfo> shuffleBlockInfoList) {
363+
private List<Future<Long>> processShuffleBlockInfos(List<ShuffleBlockInfo> shuffleBlockInfoList) {
366364
if (shuffleBlockInfoList != null && !shuffleBlockInfoList.isEmpty()) {
367365
shuffleBlockInfoList.stream()
368366
.forEach(
@@ -390,9 +388,8 @@ private List<CompletableFuture<Long>> processShuffleBlockInfos(
390388

391389
// don't send huge block to shuffle server, or there will be OOM if shuffle sever receives data
392390
// more than expected
393-
protected List<CompletableFuture<Long>> postBlockEvent(
394-
List<ShuffleBlockInfo> shuffleBlockInfoList) {
395-
List<CompletableFuture<Long>> futures = new ArrayList<>();
391+
protected List<Future<Long>> postBlockEvent(List<ShuffleBlockInfo> shuffleBlockInfoList) {
392+
List<Future<Long>> futures = new ArrayList<>();
396393
for (AddBlockEvent event : bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
397394
futures.add(shuffleManager.sendData(event));
398395
}

0 commit comments

Comments
 (0)