Skip to content

Commit 1aa91b0

Browse files
authored
refactor: Http client should not own task id (#26592)
## Description ``` == NO RELEASE NOTE == ```
1 parent 991d1b9 commit 1aa91b0

File tree

8 files changed

+216
-163
lines changed

8 files changed

+216
-163
lines changed

presto-spark-base/src/main/java/com/facebook/presto/spark/execution/http/PrestoSparkHttpTaskClient.java

Lines changed: 101 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ public class PrestoSparkHttpTaskClient
9898
private static final Logger log = Logger.get(PrestoSparkHttpTaskClient.class);
9999
private final OkHttpClient httpClient;
100100
private final URI location;
101-
private final URI taskUri;
102101
private final JsonCodec<TaskInfo> taskInfoCodec;
103102
private final JsonCodec<PlanFragment> planFragmentCodec;
104103
private final JsonCodec<BatchTaskUpdateRequest> taskUpdateRequestCodec;
@@ -109,7 +108,6 @@ public class PrestoSparkHttpTaskClient
109108

110109
public PrestoSparkHttpTaskClient(
111110
OkHttpClient httpClient,
112-
TaskId taskId,
113111
URI location,
114112
JsonCodec<TaskInfo> taskInfoCodec,
115113
JsonCodec<PlanFragment> planFragmentCodec,
@@ -124,7 +122,6 @@ public PrestoSparkHttpTaskClient(
124122
this.taskInfoCodec = requireNonNull(taskInfoCodec, "taskInfoCodec is null");
125123
this.planFragmentCodec = requireNonNull(planFragmentCodec, "planFragmentCodec is null");
126124
this.taskUpdateRequestCodec = requireNonNull(taskUpdateRequestCodec, "taskUpdateRequestCodec is null");
127-
this.taskUri = createTaskUri(location, taskId);
128125
this.infoRefreshMaxWait = requireNonNull(infoRefreshMaxWait, "infoRefreshMaxWait is null");
129126
this.executor = requireNonNull(executor, "executor is null");
130127
this.scheduledExecutorService = requireNonNull(scheduledExecutorService, "scheduledExecutorService is null");
@@ -134,7 +131,7 @@ public PrestoSparkHttpTaskClient(
134131
/**
135132
* Get results from a native engine task that ends with none shuffle operator. It always fetches from a single buffer.
136133
*/
137-
public ListenableFuture<PagesResponse> getResults(long token, DataSize maxResponseSize)
134+
public ListenableFuture<PagesResponse> getResults(TaskId taskId, long token, DataSize maxResponseSize)
138135
{
139136
RequestErrorTracker errorTracker = new RequestErrorTracker(
140137
"NativeExecution",
@@ -145,100 +142,13 @@ public ListenableFuture<PagesResponse> getResults(long token, DataSize maxRespon
145142
scheduledExecutorService,
146143
"sending update request to native process");
147144
SettableFuture<PagesResponse> result = SettableFuture.create();
148-
scheduleGetResultsRequest(prepareGetResultsRequest(token, maxResponseSize), errorTracker, result);
145+
scheduleGetResultsRequest(prepareGetResultsRequest(taskId, token, maxResponseSize), errorTracker, result);
149146
return result;
150147
}
151148

152-
private void scheduleGetResultsRequest(
153-
Request request,
154-
RequestErrorTracker errorTracker,
155-
SettableFuture<PagesResponse> result)
156-
{
157-
ListenableFuture<Void> permitFuture = (ListenableFuture<Void>) errorTracker.acquireRequestPermit();
158-
addCallback(permitFuture, new FutureCallback<Void>() {
159-
@Override
160-
public void onSuccess(Void ignored)
161-
{
162-
errorTracker.startRequest();
163-
httpClient.newCall(request).enqueue(new Callback() {
164-
@Override
165-
public void onFailure(Call call, IOException e)
166-
{
167-
handleGetResultsFailure(e, errorTracker, request, result);
168-
}
169-
170-
@Override
171-
public void onResponse(Call call, Response response)
172-
{
173-
try {
174-
BaseResponse<PagesResponse> baseResponse = new PageResponseHandler().handle(request, response);
175-
if (baseResponse.hasValue()) {
176-
errorTracker.requestSucceeded();
177-
result.set(baseResponse.getValue());
178-
}
179-
else {
180-
Exception exception = baseResponse.getException();
181-
if (exception != null) {
182-
handleGetResultsFailure(exception, errorTracker, request, result);
183-
}
184-
else {
185-
handleGetResultsFailure(new RuntimeException("Empty response without exception"), errorTracker, request, result);
186-
}
187-
}
188-
}
189-
catch (Exception e) {
190-
handleGetResultsFailure(e, errorTracker, request, result);
191-
}
192-
finally {
193-
response.close();
194-
}
195-
}
196-
});
197-
}
198-
199-
@Override
200-
public void onFailure(Throwable t)
201-
{
202-
result.setException(t);
203-
}
204-
}, executor);
205-
}
206-
207-
private void handleGetResultsFailure(Throwable failure, RequestErrorTracker errorTracker,
208-
Request request, SettableFuture<PagesResponse> result)
209-
{
210-
log.info("Received failure response with exception %s", failure);
211-
if (Arrays.stream(failure.getSuppressed()).anyMatch(t -> t instanceof PrestoException)) {
212-
result.setException(failure);
213-
return;
214-
}
215-
try {
216-
errorTracker.requestFailed(failure);
217-
scheduleGetResultsRequest(request, errorTracker, result);
218-
}
219-
catch (Throwable t) {
220-
result.setException(t);
221-
}
222-
}
223-
224-
private Request prepareGetResultsRequest(long token, DataSize maxResponseSize)
149+
public void acknowledgeResultsAsync(TaskId taskId, long nextToken)
225150
{
226-
HttpUrl url = HttpUrl.get(taskUri).newBuilder()
227-
.addPathSegment("results")
228-
.addPathSegment("0")
229-
.addPathSegment(String.valueOf(token))
230-
.build();
231-
232-
return new Request.Builder()
233-
.url(url)
234-
.get()
235-
.addHeader(PRESTO_MAX_SIZE, maxResponseSize.toString())
236-
.build();
237-
}
238-
239-
public void acknowledgeResultsAsync(long nextToken)
240-
{
241-
HttpUrl url = HttpUrl.get(taskUri).newBuilder()
151+
HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder()
242152
.addPathSegment("results")
243153
.addPathSegment("0")
244154
.addPathSegment(String.valueOf(nextToken))
@@ -259,9 +169,9 @@ public void acknowledgeResultsAsync(long nextToken)
259169
scheduleVoidRequest(request, new BytesResponseHandler(), errorTracker, result);
260170
}
261171

262-
public ListenableFuture<Void> abortResultsAsync()
172+
public ListenableFuture<Void> abortResultsAsync(TaskId taskId)
263173
{
264-
HttpUrl url = HttpUrl.get(taskUri).newBuilder()
174+
HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder()
265175
.addPathSegment("results")
266176
.addPathSegment("0")
267177
.build();
@@ -280,11 +190,11 @@ public ListenableFuture<Void> abortResultsAsync()
280190
return result;
281191
}
282192

283-
public TaskInfo getTaskInfo()
193+
public TaskInfo getTaskInfo(TaskId taskId)
284194
{
285195
Request request = setContentTypeHeaders(new Request.Builder())
286196
.addHeader(PRESTO_MAX_WAIT, infoRefreshMaxWait.toString())
287-
.url(taskUri.toString())
197+
.url(getTaskUri(taskId).toString())
288198
.get()
289199
.build();
290200
ListenableFuture<TaskInfo> future = executeWithRetries(
@@ -296,6 +206,7 @@ public TaskInfo getTaskInfo()
296206
}
297207

298208
public TaskInfo updateTask(
209+
TaskId taskId,
299210
List<TaskSource> sources,
300211
PlanFragment planFragment,
301212
TableWriteInfo tableWriteInfo,
@@ -315,7 +226,7 @@ public TaskInfo updateTask(
315226
writeInfo);
316227
BatchTaskUpdateRequest batchTaskUpdateRequest = new BatchTaskUpdateRequest(updateRequest, shuffleWriteInfo, broadcastBasePath);
317228

318-
HttpUrl url = HttpUrl.get(taskUri).newBuilder()
229+
HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder()
319230
.addPathSegment("batch")
320231
.build();
321232
byte[] requestBody = taskUpdateRequestCodec.toBytes(batchTaskUpdateRequest);
@@ -336,19 +247,101 @@ public URI getLocation()
336247
return location;
337248
}
338249

339-
public URI getTaskUri()
250+
public URI getTaskUri(TaskId taskId)
340251
{
341-
return taskUri;
252+
return HttpUrl.get(location).newBuilder()
253+
.addPathSegment("v1")
254+
.addPathSegment("task")
255+
.addPathSegment(taskId.toString())
256+
.build()
257+
.uri();
342258
}
343259

344-
private URI createTaskUri(URI baseUri, TaskId taskId)
260+
private void scheduleGetResultsRequest(
261+
Request request,
262+
RequestErrorTracker errorTracker,
263+
SettableFuture<PagesResponse> result)
345264
{
346-
return HttpUrl.get(baseUri).newBuilder()
347-
.addPathSegment("v1")
348-
.addPathSegment("task")
349-
.addPathSegment(taskId.toString())
350-
.build()
351-
.uri();
265+
ListenableFuture<Void> permitFuture = (ListenableFuture<Void>) errorTracker.acquireRequestPermit();
266+
addCallback(permitFuture, new FutureCallback<Void>() {
267+
@Override
268+
public void onSuccess(Void ignored)
269+
{
270+
errorTracker.startRequest();
271+
httpClient.newCall(request).enqueue(new Callback() {
272+
@Override
273+
public void onFailure(Call call, IOException e)
274+
{
275+
handleGetResultsFailure(e, errorTracker, request, result);
276+
}
277+
278+
@Override
279+
public void onResponse(Call call, Response response)
280+
{
281+
try {
282+
BaseResponse<PagesResponse> baseResponse = new PageResponseHandler().handle(request, response);
283+
if (baseResponse.hasValue()) {
284+
errorTracker.requestSucceeded();
285+
result.set(baseResponse.getValue());
286+
}
287+
else {
288+
Exception exception = baseResponse.getException();
289+
if (exception != null) {
290+
handleGetResultsFailure(exception, errorTracker, request, result);
291+
}
292+
else {
293+
handleGetResultsFailure(new RuntimeException("Empty response without exception"), errorTracker, request, result);
294+
}
295+
}
296+
}
297+
catch (Exception e) {
298+
handleGetResultsFailure(e, errorTracker, request, result);
299+
}
300+
finally {
301+
response.close();
302+
}
303+
}
304+
});
305+
}
306+
307+
@Override
308+
public void onFailure(Throwable t)
309+
{
310+
result.setException(t);
311+
}
312+
}, executor);
313+
}
314+
315+
private void handleGetResultsFailure(Throwable failure, RequestErrorTracker errorTracker,
316+
Request request, SettableFuture<PagesResponse> result)
317+
{
318+
log.info("Received failure response with exception %s", failure);
319+
if (Arrays.stream(failure.getSuppressed()).anyMatch(t -> t instanceof PrestoException)) {
320+
result.setException(failure);
321+
return;
322+
}
323+
try {
324+
errorTracker.requestFailed(failure);
325+
scheduleGetResultsRequest(request, errorTracker, result);
326+
}
327+
catch (Throwable t) {
328+
result.setException(t);
329+
}
330+
}
331+
332+
private Request prepareGetResultsRequest(TaskId taskId, long token, DataSize maxResponseSize)
333+
{
334+
HttpUrl url = HttpUrl.get(getTaskUri(taskId)).newBuilder()
335+
.addPathSegment("results")
336+
.addPathSegment("0")
337+
.addPathSegment(String.valueOf(token))
338+
.build();
339+
340+
return new Request.Builder()
341+
.url(url)
342+
.get()
343+
.addHeader(PRESTO_MAX_SIZE, maxResponseSize.toString())
344+
.build();
352345
}
353346

354347
private <T> ListenableFuture<T> executeWithRetries(

presto-spark-base/src/main/java/com/facebook/presto/spark/execution/nativeprocess/HttpNativeExecutionTaskInfoFetcher.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import com.facebook.airlift.log.Logger;
1717
import com.facebook.airlift.units.Duration;
18+
import com.facebook.presto.execution.TaskId;
1819
import com.facebook.presto.execution.TaskInfo;
1920
import com.facebook.presto.spark.execution.http.PrestoSparkHttpTaskClient;
2021
import com.google.common.annotations.VisibleForTesting;
@@ -39,6 +40,7 @@ public class HttpNativeExecutionTaskInfoFetcher
3940
{
4041
private static final Logger log = Logger.get(HttpNativeExecutionTaskInfoFetcher.class);
4142

43+
private final TaskId taskId;
4244
private final PrestoSparkHttpTaskClient workerClient;
4345
private final ScheduledExecutorService updateScheduledExecutor;
4446
private final AtomicReference<TaskInfo> taskInfo = new AtomicReference<>();
@@ -50,11 +52,13 @@ public class HttpNativeExecutionTaskInfoFetcher
5052
private ScheduledFuture<?> scheduledFuture;
5153

5254
public HttpNativeExecutionTaskInfoFetcher(
55+
TaskId taskId,
5356
ScheduledExecutorService updateScheduledExecutor,
5457
PrestoSparkHttpTaskClient workerClient,
5558
Duration infoFetchInterval,
5659
Object taskFinished)
5760
{
61+
this.taskId = requireNonNull(taskId, "taskId is null");
5862
this.workerClient = requireNonNull(workerClient, "workerClient is null");
5963
this.updateScheduledExecutor = requireNonNull(updateScheduledExecutor, "updateScheduledExecutor is null");
6064
this.infoFetchInterval = requireNonNull(infoFetchInterval, "infoFetchInterval is null");
@@ -78,7 +82,7 @@ public void stop()
7882
void doGetTaskInfo()
7983
{
8084
try {
81-
TaskInfo result = workerClient.getTaskInfo();
85+
TaskInfo result = workerClient.getTaskInfo(taskId);
8286
onSuccess(result);
8387
}
8488
catch (Throwable t) {

0 commit comments

Comments
 (0)