Skip to content

Commit 2d05539

Browse files
authored
Add default resources support to DynamicWorkflowTask and related classes (#317)
Signed-off-by: Rafael Ribeiro Raposo <[email protected]>
1 parent f034f75 commit 2d05539

File tree

5 files changed

+47
-0
lines changed

5 files changed

+47
-0
lines changed

flytekit-api/src/main/java/org/flyte/api/v1/DynamicWorkflowTask.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,8 @@ public interface DynamicWorkflowTask {
2626
DynamicJobSpec run(Map<String, Literal> inputs);
2727

2828
RetryStrategy getRetries();
29+
30+
default Resources getResources() {
31+
return Resources.builder().build();
32+
}
2933
}

flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTask.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,4 +93,8 @@ public SdkNode<OutputT> apply(
9393
public int getRetries() {
9494
return 0;
9595
}
96+
97+
public SdkResources getResources() {
98+
return SdkResources.empty();
99+
}
96100
}

flytekit-java/src/main/java/org/flyte/flytekit/SdkDynamicWorkflowTaskRegistrar.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.flyte.api.v1.DynamicWorkflowTaskRegistrar;
3434
import org.flyte.api.v1.Literal;
3535
import org.flyte.api.v1.Node;
36+
import org.flyte.api.v1.Resources;
3637
import org.flyte.api.v1.RetryStrategy;
3738
import org.flyte.api.v1.TaskIdentifier;
3839
import org.flyte.api.v1.TypedInterface;
@@ -112,6 +113,11 @@ public DynamicJobSpec run(Map<String, Literal> inputs) {
112113
public RetryStrategy getRetries() {
113114
return RetryStrategy.builder().retries(sdkDynamicWorkflow.getRetries()).build();
114115
}
116+
117+
@Override
118+
public Resources getResources() {
119+
return sdkDynamicWorkflow.getResources().toIdl();
120+
}
115121
}
116122

117123
/**

flytekit-java/src/test/java/org/flyte/flytekit/SdkDynamicWorkflowTaskRegistrarTest.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static org.hamcrest.Matchers.hasSize;
2323

2424
import com.google.errorprone.annotations.Var;
25+
import java.util.HashMap;
2526
import java.util.List;
2627
import java.util.Map;
2728
import org.flyte.api.v1.Binding;
@@ -30,6 +31,7 @@
3031
import org.flyte.api.v1.Literal;
3132
import org.flyte.api.v1.OutputReference;
3233
import org.flyte.api.v1.Primitive;
34+
import org.flyte.api.v1.Resources;
3335
import org.flyte.api.v1.Scalar;
3436
import org.flyte.api.v1.TaskIdentifier;
3537
import org.flyte.api.v1.TypedInterface;
@@ -70,6 +72,14 @@ void shouldLoad() {
7072
.inputs(SdkLiteralTypes.integers().asSdkType("n").getVariableMap())
7173
.outputs(SdkLiteralTypes.integers().asSdkType("2n").getVariableMap())
7274
.build()));
75+
assertThat(
76+
dynWf.getResources(),
77+
equalTo(
78+
Resources.builder()
79+
.requests(resources("0.5", "2Gi"))
80+
.limits(resources("2", "5Gi"))
81+
.build()));
82+
;
7383
var spec =
7484
dynWf.run(Map.of("n", Literal.ofScalar(Scalar.ofPrimitive(Primitive.ofIntegerValue(3)))));
7585
assertThat(spec.nodes(), hasSize(3));
@@ -103,6 +113,14 @@ public SdkBindingData<Long> run(SdkWorkflowBuilder builder, SdkBindingData<Long>
103113
}
104114
return x;
105115
}
116+
117+
@Override
118+
public SdkResources getResources() {
119+
return SdkResources.builder()
120+
.requests(sdkResources("0.5", "2Gi"))
121+
.limits(sdkResources("2", "5Gi"))
122+
.build();
123+
}
106124
}
107125

108126
static class Mult2 extends SdkRunnableTask<SdkBindingData<Long>, SdkBindingData<Long>> {
@@ -117,4 +135,18 @@ public SdkBindingData<Long> run(SdkBindingData<Long> input) {
117135
return SdkBindingDataFactory.of(input.get() * 2);
118136
}
119137
}
138+
139+
private static Map<Resources.ResourceName, String> resources(String cpu, String memory) {
140+
Map<Resources.ResourceName, String> limits = new HashMap<>();
141+
limits.put(Resources.ResourceName.CPU, cpu);
142+
limits.put(Resources.ResourceName.MEMORY, memory);
143+
return limits;
144+
}
145+
146+
private static Map<SdkResources.ResourceName, String> sdkResources(String cpu, String memory) {
147+
Map<SdkResources.ResourceName, String> limits = new HashMap<>();
148+
limits.put(SdkResources.ResourceName.CPU, cpu);
149+
limits.put(SdkResources.ResourceName.MEMORY, memory);
150+
return limits;
151+
}
120152
}

jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -577,6 +577,7 @@ private static TaskTemplate createTaskTemplateForDynamicWorkflow(
577577
"{{.taskTemplatePath}}"))
578578
.image(image)
579579
.env(emptyList())
580+
.resources(task.getResources())
580581
.build();
581582

582583
return TaskTemplate.builder()

0 commit comments

Comments
 (0)