Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition chunk metadata to ChunkEmbeddings output #14462

Open
wants to merge 2 commits into
base: release/552-release-candidate
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ class ChunkEmbeddings(override val uid: String)
begin = chunk.begin,
end = chunk.end,
result = chunk.result,
metadata = Map(
metadata = chunk.metadata ++ Map(
"sentence" -> sentenceIdx.toString,
"chunk" -> chunkIdx.toString,
"token" -> chunk.result,
Expand Down
46 changes: 46 additions & 0 deletions src/test/scala/com/johnsnowlabs/nlp/AnnotationUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.johnsnowlabs.nlp

import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT
import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row}

object AnnotationUtils {

private lazy val spark = SparkAccessor.spark

implicit class AnnotationRow(annotation: Annotation) {

def toRow(): Row = {
Row(
annotation.annotatorType,
annotation.begin,
annotation.end,
annotation.result,
annotation.metadata,
annotation.embeddings)
}
}

implicit class DocumentRow(s: String) {
def toRow(metadata: Map[String, String] = Map("sentence" -> "0")): Row = {
Row(Seq(Annotation(DOCUMENT, 0, s.length, s, metadata).toRow()))
}
}

/** Create a DataFrame with the given column name, annotator type and annotations row Output
* column will be compatible with the Spark NLP annotators
*/
def createAnnotatorDataframe(
columnName: String,
annotatorType: String,
annotationsRow: Row): DataFrame = {
val metadataBuilder: MetadataBuilder = new MetadataBuilder()
metadataBuilder.putString("annotatorType", annotatorType)
val documentField =
StructField(columnName, Annotation.arrayType, nullable = false, metadataBuilder.build)
val struct = StructType(Array(documentField))
val rdd = spark.sparkContext.parallelize(Seq(annotationsRow))
spark.createDataFrame(rdd, struct)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@

package com.johnsnowlabs.nlp.embeddings

import com.johnsnowlabs.nlp.AnnotatorType.{CHUNK, DOCUMENT}
import com.johnsnowlabs.nlp.annotator.{Chunker, PerceptronModel}
import com.johnsnowlabs.nlp.annotators.sbd.pragmatic.SentenceDetector
import com.johnsnowlabs.nlp.annotators.{NGramGenerator, StopWordsCleaner, Tokenizer}
import com.johnsnowlabs.nlp.base.DocumentAssembler
import com.johnsnowlabs.nlp.util.io.ResourceHelper
import com.johnsnowlabs.nlp.{AnnotatorBuilder, EmbeddingsFinisher, Finisher}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorBuilder, EmbeddingsFinisher, Finisher}
import com.johnsnowlabs.tags.FastTest
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.Row
import org.scalatest.flatspec.AnyFlatSpec

class ChunkEmbeddingsTestSpec extends AnyFlatSpec {
Expand Down Expand Up @@ -266,4 +268,53 @@ class ChunkEmbeddingsTestSpec extends AnyFlatSpec {

}

"ChunkEmbeddings" should "return chunk metadata at output" taggedAs FastTest in {
import com.johnsnowlabs.nlp.AnnotationUtils._
val document = "Record: Bush Blue, ZIPCODE: XYZ84556222, phone: (911) 45 88".toRow()

val chunks = Row(
Seq(
Annotation(
CHUNK,
8,
16,
"Bush Blue",
Map("entity" -> "NAME", "sentence" -> "0", "chunk" -> "0", "confidence" -> "0.98"))
.toRow(),
Annotation(
CHUNK,
48,
58,
"(911) 45 88",
Map("entity" -> "PHONE", "sentence" -> "0", "chunk" -> "1", "confidence" -> "1.0"))
.toRow()))

val df = createAnnotatorDataframe("sentence", DOCUMENT, document)
.crossJoin(createAnnotatorDataframe("chunk", CHUNK, chunks))

val token = new Tokenizer()
.setInputCols("sentence")
.setOutputCol("token")

val wordEmbeddings = WordEmbeddingsModel
.pretrained()
.setInputCols("sentence", "token")
.setOutputCol("embeddings")

val chunkEmbeddings = new ChunkEmbeddings()
.setInputCols("chunk", "embeddings")
.setOutputCol("chunk_embeddings")
.setPoolingStrategy("AVERAGE")

val pipeline = new Pipeline().setStages(Array(token, wordEmbeddings, chunkEmbeddings))
val result_df = pipeline.fit(df).transform(df)
// result_df.selectExpr("explode(chunk_embeddings) as embeddings").show(false)
val annotations = Annotation.collect(result_df, "chunk_embeddings").flatten
assert(annotations.length == 2)
assert(annotations(0).metadata("entity") == "NAME")
assert(annotations(1).metadata("entity") == "PHONE")
val expectedMetadataKeys = Set("entity", "sentence", "chunk", "confidence")
assert(annotations.forall(anno => expectedMetadataKeys.forall(anno.metadata.contains)))
}

}
Loading