Skip to content

Commit

Permalink
fix: handle composite map types (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
h4rikris authored Jul 25, 2022
1 parent 997191c commit 5540d5f
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.MapEntry;
import com.google.protobuf.WireFormat;
import io.odpf.dagger.common.serde.typehandler.TypeHandler;
import io.odpf.dagger.common.serde.typehandler.RowFactory;
import io.odpf.dagger.common.serde.typehandler.TypeHandler;
import io.odpf.dagger.common.serde.typehandler.TypeHandlerFactory;
import io.odpf.dagger.common.serde.typehandler.TypeInformationFactory;
import io.odpf.dagger.common.serde.typehandler.repeated.RepeatedMessageHandler;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.types.Row;
Expand All @@ -27,6 +27,7 @@
public class MapHandler implements TypeHandler {

private Descriptors.FieldDescriptor fieldDescriptor;
private TypeHandler repeatedMessageHandler;

/**
* Instantiates a new Map proto handler.
Expand All @@ -35,6 +36,7 @@ public class MapHandler implements TypeHandler {
*/
public MapHandler(Descriptors.FieldDescriptor fieldDescriptor) {
this.fieldDescriptor = fieldDescriptor;
this.repeatedMessageHandler = new RepeatedMessageHandler(fieldDescriptor);
}

@Override
Expand All @@ -47,38 +49,44 @@ public DynamicMessage.Builder transformToProtoBuilder(DynamicMessage.Builder bui
if (!canHandle() || field == null) {
return builder;
}

if (field instanceof Map) {
convertFromMap(builder, (Map<String, String>) field);
}

if (field instanceof Object[]) {
convertFromRow(builder, (Object[]) field);
Map<?, ?> mapField = (Map<?, ?>) field;
ArrayList<Row> rows = new ArrayList<>();
for (Entry<?, ?> entry : mapField.entrySet()) {
rows.add(Row.of(entry.getKey(), entry.getValue()));
}
return repeatedMessageHandler.transformToProtoBuilder(builder, rows.toArray());
}

return builder;
return repeatedMessageHandler.transformToProtoBuilder(builder, field);
}

@Override
public Object transformFromPostProcessor(Object field) {
ArrayList<Row> rows = new ArrayList<>();
if (field != null) {
Map<String, String> mapField = (Map<String, String>) field;
for (Entry<String, String> entry : mapField.entrySet()) {
rows.add(getRowFromMap(entry));
if (field == null) {
return rows.toArray();
}
if (field instanceof Map) {
Map<String, ?> mapField = (Map<String, ?>) field;
for (Entry<String, ?> entry : mapField.entrySet()) {
Descriptors.FieldDescriptor keyDescriptor = fieldDescriptor.getMessageType().findFieldByName("key");
Descriptors.FieldDescriptor valueDescriptor = fieldDescriptor.getMessageType().findFieldByName("value");
TypeHandler handler = TypeHandlerFactory.getTypeHandler(keyDescriptor);
Object key = handler.transformFromPostProcessor(entry.getKey());
Object value = TypeHandlerFactory.getTypeHandler(valueDescriptor).transformFromPostProcessor(entry.getValue());
rows.add(Row.of(key, value));
}
return rows.toArray();
}
if (field instanceof List) {
return repeatedMessageHandler.transformFromPostProcessor(field);
}
return rows.toArray();
}

@Override
public Object transformFromProto(Object field) {
ArrayList<Row> rows = new ArrayList<>();
if (field != null) {
List<DynamicMessage> protos = (List<DynamicMessage>) field;
protos.forEach(proto -> rows.add(getRowFromMap(proto)));
}
return rows.toArray();
return repeatedMessageHandler.transformFromProto(field);
}

@Override
Expand Down Expand Up @@ -127,53 +135,4 @@ public Object transformToJson(Object field) {
public TypeInformation getTypeInformation() {
return Types.OBJECT_ARRAY(TypeInformationFactory.getRowType(fieldDescriptor.getMessageType()));
}

private Row getRowFromMap(Entry<String, String> entry) {
Row row = new Row(2);
row.setField(0, entry.getKey());
row.setField(1, entry.getValue());
return row;
}

private Row getRowFromMap(DynamicMessage proto) {
Row row = new Row(2);
row.setField(0, parse(proto, "key"));
row.setField(1, parse(proto, "value"));
return row;
}

private Object parse(DynamicMessage proto, String fieldName) {
Object field = proto.getField(proto.getDescriptorForType().findFieldByName(fieldName));
if (DynamicMessage.class.equals(field.getClass())) {
field = RowFactory.createRow((DynamicMessage) field);
}
return field;
}

private void convertFromRow(DynamicMessage.Builder builder, Object[] field) {
for (Object inputValue : field) {
Row inputRow = (Row) inputValue;
if (inputRow.getArity() != 2) {
throw new IllegalArgumentException("Row: " + inputRow.toString() + " of size: " + inputRow.getArity() + " cannot be converted to map");
}
MapEntry<String, String> mapEntry = MapEntry
.newDefaultInstance(fieldDescriptor.getMessageType(), WireFormat.FieldType.STRING, "", WireFormat.FieldType.STRING, "");
builder.addRepeatedField(fieldDescriptor,
mapEntry.toBuilder()
.setKey((String) inputRow.getField(0))
.setValue((String) inputRow.getField(1))
.buildPartial());
}
}

private void convertFromMap(DynamicMessage.Builder builder, Map<String, String> field) {
for (Entry<String, String> entry : field.entrySet()) {
MapEntry<String, String> mapEntry = MapEntry.newDefaultInstance(fieldDescriptor.getMessageType(), WireFormat.FieldType.STRING, "", WireFormat.FieldType.STRING, "");
builder.addRepeatedField(fieldDescriptor,
mapEntry.toBuilder()
.setKey(entry.getKey())
.setValue(entry.getValue())
.buildPartial());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType;
import org.junit.Assert;
import org.junit.Test;

import java.util.ArrayList;
Expand Down Expand Up @@ -81,13 +80,11 @@ public void shouldSetMapFieldIfStringMapPassed() {
inputMap.put("b", "456");

DynamicMessage.Builder returnedBuilder = mapHandler.transformToProtoBuilder(builder, inputMap);
List<MapEntry> entries = (List<MapEntry>) returnedBuilder.getField(mapFieldDescriptor);
List<DynamicMessage> entries = (List<DynamicMessage>) returnedBuilder.getField(mapFieldDescriptor);

assertEquals(2, entries.size());
assertEquals("a", entries.get(0).getAllFields().values().toArray()[0]);
assertEquals("123", entries.get(0).getAllFields().values().toArray()[1]);
assertEquals("b", entries.get(1).getAllFields().values().toArray()[0]);
assertEquals("456", entries.get(1).getAllFields().values().toArray()[1]);
assertArrayEquals(Arrays.asList("a", "123").toArray(), entries.get(0).getAllFields().values().toArray());
assertArrayEquals(Arrays.asList("b", "456").toArray(), entries.get(1).getAllFields().values().toArray());
}

@Test
Expand All @@ -111,28 +108,29 @@ public void shouldSetMapFieldIfArrayofObjectsHavingRowsWithStringFieldsPassed()
inputRows.add(inputRow2);

DynamicMessage.Builder returnedBuilder = mapHandler.transformToProtoBuilder(builder, inputRows.toArray());
List<MapEntry> entries = (List<MapEntry>) returnedBuilder.getField(mapFieldDescriptor);
List<DynamicMessage> entries = (List<DynamicMessage>) returnedBuilder.getField(mapFieldDescriptor);

assertEquals(2, entries.size());
assertEquals("a", entries.get(0).getAllFields().values().toArray()[0]);
assertEquals("123", entries.get(0).getAllFields().values().toArray()[1]);
assertEquals("b", entries.get(1).getAllFields().values().toArray()[0]);
assertEquals("456", entries.get(1).getAllFields().values().toArray()[1]);
assertArrayEquals(Arrays.asList("a", "123").toArray(), entries.get(0).getAllFields().values().toArray());
assertArrayEquals(Arrays.asList("b", "456").toArray(), entries.get(1).getAllFields().values().toArray());
}

@Test
public void shouldThrowExceptionIfRowsPassedAreNotOfArityTwo() {
Descriptors.FieldDescriptor mapFieldDescriptor = TestBookingLogMessage.getDescriptor().findFieldByName("metadata");
MapHandler mapHandler = new MapHandler(mapFieldDescriptor);
DynamicMessage.Builder builder = DynamicMessage.newBuilder(mapFieldDescriptor.getContainingType());
public void shouldHandleComplexTypeValuesForSerialization() throws InvalidProtocolBufferException {
Row inputValue1 = Row.of("12345", Row.of(Arrays.asList("a", "b")));
Row inputValue2 = Row.of(1234123, Row.of(Arrays.asList("d", "e")));
Object input = Arrays.asList(inputValue1, inputValue2).toArray();

ArrayList<Row> inputRows = new ArrayList<>();
Descriptors.FieldDescriptor intMessageDescriptor = TestComplexMap.getDescriptor().findFieldByName("int_message");
DynamicMessage.Builder builder = DynamicMessage.newBuilder(TestComplexMap.getDescriptor());

Row inputRow = new Row(3);
inputRows.add(inputRow);
IllegalArgumentException exception = Assert.assertThrows(IllegalArgumentException.class,
() -> mapHandler.transformToProtoBuilder(builder, inputRows.toArray()));
assertEquals("Row: +I[null, null, null] of size: 3 cannot be converted to map", exception.getMessage());
byte[] data = new MapHandler(intMessageDescriptor).transformToProtoBuilder(builder, input).build().toByteArray();
TestComplexMap actualMsg = TestComplexMap.parseFrom(data);
assertArrayEquals(Arrays.asList(12345L, 1234123L).toArray(), actualMsg.getIntMessageMap().keySet().toArray());
TestComplexMap.IdMessage idMessage = (TestComplexMap.IdMessage) actualMsg.getIntMessageMap().values().toArray()[0];
assertTrue(idMessage.getIdsList().containsAll(Arrays.asList("a", "b")));
idMessage = (TestComplexMap.IdMessage) actualMsg.getIntMessageMap().values().toArray()[1];
assertTrue(idMessage.getIdsList().containsAll(Arrays.asList("d", "e")));
}

@Test
Expand All @@ -158,12 +156,8 @@ public void shouldReturnArrayOfRowHavingFieldsSetAsInputMapAndOfSizeTwoForTransf

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromPostProcessor(inputMap));

assertEquals("a", ((Row) outputValues.get(0)).getField(0));
assertEquals("123", ((Row) outputValues.get(0)).getField(1));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
assertEquals("b", ((Row) outputValues.get(1)).getField(0));
assertEquals("456", ((Row) outputValues.get(1)).getField(1));
assertEquals(2, ((Row) outputValues.get(1)).getArity());
assertEquals(Row.of("a", "123"), outputValues.get(0));
assertEquals(Row.of("b", "456"), outputValues.get(1));
}

@Test
Expand Down Expand Up @@ -210,12 +204,8 @@ public void shouldReturnArrayOfRowHavingFieldsSetAsInputMapAndOfSizeTwoForTransf

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals("a", ((Row) outputValues.get(0)).getField(0));
assertEquals("123", ((Row) outputValues.get(0)).getField(1));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
assertEquals("b", ((Row) outputValues.get(1)).getField(0));
assertEquals("456", ((Row) outputValues.get(1)).getField(1));
assertEquals(2, ((Row) outputValues.get(1)).getArity());
assertEquals(Row.of("a", "123"), outputValues.get(0));
assertEquals(Row.of("b", "456"), outputValues.get(1));
}

@Test
Expand Down Expand Up @@ -247,16 +237,11 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals(1, ((Row) outputValues.get(0)).getField(0));
assertEquals("123", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
assertEquals("abc", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
assertEquals(2, ((Row) outputValues.get(1)).getField(0));
assertEquals("456", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(1));
assertEquals("efg", ((Row) ((Row) outputValues.get(1)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(1)).getArity());
Row mapEntry1 = Row.of(1, Row.of("123", "", "abc"));
Row mapEntry2 = Row.of(2, Row.of("456", "", "efg"));

assertEquals(mapEntry1, outputValues.get(0));
assertEquals(mapEntry2, outputValues.get(1));
}

@Test
Expand All @@ -271,11 +256,8 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals(0, ((Row) outputValues.get(0)).getField(0));
assertEquals("123", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
assertEquals("abc", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
Row expected = Row.of(0, Row.of("123", "", "abc"));
assertEquals(expected, outputValues.get(0));
}

@Test
Expand All @@ -290,11 +272,9 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals(1, ((Row) outputValues.get(0)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
Row expected = Row.of(1, Row.of("", "", ""));

assertEquals(expected, outputValues.get(0));
}

@Test
Expand All @@ -309,11 +289,9 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals(0, ((Row) outputValues.get(0)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
Row expected = Row.of(0, Row.of("", "", ""));

assertEquals(expected, outputValues.get(0));
}

@Test
Expand All @@ -328,11 +306,8 @@ public void shouldReturnArrayOfRowsHavingFieldsSetAsInputMapHavingComplexDataFie

List<Object> outputValues = Arrays.asList((Object[]) mapHandler.transformFromProto(dynamicMessage.getField(mapFieldDescriptor)));

assertEquals(0, ((Row) outputValues.get(0)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(0));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(1));
assertEquals("", ((Row) ((Row) outputValues.get(0)).getField(1)).getField(2));
assertEquals(2, ((Row) outputValues.get(0)).getArity());
Row expected = Row.of(0, Row.of("", "", ""));
assertEquals(expected, outputValues.get(0));
}

@Test
Expand Down
5 changes: 5 additions & 0 deletions dagger-common/src/test/proto/TestMessage.proto
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ message TestEnumMessage {
}

message TestComplexMap {
message IdMessage {
repeated string ids = 1;
}
map<int32, TestMessage> complex_map = 1;
map<int64, IdMessage> int_message = 2;
map<string, IdMessage> string_message = 3;
}

message TestRepeatedPrimitiveMessage {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ public void shouldGetCorrectJsonPayloadForComplexFields() throws InvalidProtocol
DynamicMessage dynamicMessage = DynamicMessage.parseFrom(complexMapMessage.getDescriptor(), complexMapMessage.toByteArray());
RowManager rowManager = getRowManagerForMessage(dynamicMessage);

String expectedJsonPayload = "{\"complex_map\":[{\"key\":1,\"value\":{\"order_number\":\"order-number-123\",\"order_url\":\"https://order-url\",\"order_details\":\"pickup\"}}]}";
String expectedJsonPayload = "{\"complex_map\":[{\"key\":1,\"value\":{\"order_number\":\"order-number-123\",\"order_url\":\"https://order-url\",\"order_details\":\"pickup\"}}],\"int_message\":[],\"string_message\":[]}";
String actualJsonPayload = (String) jsonPayloadFunction.getResult(rowManager);

assertEquals(expectedJsonPayload, actualJsonPayload);
Expand Down

0 comments on commit 5540d5f

Please sign in to comment.