Skip to content

Commit 4eee9a8

Browse files
committed
refactor
1 parent c9777e7 commit 4eee9a8

File tree

7 files changed

+106
-45
lines changed

7 files changed

+106
-45
lines changed

client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.lang.reflect.Constructor;
2222
import java.lang.reflect.InvocationTargetException;
2323
import java.util.Arrays;
24+
import java.util.Collections;
2425
import java.util.HashSet;
2526
import java.util.List;
2627
import java.util.Map;
@@ -33,7 +34,9 @@
3334
import org.apache.hadoop.conf.Configuration;
3435
import org.apache.spark.SparkConf;
3536
import org.apache.spark.SparkContext;
37+
import org.apache.spark.SparkEnv;
3638
import org.apache.spark.TaskContext;
39+
import org.apache.spark.TaskContext$;
3740
import org.apache.spark.broadcast.Broadcast;
3841
import org.apache.spark.deploy.SparkHadoopUtil;
3942
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
@@ -45,9 +48,7 @@
4548
import org.apache.uniffle.client.api.ShuffleManagerClient;
4649
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
4750
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
48-
import org.apache.uniffle.client.request.RssReassignServersRequest;
4951
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
50-
import org.apache.uniffle.client.response.RssReassignServersResponse;
5152
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
5253
import org.apache.uniffle.client.util.ClientUtils;
5354
import org.apache.uniffle.client.util.RssClientConfig;
@@ -363,6 +364,7 @@ public static RssException reportRssFetchFailedException(
363364
try (ShuffleManagerClient client =
364365
ShuffleManagerClientFactory.getInstance()
365366
.createShuffleManagerClient(ClientType.GRPC, driver, port)) {
367+
TaskContext taskContext = TaskContext$.MODULE$.get();
366368
// todo: Create a new rpc interface to report failures in batch.
367369
for (int partitionId : failedPartitions) {
368370
RssReportShuffleFetchFailureRequest req =
@@ -371,22 +373,14 @@ public static RssException reportRssFetchFailedException(
371373
shuffleId,
372374
stageAttemptId,
373375
partitionId,
374-
rssFetchFailedException.getMessage());
376+
rssFetchFailedException.getMessage(),
377+
Collections.emptyList(),
378+
taskContext.stageId(),
379+
taskContext.taskAttemptId(),
380+
taskContext.attemptNumber(),
381+
SparkEnv.get().executorId());
375382
RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req);
376383
if (response.getReSubmitWholeStage()) {
377-
TaskContext taskContext = TaskContext.get();
378-
RssReassignServersRequest rssReassignServersRequest =
379-
new RssReassignServersRequest(
380-
taskContext.stageId(),
381-
taskContext.stageAttemptNumber(),
382-
shuffleId,
383-
taskContext.numPartitions());
384-
RssReassignServersResponse reassignServersResponse =
385-
client.reassignShuffleServers(rssReassignServersRequest);
386-
LOG.info(
387-
"Reassign servers for stage retry due to the fetch failure, result: {}",
388-
reassignServersResponse.isNeedReassign());
389-
390384
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1
391385
// is provided.
392386
FetchFailedException ffe =

client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssFetchFailedIterator.java

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,24 @@
1818
package org.apache.spark.shuffle.reader;
1919

2020
import java.io.IOException;
21+
import java.util.Collections;
2122
import java.util.Objects;
2223

2324
import scala.Product2;
2425
import scala.collection.AbstractIterator;
2526
import scala.collection.Iterator;
2627

28+
import org.apache.spark.SparkEnv;
2729
import org.apache.spark.TaskContext;
30+
import org.apache.spark.TaskContext$;
2831
import org.apache.spark.shuffle.FetchFailedException;
2932
import org.apache.spark.shuffle.RssSparkShuffleUtils;
3033
import org.slf4j.Logger;
3134
import org.slf4j.LoggerFactory;
3235

3336
import org.apache.uniffle.client.api.ShuffleManagerClient;
3437
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
35-
import org.apache.uniffle.client.request.RssReassignServersRequest;
3638
import org.apache.uniffle.client.request.RssReportShuffleFetchFailureRequest;
37-
import org.apache.uniffle.client.response.RssReassignServersResponse;
3839
import org.apache.uniffle.client.response.RssReportShuffleFetchFailureResponse;
3940
import org.apache.uniffle.common.ClientType;
4041
import org.apache.uniffle.common.exception.RssException;
@@ -114,28 +115,21 @@ private RssException generateFetchFailedIfNecessary(RssFetchFailedException e) {
114115
int port = builder.reportServerPort;
115116
// todo: reuse this manager client if this is a bottleneck.
116117
try (ShuffleManagerClient client = createShuffleManagerClient(driver, port)) {
118+
TaskContext taskContext = TaskContext$.MODULE$.get();
117119
RssReportShuffleFetchFailureRequest req =
118120
new RssReportShuffleFetchFailureRequest(
119121
builder.appId,
120122
builder.shuffleId,
121123
builder.stageAttemptId,
122124
builder.partitionId,
123-
e.getMessage());
125+
e.getMessage(),
126+
Collections.emptyList(),
127+
taskContext.stageId(),
128+
taskContext.taskAttemptId(),
129+
taskContext.attemptNumber(),
130+
SparkEnv.get().executorId());
124131
RssReportShuffleFetchFailureResponse response = client.reportShuffleFetchFailure(req);
125132
if (response.getReSubmitWholeStage()) {
126-
TaskContext taskContext = TaskContext.get();
127-
RssReassignServersRequest rssReassignServersRequest =
128-
new RssReassignServersRequest(
129-
taskContext.stageId(),
130-
taskContext.stageAttemptNumber(),
131-
builder.shuffleId,
132-
taskContext.numPartitions());
133-
RssReassignServersResponse reassignServersResponse =
134-
client.reassignShuffleServers(rssReassignServersRequest);
135-
LOG.info(
136-
"Reassign servers for stage retry due to the fetch failure, result: {}",
137-
reassignServersResponse.isNeedReassign());
138-
139133
// since we are going to roll out the whole stage, mapIndex shouldn't matter, hence -1 is
140134
// provided.
141135
FetchFailedException ffe =

common/src/main/java/org/apache/uniffle/common/exception/RssFetchFailedException.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,26 @@
1717

1818
package org.apache.uniffle.common.exception;
1919

20+
import org.apache.uniffle.common.ShuffleServerInfo;
21+
2022
/** Dedicated exception for rss client's shuffle failed related exception. */
2123
public class RssFetchFailedException extends RssException {
24+
private ShuffleServerInfo fetchFailureServerId;
25+
26+
public RssFetchFailedException(String message, ShuffleServerInfo fetchFailureServerId) {
27+
super(message);
28+
this.fetchFailureServerId = fetchFailureServerId;
29+
}
30+
2231
public RssFetchFailedException(String message) {
2332
super(message);
2433
}
2534

2635
public RssFetchFailedException(String message, Throwable e) {
2736
super(message, e);
2837
}
38+
39+
public ShuffleServerInfo getFetchFailureServerId() {
40+
return fetchFailureServerId;
41+
}
2942
}

internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import org.apache.uniffle.common.RemoteStorageInfo;
6767
import org.apache.uniffle.common.ShuffleBlockInfo;
6868
import org.apache.uniffle.common.ShuffleDataDistributionType;
69+
import org.apache.uniffle.common.ShuffleServerInfo;
6970
import org.apache.uniffle.common.config.RssClientConf;
7071
import org.apache.uniffle.common.config.RssConf;
7172
import org.apache.uniffle.common.exception.NotRetryException;
@@ -827,12 +828,16 @@ public RssGetShuffleResultResponse getShuffleResult(RssGetShuffleResultRequest r
827828
+ ", errorMsg:"
828829
+ rpcResponse.getRetMsg();
829830
LOG.error(msg);
830-
throw new RssFetchFailedException(msg);
831+
throw new RssFetchFailedException(msg, getShuffleServerInfo());
831832
}
832833

833834
return response;
834835
}
835836

837+
private ShuffleServerInfo getShuffleServerInfo() {
838+
return new ShuffleServerInfo(host, port);
839+
}
840+
836841
@Override
837842
public RssGetShuffleResultResponse getShuffleResultForMultiPart(
838843
RssGetShuffleResultForMultiPartRequest request) {
@@ -876,7 +881,7 @@ public RssGetShuffleResultResponse getShuffleResultForMultiPart(
876881
+ ", errorMsg:"
877882
+ rpcResponse.getRetMsg();
878883
LOG.error(msg);
879-
throw new RssFetchFailedException(msg);
884+
throw new RssFetchFailedException(msg, getShuffleServerInfo());
880885
}
881886

882887
return response;
@@ -939,7 +944,7 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request
939944
+ ", errorMsg:"
940945
+ rpcResponse.getRetMsg();
941946
LOG.error(msg);
942-
throw new RssFetchFailedException(msg);
947+
throw new RssFetchFailedException(msg, getShuffleServerInfo());
943948
}
944949
return response;
945950
}
@@ -1002,7 +1007,7 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ
10021007
+ ", errorMsg:"
10031008
+ rpcResponse.getRetMsg();
10041009
LOG.error(msg);
1005-
throw new RssFetchFailedException(msg);
1010+
throw new RssFetchFailedException(msg, getShuffleServerInfo());
10061011
}
10071012
return response;
10081013
}
@@ -1078,7 +1083,7 @@ public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData(
10781083
+ ", errorMsg:"
10791084
+ rpcResponse.getRetMsg();
10801085
LOG.error(msg);
1081-
throw new RssFetchFailedException(msg);
1086+
throw new RssFetchFailedException(msg, getShuffleServerInfo());
10821087
}
10831088
return response;
10841089
}
@@ -1101,7 +1106,7 @@ protected void waitOrThrow(
11011106
request.getRetryMax(),
11021107
System.currentTimeMillis() - start);
11031108
LOG.error(msg);
1104-
throw new RssFetchFailedException(msg);
1109+
throw new RssFetchFailedException(msg, getShuffleServerInfo());
11051110
}
11061111
try {
11071112
long backoffTime =

internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.apache.uniffle.client.response.RssGetShuffleIndexResponse;
3838
import org.apache.uniffle.client.response.RssSendShuffleDataResponse;
3939
import org.apache.uniffle.common.ShuffleBlockInfo;
40+
import org.apache.uniffle.common.ShuffleServerInfo;
4041
import org.apache.uniffle.common.config.RssClientConf;
4142
import org.apache.uniffle.common.config.RssConf;
4243
import org.apache.uniffle.common.exception.NotRetryException;
@@ -303,10 +304,14 @@ public RssGetInMemoryShuffleDataResponse getInMemoryShuffleData(
303304
+ ", errorMsg:"
304305
+ getMemoryShuffleDataResponse.getRetMessage();
305306
LOG.error(msg);
306-
throw new RssFetchFailedException(msg);
307+
throw new RssFetchFailedException(msg, getShuffleServerInfo());
307308
}
308309
}
309310

311+
private ShuffleServerInfo getShuffleServerInfo() {
312+
return new ShuffleServerInfo(host, port, nettyPort);
313+
}
314+
310315
@Override
311316
public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest request) {
312317
TransportClient transportClient = getTransportClient();
@@ -361,7 +366,7 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ
361366
+ ", errorMsg:"
362367
+ getLocalShuffleIndexResponse.getRetMessage();
363368
LOG.error(msg);
364-
throw new RssFetchFailedException(msg);
369+
throw new RssFetchFailedException(msg, getShuffleServerInfo());
365370
}
366371
}
367372

@@ -421,7 +426,7 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request
421426
+ ", errorMsg:"
422427
+ getLocalShuffleDataResponse.getRetMessage();
423428
LOG.error(msg);
424-
throw new RssFetchFailedException(msg);
429+
throw new RssFetchFailedException(msg, getShuffleServerInfo());
425430
}
426431
}
427432

internal-client/src/main/java/org/apache/uniffle/client/request/RssReportShuffleFetchFailureRequest.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717

1818
package org.apache.uniffle.client.request;
1919

20+
import java.util.Collections;
21+
import java.util.List;
22+
23+
import com.google.common.annotations.VisibleForTesting;
24+
25+
import org.apache.uniffle.common.ShuffleServerInfo;
2026
import org.apache.uniffle.proto.RssProtos.ReportShuffleFetchFailureRequest;
2127

2228
public class RssReportShuffleFetchFailureRequest {
@@ -25,14 +31,50 @@ public class RssReportShuffleFetchFailureRequest {
2531
private int stageAttemptId;
2632
private int partitionId;
2733
private String exception;
34+
private List<ShuffleServerInfo> fetchFailureServerInfos;
35+
private int stageId;
36+
private long taskAttemptId;
37+
private int taskAttemptNumber;
38+
private String executorId;
2839

2940
public RssReportShuffleFetchFailureRequest(
30-
String appId, int shuffleId, int stageAttemptId, int partitionId, String exception) {
41+
String appId,
42+
int shuffleId,
43+
int stageAttemptId,
44+
int partitionId,
45+
String exception,
46+
List<ShuffleServerInfo> fetchFailureServerInfos,
47+
int stageId,
48+
long taskAttemptId,
49+
int taskAttemptNumber,
50+
String executorId) {
3151
this.appId = appId;
3252
this.shuffleId = shuffleId;
3353
this.stageAttemptId = stageAttemptId;
3454
this.partitionId = partitionId;
3555
this.exception = exception;
56+
this.fetchFailureServerInfos = fetchFailureServerInfos;
57+
this.stageId = stageId;
58+
this.taskAttemptId = taskAttemptId;
59+
this.taskAttemptNumber = taskAttemptNumber;
60+
this.executorId = executorId;
61+
}
62+
63+
// Only for tests
64+
@VisibleForTesting
65+
public RssReportShuffleFetchFailureRequest(
66+
String appId, int shuffleId, int stageAttemptId, int partitionId, String exception) {
67+
this(
68+
appId,
69+
shuffleId,
70+
stageAttemptId,
71+
partitionId,
72+
exception,
73+
Collections.emptyList(),
74+
0,
75+
0L,
76+
0,
77+
"executor1");
3678
}
3779

3880
public ReportShuffleFetchFailureRequest toProto() {
@@ -42,7 +84,12 @@ public ReportShuffleFetchFailureRequest toProto() {
4284
.setAppId(appId)
4385
.setShuffleId(shuffleId)
4486
.setStageAttemptId(stageAttemptId)
45-
.setPartitionId(partitionId);
87+
.setPartitionId(partitionId)
88+
.setStageId(stageId)
89+
.setTaskAttemptId(taskAttemptId)
90+
.setTaskAttemptNumber(taskAttemptNumber)
91+
.setExecutorId(executorId)
92+
.addAllFetchFailureServerId(ShuffleServerInfo.toProto(fetchFailureServerInfos));
4693
if (exception != null) {
4794
builder.setException(exception);
4895
}

proto/src/main/proto/Rss.proto

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,11 @@ message ReportShuffleFetchFailureRequest {
554554
int32 stageAttemptId = 3;
555555
int32 partitionId = 4;
556556
string exception = 5;
557-
// todo: report ShuffleServerId if needed
558-
// ShuffleServerId serverId = 6;
557+
repeated ShuffleServerId fetchFailureServerId = 6;
558+
int32 stageId = 7;
559+
int64 taskAttemptId = 8;
560+
int32 taskAttemptNumber = 9;
561+
string executorId = 10;
559562
}
560563

561564
message ReportShuffleFetchFailureResponse {

0 commit comments

Comments
 (0)