diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..f01154c
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,9 @@
+FROM tensorflow/serving:2.8.2
+
+ENV MODEL_DIR ./models
+ENV MODEL_NAME cort
+ENV MODEL_VERSION 1
+
+COPY $MODEL_DIR/$MODEL_NAME/$MODEL_VERSION /models/$MODEL_NAME/$MODEL_VERSION
+
+EXPOSE 8500
diff --git a/README.md b/README.md
index 0d24ac7..835b5c8 100644
--- a/README.md
+++ b/README.md
@@ -158,6 +158,24 @@ It has the following arguments:
Perform inference for metrics by (for example) `python run_inference.py --checkpoint_path ./finetuning-checkpoints/wandb_run_id/ckpt-0 --tfrecord_path ./data/tfrecords/{model_name}/valid.fold-1-of-10.tfrecord --concat_hidden_states 2 --repr_act tanh --repr_classifier bi_lstm --repr_size 1024`.
`--concat_hidden_states`, `--repr_act`, `--repr_classifier`, `--repr_size` must be same with configurations that used for fine-tuned model's architecture.
+### Serving
+
+CoRT supports [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving) on Docker, use `configure_docker_image.py` to build a Docker image.
+It has the following arguments:
+- `--checkpoint_path`: Location of trained model checkpoint. (Required)
+- `--saved_model_dir`: Location of SavedModel to be stored. ('./models' as default)
+- `--model_spec_name`: Name of model spec. ('cort' as default)
+- `--model_spec_version`: Version of model spec. ('1' as default)
+- `--signature_name`: Name of signature of SavedModel ('serving_default' as default)
+- `--model_name`: Name of pre-trained models. (One of korscibert, korscielectra, huggingface models is allowed)
+- `--tfrecord_path`: Location of TFRecord file for warmup requests. {model_name} is a placeholder.
+- `--num_warmup_requests`: Number of warmup requests. Pass 0 to skip (10 as default)
+- `--repr_classifier`: Name of classification head for classifier. (One of 'seq_cls' and 'bi_lstm' is allowed)
+- `--repr_act`: Name of activation function for representation. (One of 'tanh' and 'gelu' is allowed)
+- `--concat_hidden_states`: Number of hidden states to concatenate. (1 as default)
+- `--repr_size`: Number of representation dense units. (1024 as default)
+- `--num_labels`: Number of labels. (9 as default)
+
### Performance
[LAN (Label Attention Network)](https://aida.kisti.re.kr/gallery/17) has been proposed in [2021 KISTI AI/ML Competition](https://aida.kisti.re.kr/notice/7).
diff --git a/configure_docker_image.py b/configure_docker_image.py
new file mode 100644
index 0000000..db56d7d
--- /dev/null
+++ b/configure_docker_image.py
@@ -0,0 +1,159 @@
+import os
+import logging
+import argparse
+
+import tensorflow as tf
+
+from utils import utils, formatting_utils
+from cort.config import Config
+from cort.modeling import CortForSequenceClassification
+from tensorflow_serving.apis.predict_pb2 import PredictRequest
+from tensorflow_serving.apis.prediction_log_pb2 import PredictionLog, PredictLog
+
+formatting_utils.setup_formatter(logging.INFO)
+
+
+def parse_tfrecords(tfrecord_path, model_name, maxlen, num_samples):
+ feature_desc = {
+ 'input_ids': tf.io.FixedLenFeature([maxlen], tf.int64),
+ 'sections': tf.io.FixedLenFeature([1], tf.int64),
+ 'labels': tf.io.FixedLenFeature([1], tf.int64)
+ }
+
+ def _parse_feature_desc(example_proto):
+ example = tf.io.parse_single_example(example_proto, feature_desc)
+
+ # tf.int64 is acceptable, but tf.int32 has more performance advantages.
+ for name in list(example.keys()):
+ tensor = example[name]
+ if tensor.dtype == tf.int64:
+ tensor = tf.cast(tensor, tf.int32)
+ example[name] = tensor
+ return example
+
+ def _reconfigure_inputs(example):
+ return example['input_ids']
+
+ fname = tfrecord_path.format(model_name=model_name.replace('/', '_'))
+ logging.info('Parsing TFRecords from {}'.format(fname))
+
+ dataset = tf.data.TFRecordDataset(fname)
+ dataset = dataset.map(_parse_feature_desc).map(_reconfigure_inputs)
+ dataset = dataset.shuffle(buffer_size=1024).repeat().batch(num_samples)
+
+ input_ids = None
+ for input_ids in dataset:
+ break
+ return input_ids
+
+
+def store_warmup_requests(args, input_ids, saved_model_path):
+ warmup_request_dir = os.path.join(saved_model_path, 'assets.extra')
+ os.makedirs(warmup_request_dir, exist_ok=True)
+ warmup_request_path = os.path.join(warmup_request_dir, 'tf_serving_warmup_requests')
+
+ with tf.io.TFRecordWriter(warmup_request_path) as writer:
+ input_ids = tf.make_tensor_proto(input_ids)
+
+ request = PredictRequest()
+ request.model_spec.name = args.model_spec_name
+ request.model_spec.signature_name = args.signature_name
+ request.inputs['input_ids'].CopyFrom(input_ids)
+
+ log = PredictionLog(predict_log=PredictLog(request=request))
+ writer.write(log.SerializeToString())
+ logging.info('{} warmup requests have been stored at: {}'.format(args.num_warmup_requests, warmup_request_path))
+
+
+def restore_cort_classifier(args, config: Config):
+ cort_model = CortForSequenceClassification(config, num_labels=config.num_labels)
+ cort_model.trainable = False
+
+ # Restore from checkpoint
+ checkpoint = tf.train.Checkpoint(model=cort_model)
+ checkpoint.restore(args.checkpoint_path).expect_partial()
+
+ serving = CortForSequenceClassification.Serving(config, cort_model)
+ serving(serving.dummy_inputs)
+ logging.info('Restored model checkpoint from: {}'.format(args.checkpoint_path))
+ return serving
+
+
+def store_as_saved_model(cort_model, signature_name, filepath):
+ maxlen = cort_model.config.pretrained_config.max_position_embeddings
+
+ @tf.function(input_signature=[tf.TensorSpec(shape=(None, maxlen), dtype=tf.int32, name='input_ids')])
+ def _eval_wrapper(input_ids):
+ return cort_model(input_ids)
+
+ signatures = _eval_wrapper.get_concrete_function()
+ tf.saved_model.save(cort_model, filepath, signatures={
+ signature_name: signatures
+ })
+ logging.info('Servable CoRT classifier has been written as SavedModel format at: {}'.format(filepath))
+
+
+def main():
+ parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument('--checkpoint_path', required=True,
+ help='Location of trained model checkpoint.')
+ parser.add_argument('--saved_model_dir', default='./models',
+ help='Location of SavedModel to be stored.')
+ parser.add_argument('--model_spec_name', default='cort',
+ help='Name of model spec.')
+ parser.add_argument('--model_spec_version', default='1',
+ help='Version of model spec.')
+ parser.add_argument('--signature_name', default='serving_default',
+ help='Name of signature of SavedModel')
+ parser.add_argument('--model_name', default='klue/roberta-base',
+ help='Name of pre-trained models. (One of korscibert, korscielectra, huggingface models)')
+ parser.add_argument('--tfrecord_path', default='./data/tfrecords/{model_name}/eval.tfrecord',
+ help='Location of TFRecord file for warmup requests. {model_name} is a placeholder.')
+ parser.add_argument('--num_warmup_requests', default=10, type=int,
+ help='Number of warmup requests. Pass 0 to skip')
+ parser.add_argument('--repr_classifier', default='seq_cls',
+ help='Name of classification head for classifier. (One of seq_cls and bi_lstm is allowed)')
+ parser.add_argument('--repr_act', default='tanh',
+ help='Name of activation function for representation. (One of tanh and gelu is allowed)')
+ parser.add_argument('--concat_hidden_states', default=1, type=int,
+ help='Number of hidden states to concatenate.')
+ parser.add_argument('--repr_size', default=1024, type=int,
+ help='Number of representation dense units')
+ parser.add_argument('--num_labels', default=9, type=int,
+ help='Number of labels')
+
+ # Configurable pre-defined variables
+ parser.add_argument('--korscibert_vocab', default='./cort/pretrained/korscibert/vocab_kisti.txt')
+ parser.add_argument('--korscibert_ckpt', default='./cort/pretrained/korscibert/model.ckpt-262500')
+ parser.add_argument('--korscielectra_vocab', default='./cort/pretrained/korscielectra/data/vocab.txt')
+ parser.add_argument('--korscielectra_ckpt', default='./cort/pretrained/korscielectra/data/models/korsci_base')
+ parser.add_argument('--classifier_dropout_prob', default=0.1, type=float)
+
+ # Parser arguments
+ args = parser.parse_args()
+ config = Config(**vars(args))
+ config.pretrained_config = utils.parse_pretrained_config(config)
+ saved_model_path = os.path.join(args.saved_model_dir, args.model_spec_name, args.model_spec_version)
+
+ cort_serving = restore_cort_classifier(args, config)
+
+ store_as_saved_model(cort_serving, args.signature_name, saved_model_path)
+
+ if args.num_warmup_requests > 0:
+ maxlen = config.pretrained_config.max_position_embeddings
+ input_ids = parse_tfrecords(args.tfrecord_path, args.model_name, maxlen, num_samples=args.num_warmup_requests)
+ store_warmup_requests(args, input_ids, saved_model_path)
+
+ logging.info('Finishing all necessary jobs')
+ logging.info('Run following command to build and run Docker container:')
+ logging.info(
+ ' MODEL_DIR={} MODEL_NAME={} MODEL_VERSION={} docker build -t cort/serving:latest .'
+ .format(args.saved_model_dir,
+ args.model_spec_name,
+ args.model_spec_version)
+ )
+ logging.info(' docker run -d -p 8500:8500 --name cort-serving cort/serving')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/cort/modeling.py b/cort/modeling.py
index 4033ebe..392decf 100644
--- a/cort/modeling.py
+++ b/cort/modeling.py
@@ -377,6 +377,39 @@ def call(self, inputs, training=None, mask=None):
def get_config(self):
return super(CortForSequenceClassification, self).get_config()
+ class Serving(models.Model):
+
+ def __init__(self,
+ config: ConfigLike,
+ cort_model: "CortForSequenceClassification",
+ calc_correlation=True, **kwargs):
+ super(CortForSequenceClassification.Serving, self).__init__(**kwargs)
+ self.config = Config.parse_config(config)
+ self.cort_model = cort_model
+ self.calc_correlation = calc_correlation
+ self.dummy_inputs = tf.zeros(
+ shape=(1, self.config.pretrained_config.max_position_embeddings), dtype=tf.int32
+ )
+
+ def call(self, inputs, training=None, mask=None):
+ _, cort_outputs = self.cort_model(inputs)
+
+ outputs = {
+ 'logits': cort_outputs['logits'],
+ 'probs': cort_outputs['probs']
+ }
+ if self.calc_correlation:
+ attentions = cort_outputs['attentions']
+ attention_maps = []
+ for attention in attentions:
+ reduced = tf.reduce_mean(attention, axis=1)
+ attention_maps.append(reduced)
+
+ reduced_attention = tf.concat(attention_maps, axis=1)
+ reduced_attention = tf.reduce_mean(reduced_attention, axis=1)
+ outputs['correlations'] = reduced_attention
+ return outputs
+
class CortMainLayer(layers.Layer):
diff --git a/requirements.txt b/requirements.txt
index bbcc31d..98e9ab5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -127,6 +127,7 @@ tensorflow==2.10.0
tensorflow-addons==0.18.0
tensorflow-estimator==2.10.0
tensorflow-io-gcs-filesystem==0.27.0
+tensorflow-serving-api==2.10.0
termcolor==2.0.1
terminado==0.15.0
threadpoolctl==3.1.0