Skip to content

Commit d5a06b4

Browse files
committed
[FLINK-4307] [streaming API] Restore ListState behavior for user-facing ListStates
1 parent 31837a7 commit d5a06b4

File tree

3 files changed

+104
-5
lines changed

3 files changed

+104
-5
lines changed

flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContext.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) {
141141
requireNonNull(stateProperties, "The state properties must not be null");
142142
try {
143143
stateProperties.initializeSerializerUnlessSet(getExecutionConfig());
144-
return operator.getPartitionedState(stateProperties);
144+
ListState<T> originalState = operator.getPartitionedState(stateProperties);
145+
return new UserFacingListState<T>(originalState);
145146
} catch (Exception e) {
146147
throw new RuntimeException("Error while getting state", e);
147148
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.streaming.api.operators;
20+
21+
import org.apache.flink.api.common.state.ListState;
22+
23+
import java.util.Collections;
24+
25+
/**
26+
* Simple wrapper list state that exposes empty state properly as an empty list.
27+
*
28+
* @param <T> The type of elements in the list state.
29+
*/
30+
class UserFacingListState<T> implements ListState<T> {
31+
32+
private final ListState<T> originalState;
33+
34+
private final Iterable<T> emptyState = Collections.emptyList();
35+
36+
UserFacingListState(ListState<T> originalState) {
37+
this.originalState = originalState;
38+
}
39+
40+
// ------------------------------------------------------------------------
41+
42+
@Override
43+
public Iterable<T> get() throws Exception {
44+
Iterable<T> original = originalState.get();
45+
return original != null ? original : emptyState;
46+
}
47+
48+
@Override
49+
public void add(T value) throws Exception {
50+
originalState.add(value);
51+
}
52+
53+
@Override
54+
public void clear() {
55+
originalState.clear();
56+
}
57+
}

flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StreamingRuntimeContextTest.java

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,19 @@
2222
import org.apache.flink.api.common.TaskInfo;
2323
import org.apache.flink.api.common.accumulators.Accumulator;
2424
import org.apache.flink.api.common.functions.ReduceFunction;
25+
import org.apache.flink.api.common.state.ListState;
2526
import org.apache.flink.api.common.state.ListStateDescriptor;
2627
import org.apache.flink.api.common.state.ReducingStateDescriptor;
2728
import org.apache.flink.api.common.state.StateDescriptor;
2829
import org.apache.flink.api.common.state.ValueStateDescriptor;
2930
import org.apache.flink.api.common.typeutils.TypeSerializer;
31+
import org.apache.flink.api.common.typeutils.base.StringSerializer;
32+
import org.apache.flink.api.common.typeutils.base.VoidSerializer;
3033
import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
3134
import org.apache.flink.core.fs.Path;
3235
import org.apache.flink.runtime.execution.Environment;
3336

37+
import org.apache.flink.runtime.state.memory.MemListState;
3438
import org.junit.Test;
3539

3640
import org.mockito.invocation.InvocationOnMock;
@@ -54,7 +58,7 @@ public void testValueStateInstantiation() throws Exception {
5458
final AtomicReference<Object> descriptorCapture = new AtomicReference<>();
5559

5660
StreamingRuntimeContext context = new StreamingRuntimeContext(
57-
createMockOp(descriptorCapture, config),
61+
createDescriptorCapturingMockOp(descriptorCapture, config),
5862
createMockEnvironment(),
5963
Collections.<String, Accumulator<?, ?>>emptyMap());
6064

@@ -78,7 +82,7 @@ public void testReduceingStateInstantiation() throws Exception {
7882
final AtomicReference<Object> descriptorCapture = new AtomicReference<>();
7983

8084
StreamingRuntimeContext context = new StreamingRuntimeContext(
81-
createMockOp(descriptorCapture, config),
85+
createDescriptorCapturingMockOp(descriptorCapture, config),
8286
createMockEnvironment(),
8387
Collections.<String, Accumulator<?, ?>>emptyMap());
8488

@@ -107,7 +111,7 @@ public void testListStateInstantiation() throws Exception {
107111
final AtomicReference<Object> descriptorCapture = new AtomicReference<>();
108112

109113
StreamingRuntimeContext context = new StreamingRuntimeContext(
110-
createMockOp(descriptorCapture, config),
114+
createDescriptorCapturingMockOp(descriptorCapture, config),
111115
createMockEnvironment(),
112116
Collections.<String, Accumulator<?, ?>>emptyMap());
113117

@@ -121,13 +125,29 @@ public void testListStateInstantiation() throws Exception {
121125
assertTrue(serializer instanceof KryoSerializer);
122126
assertTrue(((KryoSerializer<?>) serializer).getKryo().getRegistration(Path.class).getId() > 0);
123127
}
128+
129+
@Test
130+
public void testListStateReturnsEmptyListByDefault() throws Exception {
131+
132+
StreamingRuntimeContext context = new StreamingRuntimeContext(
133+
createPlainMockOp(),
134+
createMockEnvironment(),
135+
Collections.<String, Accumulator<?, ?>>emptyMap());
136+
137+
ListStateDescriptor<String> descr = new ListStateDescriptor<>("name", String.class);
138+
ListState<String> state = context.getListState(descr);
139+
140+
Iterable<String> value = state.get();
141+
assertNotNull(value);
142+
assertFalse(value.iterator().hasNext());
143+
}
124144

125145
// ------------------------------------------------------------------------
126146
//
127147
// ------------------------------------------------------------------------
128148

129149
@SuppressWarnings("unchecked")
130-
private static AbstractStreamOperator<?> createMockOp(
150+
private static AbstractStreamOperator<?> createDescriptorCapturingMockOp(
131151
final AtomicReference<Object> ref, final ExecutionConfig config) throws Exception {
132152

133153
AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class);
@@ -145,6 +165,27 @@ public Object answer(InvocationOnMock invocationOnMock) throws Throwable {
145165

146166
return operatorMock;
147167
}
168+
169+
@SuppressWarnings("unchecked")
170+
private static AbstractStreamOperator<?> createPlainMockOp() throws Exception {
171+
172+
AbstractStreamOperator<?> operatorMock = mock(AbstractStreamOperator.class);
173+
when(operatorMock.getExecutionConfig()).thenReturn(new ExecutionConfig());
174+
175+
when(operatorMock.getPartitionedState(any(ListStateDescriptor.class))).thenAnswer(
176+
new Answer<ListState<String>>() {
177+
178+
@Override
179+
public ListState<String> answer(InvocationOnMock invocationOnMock) throws Throwable {
180+
ListStateDescriptor<String> descr =
181+
(ListStateDescriptor<String>) invocationOnMock.getArguments()[0];
182+
return new MemListState<String, Void, String>(
183+
StringSerializer.INSTANCE, VoidSerializer.INSTANCE, descr);
184+
}
185+
});
186+
187+
return operatorMock;
188+
}
148189

149190
private static Environment createMockEnvironment() {
150191
Environment env = mock(Environment.class);

0 commit comments

Comments
 (0)