diff --git a/src/main/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddings.scala b/src/main/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddings.scala index c374e62ca80dfb..b69f499ad973e2 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddings.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddings.scala @@ -260,7 +260,7 @@ class ChunkEmbeddings(override val uid: String) begin = chunk.begin, end = chunk.end, result = chunk.result, - metadata = Map( + metadata = chunk.metadata ++ Map( "sentence" -> sentenceIdx.toString, "chunk" -> chunkIdx.toString, "token" -> chunk.result, diff --git a/src/test/scala/com/johnsnowlabs/nlp/AnnotationUtils.scala b/src/test/scala/com/johnsnowlabs/nlp/AnnotationUtils.scala new file mode 100644 index 00000000000000..4cefc51979dc90 --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/AnnotationUtils.scala @@ -0,0 +1,46 @@ +package com.johnsnowlabs.nlp + +import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT +import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} +import org.apache.spark.sql.{DataFrame, Row} + +object AnnotationUtils { + + private lazy val spark = SparkAccessor.spark + + implicit class AnnotationRow(annotation: Annotation) { + + def toRow(): Row = { + Row( + annotation.annotatorType, + annotation.begin, + annotation.end, + annotation.result, + annotation.metadata, + annotation.embeddings) + } + } + + implicit class DocumentRow(s: String) { + def toRow(metadata: Map[String, String] = Map("sentence" -> "0")): Row = { + Row(Seq(Annotation(DOCUMENT, 0, s.length, s, metadata).toRow())) + } + } + + /** Create a DataFrame with the given column name, annotator type and annotations row Output + * column will be compatible with the Spark NLP annotators + */ + def createAnnotatorDataframe( + columnName: String, + annotatorType: String, + annotationsRow: Row): DataFrame = { + val metadataBuilder: MetadataBuilder = new MetadataBuilder() + metadataBuilder.putString("annotatorType", annotatorType) + val documentField = + StructField(columnName, Annotation.arrayType, nullable = false, metadataBuilder.build) + val struct = StructType(Array(documentField)) + val rdd = spark.sparkContext.parallelize(Seq(annotationsRow)) + spark.createDataFrame(rdd, struct) + } + +} diff --git a/src/test/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddingsTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddingsTestSpec.scala index b3f1a491088a68..7c907b1fb37163 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddingsTestSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/embeddings/ChunkEmbeddingsTestSpec.scala @@ -16,14 +16,16 @@ package com.johnsnowlabs.nlp.embeddings +import com.johnsnowlabs.nlp.AnnotatorType.{CHUNK, DOCUMENT} import com.johnsnowlabs.nlp.annotator.{Chunker, PerceptronModel} import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector import com.johnsnowlabs.nlp.annotators.{NGramGenerator, StopWordsCleaner, Tokenizer} import com.johnsnowlabs.nlp.base.DocumentAssembler import com.johnsnowlabs.nlp.util.io.ResourceHelper -import com.johnsnowlabs.nlp.{AnnotatorBuilder, EmbeddingsFinisher, Finisher} +import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder, EmbeddingsFinisher, Finisher} import com.johnsnowlabs.tags.FastTest import org.apache.spark.ml.Pipeline +import org.apache.spark.sql.Row import org.scalatest.flatspec.AnyFlatSpec class ChunkEmbeddingsTestSpec extends AnyFlatSpec { @@ -266,4 +268,53 @@ class ChunkEmbeddingsTestSpec extends AnyFlatSpec { } + "ChunkEmbeddings" should "return chunk metadata at output" taggedAs FastTest in { + import com.johnsnowlabs.nlp.AnnotationUtils._ + val document = "Record: Bush Blue, ZIPCODE: XYZ84556222, phone: (911) 45 88".toRow() + + val chunks = Row( + Seq( + Annotation( + CHUNK, + 8, + 16, + "Bush Blue", + Map("entity" -> "NAME", "sentence" -> "0", "chunk" -> "0", "confidence" -> "0.98")) + .toRow(), + Annotation( + CHUNK, + 48, + 58, + "(911) 45 88", + Map("entity" -> "PHONE", "sentence" -> "0", "chunk" -> "1", "confidence" -> "1.0")) + .toRow())) + + val df = createAnnotatorDataframe("sentence", DOCUMENT, document) + .crossJoin(createAnnotatorDataframe("chunk", CHUNK, chunks)) + + val token = new Tokenizer() + .setInputCols("sentence") + .setOutputCol("token") + + val wordEmbeddings = WordEmbeddingsModel + .pretrained() + .setInputCols("sentence", "token") + .setOutputCol("embeddings") + + val chunkEmbeddings = new ChunkEmbeddings() + .setInputCols("chunk", "embeddings") + .setOutputCol("chunk_embeddings") + .setPoolingStrategy("AVERAGE") + + val pipeline = new Pipeline().setStages(Array(token, wordEmbeddings, chunkEmbeddings)) + val result_df = pipeline.fit(df).transform(df) + // result_df.selectExpr("explode(chunk_embeddings) as embeddings").show(false) + val annotations = Annotation.collect(result_df, "chunk_embeddings").flatten + assert(annotations.length == 2) + assert(annotations(0).metadata("entity") == "NAME") + assert(annotations(1).metadata("entity") == "PHONE") + val expectedMetadataKeys = Set("entity", "sentence", "chunk", "confidence") + assert(annotations.forall(anno => expectedMetadataKeys.forall(anno.metadata.contains))) + } + }