diff --git a/python/sparknlp/annotator/seq2seq/__init__.py b/python/sparknlp/annotator/seq2seq/__init__.py index 5abf7be0d12dfb..87ff1b42f02821 100644 --- a/python/sparknlp/annotator/seq2seq/__init__.py +++ b/python/sparknlp/annotator/seq2seq/__init__.py @@ -19,3 +19,4 @@ from sparknlp.annotator.seq2seq.bart_transformer import * from sparknlp.annotator.seq2seq.llama2_transformer import * from sparknlp.annotator.seq2seq.m2m100_transformer import * +from sparknlp.annotator.seq2seq.olmo_transformer import * diff --git a/python/sparknlp/annotator/seq2seq/olmo_transformer.py b/python/sparknlp/annotator/seq2seq/olmo_transformer.py new file mode 100644 index 00000000000000..3e0eb2269c3a83 --- /dev/null +++ b/python/sparknlp/annotator/seq2seq/olmo_transformer.py @@ -0,0 +1,326 @@ +# Copyright 2017-2022 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Contains classes for the OLMoTransformer.""" + +from sparknlp.common import * + + +class OLMoTransformer(AnnotatorModel, HasBatchedAnnotate, HasEngine): + """OLMo: Open Language Models + + OLMo is a series of Open Language Models designed to enable the science of language models. + The OLMo models are trained on the Dolma dataset. We release all code, checkpoints, logs + (coming soon), and details involved in training these models. + + Pretrained models can be loaded with :meth:`.pretrained` of the companion + object: + + >>> olmo = OLMoTransformer.pretrained() \\ + ... .setInputCols(["document"]) \\ + ... .setOutputCol("generation") + + + The default model is ``"llam2-7b"``, if no name is provided. For available + pretrained models please see the `Models Hub + `__. + + ====================== ====================== + Input Annotation types Output Annotation type + ====================== ====================== + ``DOCUMENT`` ``DOCUMENT`` + ====================== ====================== + + Parameters + ---------- + configProtoBytes + ConfigProto from tensorflow, serialized into byte array. + minOutputLength + Minimum length of the sequence to be generated, by default 0 + maxOutputLength + Maximum length of output text, by default 20 + doSample + Whether or not to use sampling; use greedy decoding otherwise, by default False + temperature + The value used to module the next token probabilities, by default 1.0 + topK + The number of highest probability vocabulary tokens to keep for + top-k-filtering, by default 50 + topP + Top cumulative probability for vocabulary tokens, by default 1.0 + + If set to float < 1, only the most probable tokens with probabilities + that add up to ``topP`` or higher are kept for generation. + repetitionPenalty + The parameter for repetition penalty, 1.0 means no penalty. , by default + 1.0 + noRepeatNgramSize + If set to int > 0, all ngrams of that size can only occur once, by + default 0 + ignoreTokenIds + A list of token ids which are ignored in the decoder's output, by + default [] + + Notes + ----- + This is a very computationally expensive module especially on larger + sequence. The use of an accelerator such as GPU is recommended. + + References + ---------- + - `OLMo Project Page. + `__ + - `OLMO GitHub Repository. + `__ + - `OLMo: Accelerating the Science of Language Models + `__ + + **Paper Abstract:** + + *Language models (LMs) have become ubiquitous in both NLP research and in commercial product offerings. + As their commercial importance has surged, the most powerful models have become closed off, gated behind + proprietary interfaces, with important details of their training data, architectures, and development + undisclosed. Given the importance of these details in scientifically studying these models, including + their biases and potential risks, we believe it is essential for the research community to have access + to powerful, truly open LMs. To this end, this technical report details the first release of OLMo, + a state-of-the-art, truly Open Language Model and its framework to build and study the science of + language modeling. Unlike most prior efforts that have only released model weights and inference code, + we release OLMo and the whole framework, including training data and training and evaluation code. + We hope this release will empower and strengthen the open research community and inspire a new wave + of innovation.* + + Examples + -------- + >>> import sparknlp + >>> from sparknlp.base import * + >>> from sparknlp.annotator import * + >>> from pyspark.ml import Pipeline + >>> documentAssembler = DocumentAssembler() \\ + ... .setInputCol("text") \\ + ... .setOutputCol("documents") + >>> olmo = OLMoTransformer.pretrained("olmo-7b") \\ + ... .setInputCols(["documents"]) \\ + ... .setMaxOutputLength(50) \\ + ... .setOutputCol("generation") + >>> pipeline = Pipeline().setStages([documentAssembler, olmo]) + >>> data = spark.createDataFrame([["My name is Leonardo."]]).toDF("text") + >>> result = pipeline.fit(data).transform(data) + >>> result.select("summaries.generation").show(truncate=False) + +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + |result | + +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + |[My name is Leonardo . I am a student of the University of California, Berkeley. I am interested in the field of Artificial Intelligence and its applications in the real world. I have a strong | + | passion for learning and am always looking for ways to improve my knowledge and skills] | + -----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + """ + + name = "OLMoTransformer" + + inputAnnotatorTypes = [AnnotatorType.DOCUMENT] + + outputAnnotatorType = AnnotatorType.DOCUMENT + + configProtoBytes = Param(Params._dummy(), "configProtoBytes", + "ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()", + TypeConverters.toListInt) + + minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated", + typeConverter=TypeConverters.toInt) + + maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text", + typeConverter=TypeConverters.toInt) + + doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise", + typeConverter=TypeConverters.toBoolean) + + temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities", + typeConverter=TypeConverters.toFloat) + + topK = Param(Params._dummy(), "topK", + "The number of highest probability vocabulary tokens to keep for top-k-filtering", + typeConverter=TypeConverters.toInt) + + topP = Param(Params._dummy(), "topP", + "If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation", + typeConverter=TypeConverters.toFloat) + + repetitionPenalty = Param(Params._dummy(), "repetitionPenalty", + "The parameter for repetition penalty. 1.0 means no penalty. See `this paper `__ for more details", + typeConverter=TypeConverters.toFloat) + + noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize", + "If set to int > 0, all ngrams of that size can only occur once", + typeConverter=TypeConverters.toInt) + + ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds", + "A list of token ids which are ignored in the decoder's output", + typeConverter=TypeConverters.toListInt) + + def setIgnoreTokenIds(self, value): + """A list of token ids which are ignored in the decoder's output. + + Parameters + ---------- + value : List[int] + The words to be filtered out + """ + return self._set(ignoreTokenIds=value) + + def setConfigProtoBytes(self, b): + """Sets configProto from tensorflow, serialized into byte array. + + Parameters + ---------- + b : List[int] + ConfigProto from tensorflow, serialized into byte array + """ + return self._set(configProtoBytes=b) + + def setMinOutputLength(self, value): + """Sets minimum length of the sequence to be generated. + + Parameters + ---------- + value : int + Minimum length of the sequence to be generated + """ + return self._set(minOutputLength=value) + + def setMaxOutputLength(self, value): + """Sets maximum length of output text. + + Parameters + ---------- + value : int + Maximum length of output text + """ + return self._set(maxOutputLength=value) + + def setDoSample(self, value): + """Sets whether or not to use sampling, use greedy decoding otherwise. + + Parameters + ---------- + value : bool + Whether or not to use sampling; use greedy decoding otherwise + """ + return self._set(doSample=value) + + def setTemperature(self, value): + """Sets the value used to module the next token probabilities. + + Parameters + ---------- + value : float + The value used to module the next token probabilities + """ + return self._set(temperature=value) + + def setTopK(self, value): + """Sets the number of highest probability vocabulary tokens to keep for + top-k-filtering. + + Parameters + ---------- + value : int + Number of highest probability vocabulary tokens to keep + """ + return self._set(topK=value) + + def setTopP(self, value): + """Sets the top cumulative probability for vocabulary tokens. + + If set to float < 1, only the most probable tokens with probabilities + that add up to ``topP`` or higher are kept for generation. + + Parameters + ---------- + value : float + Cumulative probability for vocabulary tokens + """ + return self._set(topP=value) + + def setRepetitionPenalty(self, value): + """Sets the parameter for repetition penalty. 1.0 means no penalty. + + Parameters + ---------- + value : float + The repetition penalty + + References + ---------- + See `Ctrl: A Conditional Transformer Language Model For Controllable + Generation `__ for more details. + """ + return self._set(repetitionPenalty=value) + + def setNoRepeatNgramSize(self, value): + """Sets size of n-grams that can only occur once. + + If set to int > 0, all ngrams of that size can only occur once. + + Parameters + ---------- + value : int + N-gram size can only occur once + """ + return self._set(noRepeatNgramSize=value) + + @keyword_only + def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.OLMoTransformer", java_model=None): + super(OLMoTransformer, self).__init__(classname=classname, java_model=java_model) + self._setDefault(minOutputLength=0, maxOutputLength=20, doSample=False, temperature=0.6, topK=50, topP=0.9, + repetitionPenalty=1.0, noRepeatNgramSize=0, ignoreTokenIds=[], batchSize=1) + + @staticmethod + def loadSavedModel(folder, spark_session): + """Loads a locally saved model. + + Parameters + ---------- + folder : str + Folder of the saved model + spark_session : pyspark.sql.SparkSession + The current SparkSession + + Returns + ------- + OLMoTransformer + The restored model + """ + from sparknlp.internal import _OLMoLoader + jModel = _OLMoLoader(folder, spark_session._jsparkSession)._java_obj + return OLMoTransformer(java_model=jModel) + + @staticmethod + def pretrained(name="olmo-1b", lang="en", remote_loc=None): + """Downloads and loads a pretrained model. + + Parameters + ---------- + name : str, optional + Name of the pretrained model, by default "olmo-7b" + lang : str, optional + Language of the pretrained model, by default "en" + remote_loc : str, optional + Optional remote address of the resource, by default None. Will use + Spark NLPs repositories otherwise. + + Returns + ------- + OLMoTransformer + The restored model + """ + from sparknlp.pretrained import ResourceDownloader + return ResourceDownloader.downloadModel(OLMoTransformer, name, lang, remote_loc) diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index 54180480bdce63..48135df0ddab39 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -210,7 +210,10 @@ def __init__(self, path, jspark): super(_MPNetLoader, self).__init__( "com.johnsnowlabs.nlp.embeddings.MPNetEmbeddings.loadSavedModel", path, jspark) - +class _OLMoLoader(ExtendedJavaWrapper): + def __init__(self, path, jspark): + super(_OLMoLoader, self).__init__( + "com.johnsnowlabs.nlp.annotators.seq2seq.OLMoTransformer.loadSavedModel", path, jspark) class _RoBertaLoader(ExtendedJavaWrapper): def __init__(self, path, jspark): super(_RoBertaLoader, self).__init__("com.johnsnowlabs.nlp.embeddings.RoBertaEmbeddings.loadSavedModel", path, diff --git a/python/test/annotator/seq2seq/olmo_transformer_test.py b/python/test/annotator/seq2seq/olmo_transformer_test.py new file mode 100644 index 00000000000000..8c09b3cfa2e4cf --- /dev/null +++ b/python/test/annotator/seq2seq/olmo_transformer_test.py @@ -0,0 +1,47 @@ +# Copyright 2017-2024 John Snow Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import pytest + +from sparknlp.annotator import * +from sparknlp.base import * +from test.util import SparkContextForTest + + +@pytest.mark.slow +class OLMoTransformerTextGenerationTestSpec(unittest.TestCase): + def setUp(self): + self.spark = SparkContextForTest.spark + + def runTest(self): + data = self.spark.createDataFrame([ + [1, """Leonardo Da Vinci invented the microscope?""".strip().replace("\n", " ")]]).toDF("id", "text") + + document_assembler = DocumentAssembler() \ + .setInputCol("text") \ + .setOutputCol("documents") + + olmo = OLMoTransformer \ + .pretrained() \ + .setMaxOutputLength(50) \ + .setDoSample(False) \ + .setInputCols(["documents"]) \ + .setOutputCol("generation") + + pipeline = Pipeline().setStages([document_assembler, olmo]) + results = pipeline.fit(data).transform(data) + + results.select("generation.result").show(truncate=False) + diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/OLMo.scala b/src/main/scala/com/johnsnowlabs/ml/ai/OLMo.scala new file mode 100644 index 00000000000000..fd9e330abf8690 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/ml/ai/OLMo.scala @@ -0,0 +1,361 @@ +/* + * Copyright 2017 - 2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.ml.ai + +import ai.onnxruntime.{OnnxTensor, OrtEnvironment, OrtSession} +import com.johnsnowlabs.ml.ai.util.Generation.{Generate, GenerationConfig} +import com.johnsnowlabs.ml.onnx.OnnxSession +import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers +import com.johnsnowlabs.ml.onnx.TensorResources.implicits._ +import com.johnsnowlabs.ml.tensorflow.sentencepiece.SentencePieceWrapper +import com.johnsnowlabs.nlp.Annotation +import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT +import com.johnsnowlabs.nlp.annotators.common.SentenceSplit +import com.johnsnowlabs.nlp.annotators.tokenizer.bpe.{BpeTokenizer, OLMoTokenizer} +import org.tensorflow.{Session, Tensor} + +import scala.collection.JavaConverters._ + +private[johnsnowlabs] class OLMo( + val onnxWrappers: DecoderWrappers, + merges: Map[(String, String), Int], + vocabulary: Map[String, Int], + generationConfig: GenerationConfig) + extends Serializable + with Generate { + + private val onnxSessionOptions: Map[String, String] = new OnnxSession().getSessionOptions + val bpeTokenizer: OLMoTokenizer = BpeTokenizer + .forModel("olmo", merges = merges, vocab = vocabulary, padWithSequenceTokens = false) + .asInstanceOf[OLMoTokenizer] + private val GenerationConfig( + bosTokenId: Int, + paddingTokenId: Int, + eosTokenId: Int, + vocabSize: Int, + beginSuppressTokens, + suppressTokenIds, + forcedDecoderIds) = + generationConfig + + /** Decode a sequence of sentences + * @param sentences + * Sequence of sentences + * @return + * Sequence of decoded sentences + */ + def decode(sentences: Array[Array[Int]]): Seq[String] = { + sentences.map(s => bpeTokenizer.decodeTokens(s.map(_.toInt))) + } + + /** Encode a sequence of sentences + * @param sentences + * Sequence of sentences + * @return + * Sequence of encoded sentences + */ + def encode(sentences: Seq[Annotation]): Seq[Array[Int]] = { + SentenceSplit + .unpack(sentences) + .map(s => { + val sentWithTask = s + bpeTokenizer + .tokenize(sentWithTask) + .map(bpeTokenizer.encode) + .flatMap(_.map(_.pieceId)) + }) + } + + def tag( + batch: Seq[Array[Int]], + minOutputLength: Int, + maxOutputLength: Int, + doSample: Boolean, + temperature: Double, + topK: Int, + topP: Double, + repetitionPenalty: Double, + noRepeatNgramSize: Int, + randomSeed: Option[Long], + ignoreTokenIds: Array[Int] = Array(), + beamSize: Int, + maxInputLength: Int): Array[Array[Int]] = { + val (encoderSession, env) = onnxWrappers.decoder.getSession(onnxSessionOptions) + val ignoreTokenIdsInt = ignoreTokenIds + val expandedDecoderInputsVals = batch + val sequencesLength = expandedDecoderInputsVals.map(x => x.length).toArray + val maxSentenceLength = sequencesLength.max // - curLen + + val numReturn_sequences = 1 + // from config + + var effectiveBatch_size = 1 + var effectiveBatch_mult = 1 + + if (doSample) { + effectiveBatch_size = expandedDecoderInputsVals.length * numReturn_sequences + effectiveBatch_mult = numReturn_sequences + } else { + effectiveBatch_size = expandedDecoderInputsVals.length + effectiveBatch_mult = 1 + } + + // Run the prompt through the decoder and get the past +// val decoderOutputs = +// generateGreedyOnnx( +// expandedDecoderInputsVals.toArray, +// (encoderSession, env), +// maxOutputLength) + + // dummy tensors for decoder encode state and attention mask + val decoderEncoderStateTensors = Right(OnnxTensor.createTensor(env, Array(0))) + val encoderAttentionMaskTensors = Right(OnnxTensor.createTensor(env, Array(1))) + + // output with beam search + val modelOutputs = generate( + batch, + decoderEncoderStateTensors, + encoderAttentionMaskTensors, + expandedDecoderInputsVals.toArray, + maxOutputLength + maxSentenceLength, + minOutputLength, + doSample, + beamSize, + 1, + temperature, + topK, + topP, + repetitionPenalty, + noRepeatNgramSize, + this.vocabSize, + this.eosTokenId, + this.paddingTokenId, + randomSeed, + ignoreTokenIdsInt, + Right((env, encoderSession)), + applySoftmax = false) + +// decoderOutputs + modelOutputs + } + + def predict( + sentences: Seq[Annotation], + batchSize: Int, + minOutputLength: Int, + maxOutputLength: Int, + doSample: Boolean, + temperature: Double, + topK: Int, + topP: Double, + repetitionPenalty: Double, + noRepeatNgramSize: Int, + randomSeed: Option[Long] = None, + ignoreTokenIds: Array[Int] = Array(), + beamSize: Int, + maxInputLength: Int): Seq[Annotation] = { + + val batchDecoder = sentences.grouped(batchSize).toArray.flatMap { batch => + val batchSP = encode(batch) + val spIds = tag( + batchSP, + minOutputLength, + maxOutputLength, + doSample, + temperature, + topK, + topP, + repetitionPenalty, + noRepeatNgramSize, + randomSeed, + ignoreTokenIds, + beamSize, + maxInputLength) + + decode(spIds) + + } + + var sentBegin, nextSentEnd = 0 + val annotations = batchDecoder.zip(sentences).map { case (content, sent) => + nextSentEnd += content.length - 1 + val annots = new Annotation( + annotatorType = DOCUMENT, + begin = sentBegin, + end = nextSentEnd, + result = content, + metadata = sent.metadata) + sentBegin += nextSentEnd + 1 + annots + } + annotations + } + + private def getDecoderOutputsWithPast( + inputIds: Array[Array[Int]], + decoderPast: Map[String, OnnxTensor], + onnxSession: (OrtSession, OrtEnvironment)) + : (Array[Array[Float]], Map[String, OnnxTensor]) = { + val (session, env) = onnxSession + + val lastTokens: Array[Array[Long]] = + inputIds.map { tokenIds => + Array(tokenIds.last.toLong) + } + + val lastTokensTensor: OnnxTensor = + OnnxTensor.createTensor(env, lastTokens) + val decoderAttentionMask: OnnxTensor = + OnnxTensor.createTensor(env, lastTokens.map(_.map(_ => 1L))) + val decoderWithPastInputs: java.util.Map[String, OnnxTensor] = (Map( + OnnxSignatures.decoderInputIDs -> lastTokensTensor, + OnnxSignatures.decoderAttentionMask -> decoderAttentionMask) ++ decoderPast).asJava + val sessionOutput = session.run(decoderWithPastInputs) + val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) + val decoderPresent = sessionOutput.getOnnxTensors(OnnxSignatures.decoderPresent) + lastTokensTensor.close() + val batchLogits = logits.grouped(vocabSize).toArray + (batchLogits, decoderPresent) + + } + + override def getModelOutput( + encoderInputIds: Seq[Array[Int]], + decoderInputIds: Seq[Array[Int]], + decoderEncoderStateTensors: Either[Tensor, OnnxTensor], + encoderAttentionMaskTensors: Either[Tensor, OnnxTensor], + maxLength: Int, + session: Either[Session, (OrtEnvironment, OrtSession)]): Array[Array[Float]] = { + + session.fold( + tfSession => { + // not implemented yet + Array() + }, + onnxSession => { + val (env, decoderSession) = onnxSession + val decoderOutputs = + getDecoderOutputs(decoderInputIds.toArray, onnxSession = (decoderSession, env)) + decoderOutputs + }) + + } + private def getDecoderOutputs( + inputIds: Array[Array[Int]], + onnxSession: (OrtSession, OrtEnvironment)): (Array[Array[Float]]) = { + val (session, env) = onnxSession + + val inputIdsLong: Array[Array[Long]] = + inputIds.map { tokenIds => tokenIds.map(_.toLong) } + + val inputPositionIDsLong: Array[Array[Long]] = + inputIds.map { tokenIds => + tokenIds.zipWithIndex.map { case (_, i) => + i.toLong + } + } + + val inputIdsLongTensor: OnnxTensor = + OnnxTensor.createTensor(env, inputIdsLong) + val decoderAttentionMask: OnnxTensor = + OnnxTensor.createTensor(env, inputIdsLong.map(_.map(_ => 1L))) + val decoderPositionIDs: OnnxTensor = + OnnxTensor.createTensor(env, inputPositionIDsLong) + + val decoderInputs: java.util.Map[String, OnnxTensor] = Map( + OnnxSignatures.decoderInputIDs -> inputIdsLongTensor, + OnnxSignatures.decoderAttentionMask -> decoderAttentionMask, + OnnxSignatures.decoderPositionIDs -> decoderPositionIDs).asJava + val sessionOutput = session.run(decoderInputs) + + val sequenceLength = inputIds.head.length + val batchSize = inputIds.length + +// val logits = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) +// inputIdsLongTensor.close() +// decoderPositionIDs.close() +// decoderAttentionMask.close() +// val batchLogits = logits.grouped(vocabSize).toArray +// batchLogits + + val logitsRaw = sessionOutput.getFloatArray(OnnxSignatures.decoderOutput) + val decoderOutputs = (0 until batchSize).map(i => { + logitsRaw + .slice( + i * sequenceLength * vocabSize + (sequenceLength - 1) * vocabSize, + i * sequenceLength * vocabSize + sequenceLength * vocabSize) + }) + decoderOutputs.toArray + } + + /** Gets the index with the highest score + * + * @param scores + * Array of Scores to max + * @return + * Index of the highest score + */ + private def argmax(scores: Array[Float]): Int = + scores.zipWithIndex.maxBy { case (score, _) => + score + }._2 + private def greedyGenerationFinished( + decoderIds: Seq[Array[Int]], + eosTokenId: Int, + maxOutputLength: Int): Boolean = + decoderIds.map(_.last).forall(_ == eosTokenId) || decoderIds.head.length == maxOutputLength + + private def generateGreedyOnnx( + inputIds: Array[Array[Int]], + onnxSession: (OrtSession, OrtEnvironment), + maxOutputLength: Int): (Array[Array[Int]]) = { + + val sequencesLength = inputIds.map(x => x.length).toArray + val maxSentenceLength = sequencesLength.max // - curLen + var generatedIds: Array[Array[Int]] = inputIds + while (!greedyGenerationFinished( + generatedIds, + eosTokenId, + maxOutputLength + maxSentenceLength)) { + + val (batchLogits: Array[Array[Float]]) = + Array(getDecoderOutputs(generatedIds, onnxSession).last) + + val nextTokenIds: Array[Int] = batchLogits.map(argmax) + generatedIds = + generatedIds.zip(nextTokenIds).map { case (currentIds: Array[Int], nextId: Int) => + currentIds ++ Array(nextId) + } + } + generatedIds + } + + private object OnnxSignatures { + val decoderInputIDs: String = "input_ids" + val decoderAttentionMask: String = "attention_mask" + val decoderPositionIDs: String = "position_ids" + + // create decoder past for 32 layers of key and value eg. past_key_values.0.key and past_key_values.0.value + val decoderPast: Array[String] = (0 until 32) + .flatMap(i => Seq(s"past_key_values.$i.key", s"past_key_values.$i.value")) + .toArray + val decoderOutput: String = "logits" + val decoderPresent: Array[String] = + (0 until 32).flatMap(i => Seq(s"present.$i.key", s"present.$i.value")).toArray + } + +} diff --git a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala index ee96819081fd3d..d4b31bcd361660 100644 --- a/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala +++ b/src/main/scala/com/johnsnowlabs/ml/ai/util/Generation/Generate.scala @@ -301,7 +301,7 @@ trait Generate { beamIndices(beamIdx(elem)) :+ beamIdx(elem) } currentLength = currentLength + 1 - if (beamScorer.isDone || (expandedInputs.head.length >= maxLength)) { + if (beamScorer.isDone || (expandedInputs.head.length > maxLength)) { break } diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/OLMoTransformer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/OLMoTransformer.scala new file mode 100644 index 00000000000000..17ee68463d18b0 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/seq2seq/OLMoTransformer.scala @@ -0,0 +1,427 @@ +/* + * Copyright 2017-2024 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.seq2seq + +import com.johnsnowlabs.ml.ai.util.Generation.GenerationConfig +import com.johnsnowlabs.ml.ai.OLMo +import com.johnsnowlabs.ml.onnx.OnnxWrapper.DecoderWrappers +import com.johnsnowlabs.ml.onnx.{OnnxWrapper, ReadOnnxModel, WriteOnnxModel} +import com.johnsnowlabs.ml.util.LoadExternalModel.{ + loadJsonStringAsset, + loadSentencePieceAsset, + loadTextAsset, + modelSanityCheck, + notSupportedEngineError +} +import com.johnsnowlabs.ml.util.ONNX +import com.johnsnowlabs.nlp.AnnotatorType.DOCUMENT +import com.johnsnowlabs.nlp._ +import com.johnsnowlabs.ml.tensorflow.sentencepiece.{ + ReadSentencePieceModel, + SentencePieceWrapper, + WriteSentencePieceModel +} +import com.johnsnowlabs.nlp.serialization.MapFeature +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.SparkSession +import com.johnsnowlabs.nlp.serialization.{MapFeature, StructFeature} +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +/** OLMo: Open Language Models + * + * OLMo is a series of Open Language Models designed to enable the science of language models. + * The OLMo models are trained on the Dolma dataset. + * + * Pretrained models can be loaded with `pretrained` of the companion object: + * {{{ + * val OLMo = OLMoTransformer.pretrained() + * .setInputCols("document") + * .setOutputCol("generation") + * }}} + * The default model is `"OLMo-1b"`, if no name is provided. For available pretrained models + * please see the [[https://sparknlp.org/models?q=OLMo Models Hub]]. + * + * For extended examples of usage, see + * [[https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/OLMoTestSpec.scala OLMoTestSpec]]. + * + * '''References:''' + * - [[https://allenai.org/olmo OLMo Project Page.]] + * - [[https://github.com/allenai/OLMo OLMo GitHub Repository.]] + * - [[https://arxiv.org/pdf/2402.00838.pdf OLMo: Accelerating the Science of Language Models]] + * + * '''Paper Abstract:''' + * + * ''Language models (LMs) have become ubiquitous in both NLP research and in commercial product + * offerings. As their commercial importance has surged, the most powerful models have become + * closed off, gated behind proprietary interfaces, with important details of their training + * data, architectures, and development undisclosed. Given the importance of these details in + * scientifically studying these models, including their biases and potential risks, we believe + * it is essential for the research community to have access to powerful, truly open LMs. To this + * end, this technical report details the first release of OLMo, a state-of-the-art, truly Open + * Language Model and its framework to build and study the science of language modeling. Unlike + * most prior efforts that have only released model weights and inference code, we release OLMo + * and the whole framework, including training data and training and evaluation code. We hope + * this release will empower and strengthen the open research community and inspire a new wave of + * innovation.'' + * + * '''Note:''' + * + * This is a very computationally expensive module especially on larger sequence. The use of an + * accelerator such as GPU is recommended. + * + * ==Example== + * {{{ + * import spark.implicits._ + * import com.johnsnowlabs.nlp.base.DocumentAssembler + * import com.johnsnowlabs.nlp.annotators.seq2seq.OLMoTransformer + * import org.apache.spark.ml.Pipeline + * + * val documentAssembler = new DocumentAssembler() + * .setInputCol("text") + * .setOutputCol("documents") + * + * val OLMo = OLMoTransformer.pretrained("OLMo-7b") + * .setInputCols(Array("documents")) + * .setMinOutputLength(10) + * .setMaxOutputLength(50) + * .setDoSample(false) + * .setTopK(50) + * .setNoRepeatNgramSize(3) + * .setOutputCol("generation") + * + * val pipeline = new Pipeline().setStages(Array(documentAssembler, OLMo)) + * + * val data = Seq( + * "My name is Leonardo." + * ).toDF("text") + * val result = pipeline.fit(data).transform(data) + * + * results.select("generation.result").show(truncate = false) + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * |result | + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * |[ My name is Leonardo . I am a student of the University of California, Berkeley. I am interested in the field of Artificial Intelligence and its applications in the real world. I have a strong | + * | passion for learning and am always looking for ways to improve my knowledge and skills] | + * +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ + * }}} + * + * @param uid + * required uid for storing annotator to disk + * @groupname anno Annotator types + * @groupdesc anno + * Required input and expected output annotator types + * @groupname Ungrouped Members + * @groupname param Parameters + * @groupname setParam Parameter setters + * @groupname getParam Parameter getters + * @groupname Ungrouped Members + * @groupprio param 1 + * @groupprio anno 2 + * @groupprio Ungrouped 3 + * @groupprio setParam 4 + * @groupprio getParam 5 + * @groupdesc param + * A list of (hyper-)parameter keys this annotator can take. Users can set and get the + * parameter values through setters and getters, respectively. + */ +class OLMoTransformer(override val uid: String) + extends AnnotatorModel[OLMoTransformer] + with HasBatchedAnnotate[OLMoTransformer] + with ParamsAndFeaturesWritable + with WriteOnnxModel + with HasGeneratorProperties + with HasEngine { + + def this() = this(Identifiable.randomUID("OLMoTRANSFORMER")) + + /** Input annotator type : DOCUMENT + * + * @group param + */ + override val inputAnnotatorTypes: Array[AnnotatorType] = Array(DOCUMENT) + + /** Output annotator type : DOCUMENT + * + * @group param + */ + override val outputAnnotatorType: String = DOCUMENT + + /** @group setParam */ + def setRandomSeed(value: Int): OLMoTransformer.this.type = { + if (randomSeed.isEmpty) { + this.randomSeed = Some(value) + } + this + } + + /** A list of token ids which are ignored in the decoder's output (Default: `Array()`) + * + * @group param + */ + var ignoreTokenIds = new IntArrayParam( + this, + "ignoreTokenIds", + "A list of token ids which are ignored in the decoder's output") + + /** @group setParam */ + def setIgnoreTokenIds(tokenIds: Array[Int]): OLMoTransformer.this.type = { + set(ignoreTokenIds, tokenIds) + } + + /** @group getParam */ + def getIgnoreTokenIds: Array[Int] = $(ignoreTokenIds) + + /** Vocabulary used to encode the words to ids with bpeTokenizer.encode + * + * @group param + */ + val vocabulary: MapFeature[String, Int] = new MapFeature(this, "vocabulary").setProtected() + + /** @group setParam */ + def setVocabulary(value: Map[String, Int]): this.type = set(vocabulary, value) + + /** Holding merges.txt coming from RoBERTa model + * + * @group param + */ + val merges: MapFeature[(String, String), Int] = new MapFeature(this, "merges").setProtected() + + /** @group setParam */ + def setMerges(value: Map[(String, String), Int]): this.type = set(merges, value) + + private var _model: Option[Broadcast[OLMo]] = None + + val generationConfig: StructFeature[GenerationConfig] = + new StructFeature(this, "generationConfig").setProtected() + + def setGenerationConfig(value: GenerationConfig): this.type = + set(generationConfig, value) + + def getGenerationConfig: GenerationConfig = $$(generationConfig) + + /** @group setParam */ + def setModelIfNotSet(spark: SparkSession, onnxWrappers: DecoderWrappers): this.type = { + if (_model.isEmpty) { + _model = Some( + spark.sparkContext.broadcast( + new OLMo( + onnxWrappers, + $$(merges), + $$(vocabulary), + generationConfig = getGenerationConfig))) + } + this + } + + /** @group getParam */ + def getModelIfNotSet: OLMo = _model.get.value + + setDefault( + minOutputLength -> 0, + maxOutputLength -> 20, + doSample -> false, + temperature -> 0.6, + topK -> 50, + topP -> 0.9, + repetitionPenalty -> 1.0, + noRepeatNgramSize -> 3, + ignoreTokenIds -> Array(), + batchSize -> 1, + beamSize -> 1, + maxInputLength -> 4096) + + /** takes a document and annotations and produces new annotations of this annotator's annotation + * type + * + * @param batchedAnnotations + * Annotations that correspond to inputAnnotationCols generated by previous annotators if any + * @return + * any number of annotations processed for every input annotation. Not necessary one to one + * relationship + */ + override def batchAnnotate(batchedAnnotations: Seq[Array[Annotation]]): Seq[Seq[Annotation]] = { + + val allAnnotations = batchedAnnotations + .filter(_.nonEmpty) + .zipWithIndex + .flatMap { case (annotations, i) => + annotations.filter(_.result.nonEmpty).map(x => (x, i)) + } + val processedAnnotations = if (allAnnotations.nonEmpty) { + this.getModelIfNotSet.predict( + sentences = allAnnotations.map(_._1), + batchSize = $(batchSize), + minOutputLength = $(minOutputLength), + maxOutputLength = $(maxOutputLength), + doSample = $(doSample), + temperature = $(temperature), + topK = $(topK), + topP = $(topP), + repetitionPenalty = $(repetitionPenalty), + noRepeatNgramSize = $(noRepeatNgramSize), + randomSeed = this.randomSeed, + ignoreTokenIds = $(ignoreTokenIds), + beamSize = $(beamSize), + maxInputLength = $(maxInputLength)) + } else { + Seq() + } + Seq(processedAnnotations) + } + + override def onWrite(path: String, spark: SparkSession): Unit = { + super.onWrite(path, spark) + getEngine match { + case ONNX.name => + val wrappers = getModelIfNotSet.onnxWrappers + writeOnnxModels( + path, + spark, + Seq((wrappers.decoder, "decoder_model.onnx")), + OLMoTransformer.suffix) + } + } +} + +trait ReadablePretrainedOLMoTransformerModel + extends ParamsAndFeaturesReadable[OLMoTransformer] + with HasPretrained[OLMoTransformer] { + override val defaultModelName: Some[String] = Some("OLMo-7b") + + /** Java compliant-overrides */ + override def pretrained(): OLMoTransformer = super.pretrained() + + override def pretrained(name: String): OLMoTransformer = super.pretrained(name) + + override def pretrained(name: String, lang: String): OLMoTransformer = + super.pretrained(name, lang) + + override def pretrained(name: String, lang: String, remoteLoc: String): OLMoTransformer = + super.pretrained(name, lang, remoteLoc) +} + +trait ReadOLMoTransformerDLModel extends ReadOnnxModel { + this: ParamsAndFeaturesReadable[OLMoTransformer] => + + override val onnxFile: String = "olmo_onnx" + val suffix: String = "_olmo" + + def readModel(instance: OLMoTransformer, path: String, spark: SparkSession): Unit = { + instance.getEngine match { + case ONNX.name => + val wrappers = + readOnnxModels(path, spark, Seq("decoder_model.onnx"), suffix) + val onnxWrappers = + DecoderWrappers(decoder = wrappers("decoder_model.onnx")) + instance.setModelIfNotSet(spark, onnxWrappers) + case _ => + throw new Exception(notSupportedEngineError) + } + } + + addReader(readModel) + + def loadSavedModel(modelPath: String, spark: SparkSession): OLMoTransformer = { + implicit val formats: DefaultFormats.type = DefaultFormats // for json4 + val (localModelPath, detectedEngine) = + modelSanityCheck(modelPath, isDecoder = true) + val modelConfig: JValue = + parse(loadJsonStringAsset(localModelPath, "config.json")) + + val beginSuppressTokens: Array[Int] = + (modelConfig \ "begin_suppress_tokens").extract[Array[Int]] + + val suppressTokenIds: Array[Int] = + (modelConfig \ "suppress_tokens").extract[Array[Int]] + + val forcedDecoderIds: Array[(Int, Int)] = + (modelConfig \ "forced_decoder_ids").extract[Array[Array[Int]]].map { + case idxWithTokenId: Array[Int] if idxWithTokenId.length == 2 => + (idxWithTokenId(0), idxWithTokenId(1)) + case _ => + throw new Exception( + "Could not extract forced_decoder_ids. Should be a list of tuples with 2 entries.") + } + + def arrayOrNone[T](array: Array[T]): Option[Array[T]] = + if (array.nonEmpty) Some(array) else None + + var bosTokenId = -1 + try { + bosTokenId = (modelConfig \ "bos_token_id").extract[Int] + } catch { + case _: Exception => + println("Could not extract bos_token_id from config.json, assigning default value -1") + } + val eosTokenId = (modelConfig \ "eos_token_id").extract[Int] + val padTokenId = (modelConfig \ "eos_token_id").extract[Int] + val vocabSize = (modelConfig \ "vocab_size").extract[Int] + + val vocabs = loadTextAsset(localModelPath, "vocab.txt").zipWithIndex.toMap + + val bytePairs = loadTextAsset(localModelPath, "merges.txt") + .map(_.split(" ")) + .filter(w => w.length == 2) + .map { case Array(c1, c2) => (c1, c2) } + .zipWithIndex + .toMap + + val annotatorModel = new OLMoTransformer() + .setGenerationConfig( + GenerationConfig( + bosTokenId, + padTokenId, + eosTokenId, + vocabSize, + arrayOrNone(beginSuppressTokens), + arrayOrNone(suppressTokenIds), + arrayOrNone(forcedDecoderIds))) + .setVocabulary(vocabs) + .setMerges(bytePairs) + + annotatorModel.set(annotatorModel.engine, detectedEngine) + + detectedEngine match { + case ONNX.name => + val onnxWrapperDecoder = + OnnxWrapper.read( + localModelPath, + zipped = false, + useBundle = true, + modelName = "decoder_model", + dataFileSuffix = ".onnx_data") + + val onnxWrappers = DecoderWrappers(onnxWrapperDecoder) + + annotatorModel + .setModelIfNotSet(spark, onnxWrappers) + + case _ => + throw new Exception(notSupportedEngineError) + } + + annotatorModel + } + +} + +object OLMoTransformer + extends ReadablePretrainedOLMoTransformerModel + with ReadOLMoTransformerDLModel diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala index 691ca8522dddb5..0d6e6a4cec88d3 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeSpecialTokens.scala @@ -137,6 +137,14 @@ private[johnsnowlabs] object SpecialTokens { unkTokenString = "", maskTokenString = "", padTokenString = "") + case "olmo" => + SpecialTokens( + vocab, + startTokenString = "<|endoftext|>", + endTokenString = "<|endoftext|>", + unkTokenString = "<|endoftext|>", + maskTokenString = "<|endoftext|>", + padTokenString = "<|padding|>") case "clip" => SpecialTokens( vocab, diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala index c948661b0e039b..d701b4ad57bdda 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/BpeTokenizer.scala @@ -352,6 +352,13 @@ object BpeTokenizer { modelSpecialTokens(), padWithSequenceTokens, addPrefixSpaceToSentence = addPrefixSpaceToSentence) + case "olmo" => + new OLMoTokenizer( + merges, + vocab, + modelSpecialTokens(), + padWithSequenceTokens, + addPrefixSpaceToSentence = addPrefixSpaceToSentence) case "clip" => new CLIPTokenizer(merges, vocab, modelSpecialTokens()) case _ => diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/OLMoTokenizer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/OLMoTokenizer.scala new file mode 100644 index 00000000000000..95f046f5913670 --- /dev/null +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/OLMoTokenizer.scala @@ -0,0 +1,31 @@ +/* + * Copyright 2017-2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.tokenizer.bpe + +class OLMoTokenizer( + merges: Map[(String, String), Int], + vocab: Map[String, Int], + specialTokens: SpecialTokens, + padWithSequenceTokens: Boolean = false, + addPrefixSpaceToSentence: Boolean = false) + extends Gpt2Tokenizer( + merges, + vocab, + specialTokens, + padWithSequenceTokens, + prependString = "Ġ", + addPrefixSpaceToSentence) diff --git a/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/OLMoTestSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/OLMoTestSpec.scala new file mode 100644 index 00000000000000..a1a1ee5365b47d --- /dev/null +++ b/src/test/scala/com/johnsnowlabs/nlp/annotators/seq2seq/OLMoTestSpec.scala @@ -0,0 +1,52 @@ +/* + * Copyright 2017-2023 John Snow Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.johnsnowlabs.nlp.annotators.seq2seq + +import com.johnsnowlabs.nlp.base.DocumentAssembler +import com.johnsnowlabs.nlp.util.io.ResourceHelper +import com.johnsnowlabs.tags.{FastTest, SlowTest} +import org.apache.spark.ml.Pipeline +import org.scalatest.flatspec.AnyFlatSpec + +class OLMoTestSpec extends AnyFlatSpec { + + "olmo" should "should handle temperature=0 correctly and not crash when predicting more than 1 element with doSample=True" taggedAs FastTest in { + // Even tough the Paper states temperature in interval [0,1), using temperature=0 will result in division by 0 error. + // Also DoSample=True may result in infinities being generated and distFiltered.length==0 which results in exception if we don't return 0 instead internally. + val testData = ResourceHelper.spark + .createDataFrame(Seq((1, "My name is Leonardo."))) + .toDF("id", "text") + .repartition(1) + val documentAssembler = new DocumentAssembler() + .setInputCol("text") + .setOutputCol("documents") + + val bart = OLMoTransformer + .pretrained() + .setInputCols(Array("documents")) + .setDoSample(false) + .setMaxOutputLength(100) + .setOutputCol("generation") + .setBeamSize(1) + new Pipeline() + .setStages(Array(documentAssembler, bart)) + .fit(testData) + .transform(testData) + .show(truncate = false) + + } +}