From 988110efd57f01aa5d13ab9293f8cb05ca85d55b Mon Sep 17 00:00:00 2001 From: OrigamiDream Date: Wed, 2 Nov 2022 13:41:21 +0900 Subject: [PATCH] Add TensorFlow-Serving support and Dockerfile --- Dockerfile | 9 +++ README.md | 18 +++++ configure_docker_image.py | 159 ++++++++++++++++++++++++++++++++++++++ cort/modeling.py | 33 ++++++++ requirements.txt | 1 + 5 files changed, 220 insertions(+) create mode 100644 Dockerfile create mode 100644 configure_docker_image.py diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..f01154c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,9 @@ +FROM tensorflow/serving:2.8.2 + +ENV MODEL_DIR ./models +ENV MODEL_NAME cort +ENV MODEL_VERSION 1 + +COPY $MODEL_DIR/$MODEL_NAME/$MODEL_VERSION /models/$MODEL_NAME/$MODEL_VERSION + +EXPOSE 8500 diff --git a/README.md b/README.md index 0d24ac7..835b5c8 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,24 @@ It has the following arguments: Perform inference for metrics by (for example) `python run_inference.py --checkpoint_path ./finetuning-checkpoints/wandb_run_id/ckpt-0 --tfrecord_path ./data/tfrecords/{model_name}/valid.fold-1-of-10.tfrecord --concat_hidden_states 2 --repr_act tanh --repr_classifier bi_lstm --repr_size 1024`.
`--concat_hidden_states`, `--repr_act`, `--repr_classifier`, `--repr_size` must be same with configurations that used for fine-tuned model's architecture. +### Serving + +CoRT supports [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving) on Docker, use `configure_docker_image.py` to build a Docker image. +It has the following arguments: +- `--checkpoint_path`: Location of trained model checkpoint. (Required) +- `--saved_model_dir`: Location of SavedModel to be stored. ('./models' as default) +- `--model_spec_name`: Name of model spec. ('cort' as default) +- `--model_spec_version`: Version of model spec. ('1' as default) +- `--signature_name`: Name of signature of SavedModel ('serving_default' as default) +- `--model_name`: Name of pre-trained models. (One of korscibert, korscielectra, huggingface models is allowed) +- `--tfrecord_path`: Location of TFRecord file for warmup requests. {model_name} is a placeholder. +- `--num_warmup_requests`: Number of warmup requests. Pass 0 to skip (10 as default) +- `--repr_classifier`: Name of classification head for classifier. (One of 'seq_cls' and 'bi_lstm' is allowed) +- `--repr_act`: Name of activation function for representation. (One of 'tanh' and 'gelu' is allowed) +- `--concat_hidden_states`: Number of hidden states to concatenate. (1 as default) +- `--repr_size`: Number of representation dense units. (1024 as default) +- `--num_labels`: Number of labels. (9 as default) + ### Performance [LAN (Label Attention Network)](https://aida.kisti.re.kr/gallery/17) has been proposed in [2021 KISTI AI/ML Competition](https://aida.kisti.re.kr/notice/7).
diff --git a/configure_docker_image.py b/configure_docker_image.py new file mode 100644 index 0000000..db56d7d --- /dev/null +++ b/configure_docker_image.py @@ -0,0 +1,159 @@ +import os +import logging +import argparse + +import tensorflow as tf + +from utils import utils, formatting_utils +from cort.config import Config +from cort.modeling import CortForSequenceClassification +from tensorflow_serving.apis.predict_pb2 import PredictRequest +from tensorflow_serving.apis.prediction_log_pb2 import PredictionLog, PredictLog + +formatting_utils.setup_formatter(logging.INFO) + + +def parse_tfrecords(tfrecord_path, model_name, maxlen, num_samples): + feature_desc = { + 'input_ids': tf.io.FixedLenFeature([maxlen], tf.int64), + 'sections': tf.io.FixedLenFeature([1], tf.int64), + 'labels': tf.io.FixedLenFeature([1], tf.int64) + } + + def _parse_feature_desc(example_proto): + example = tf.io.parse_single_example(example_proto, feature_desc) + + # tf.int64 is acceptable, but tf.int32 has more performance advantages. + for name in list(example.keys()): + tensor = example[name] + if tensor.dtype == tf.int64: + tensor = tf.cast(tensor, tf.int32) + example[name] = tensor + return example + + def _reconfigure_inputs(example): + return example['input_ids'] + + fname = tfrecord_path.format(model_name=model_name.replace('/', '_')) + logging.info('Parsing TFRecords from {}'.format(fname)) + + dataset = tf.data.TFRecordDataset(fname) + dataset = dataset.map(_parse_feature_desc).map(_reconfigure_inputs) + dataset = dataset.shuffle(buffer_size=1024).repeat().batch(num_samples) + + input_ids = None + for input_ids in dataset: + break + return input_ids + + +def store_warmup_requests(args, input_ids, saved_model_path): + warmup_request_dir = os.path.join(saved_model_path, 'assets.extra') + os.makedirs(warmup_request_dir, exist_ok=True) + warmup_request_path = os.path.join(warmup_request_dir, 'tf_serving_warmup_requests') + + with tf.io.TFRecordWriter(warmup_request_path) as writer: + input_ids = tf.make_tensor_proto(input_ids) + + request = PredictRequest() + request.model_spec.name = args.model_spec_name + request.model_spec.signature_name = args.signature_name + request.inputs['input_ids'].CopyFrom(input_ids) + + log = PredictionLog(predict_log=PredictLog(request=request)) + writer.write(log.SerializeToString()) + logging.info('{} warmup requests have been stored at: {}'.format(args.num_warmup_requests, warmup_request_path)) + + +def restore_cort_classifier(args, config: Config): + cort_model = CortForSequenceClassification(config, num_labels=config.num_labels) + cort_model.trainable = False + + # Restore from checkpoint + checkpoint = tf.train.Checkpoint(model=cort_model) + checkpoint.restore(args.checkpoint_path).expect_partial() + + serving = CortForSequenceClassification.Serving(config, cort_model) + serving(serving.dummy_inputs) + logging.info('Restored model checkpoint from: {}'.format(args.checkpoint_path)) + return serving + + +def store_as_saved_model(cort_model, signature_name, filepath): + maxlen = cort_model.config.pretrained_config.max_position_embeddings + + @tf.function(input_signature=[tf.TensorSpec(shape=(None, maxlen), dtype=tf.int32, name='input_ids')]) + def _eval_wrapper(input_ids): + return cort_model(input_ids) + + signatures = _eval_wrapper.get_concrete_function() + tf.saved_model.save(cort_model, filepath, signatures={ + signature_name: signatures + }) + logging.info('Servable CoRT classifier has been written as SavedModel format at: {}'.format(filepath)) + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument('--checkpoint_path', required=True, + help='Location of trained model checkpoint.') + parser.add_argument('--saved_model_dir', default='./models', + help='Location of SavedModel to be stored.') + parser.add_argument('--model_spec_name', default='cort', + help='Name of model spec.') + parser.add_argument('--model_spec_version', default='1', + help='Version of model spec.') + parser.add_argument('--signature_name', default='serving_default', + help='Name of signature of SavedModel') + parser.add_argument('--model_name', default='klue/roberta-base', + help='Name of pre-trained models. (One of korscibert, korscielectra, huggingface models)') + parser.add_argument('--tfrecord_path', default='./data/tfrecords/{model_name}/eval.tfrecord', + help='Location of TFRecord file for warmup requests. {model_name} is a placeholder.') + parser.add_argument('--num_warmup_requests', default=10, type=int, + help='Number of warmup requests. Pass 0 to skip') + parser.add_argument('--repr_classifier', default='seq_cls', + help='Name of classification head for classifier. (One of seq_cls and bi_lstm is allowed)') + parser.add_argument('--repr_act', default='tanh', + help='Name of activation function for representation. (One of tanh and gelu is allowed)') + parser.add_argument('--concat_hidden_states', default=1, type=int, + help='Number of hidden states to concatenate.') + parser.add_argument('--repr_size', default=1024, type=int, + help='Number of representation dense units') + parser.add_argument('--num_labels', default=9, type=int, + help='Number of labels') + + # Configurable pre-defined variables + parser.add_argument('--korscibert_vocab', default='./cort/pretrained/korscibert/vocab_kisti.txt') + parser.add_argument('--korscibert_ckpt', default='./cort/pretrained/korscibert/model.ckpt-262500') + parser.add_argument('--korscielectra_vocab', default='./cort/pretrained/korscielectra/data/vocab.txt') + parser.add_argument('--korscielectra_ckpt', default='./cort/pretrained/korscielectra/data/models/korsci_base') + parser.add_argument('--classifier_dropout_prob', default=0.1, type=float) + + # Parser arguments + args = parser.parse_args() + config = Config(**vars(args)) + config.pretrained_config = utils.parse_pretrained_config(config) + saved_model_path = os.path.join(args.saved_model_dir, args.model_spec_name, args.model_spec_version) + + cort_serving = restore_cort_classifier(args, config) + + store_as_saved_model(cort_serving, args.signature_name, saved_model_path) + + if args.num_warmup_requests > 0: + maxlen = config.pretrained_config.max_position_embeddings + input_ids = parse_tfrecords(args.tfrecord_path, args.model_name, maxlen, num_samples=args.num_warmup_requests) + store_warmup_requests(args, input_ids, saved_model_path) + + logging.info('Finishing all necessary jobs') + logging.info('Run following command to build and run Docker container:') + logging.info( + ' MODEL_DIR={} MODEL_NAME={} MODEL_VERSION={} docker build -t cort/serving:latest .' + .format(args.saved_model_dir, + args.model_spec_name, + args.model_spec_version) + ) + logging.info(' docker run -d -p 8500:8500 --name cort-serving cort/serving') + + +if __name__ == '__main__': + main() diff --git a/cort/modeling.py b/cort/modeling.py index 4033ebe..392decf 100644 --- a/cort/modeling.py +++ b/cort/modeling.py @@ -377,6 +377,39 @@ def call(self, inputs, training=None, mask=None): def get_config(self): return super(CortForSequenceClassification, self).get_config() + class Serving(models.Model): + + def __init__(self, + config: ConfigLike, + cort_model: "CortForSequenceClassification", + calc_correlation=True, **kwargs): + super(CortForSequenceClassification.Serving, self).__init__(**kwargs) + self.config = Config.parse_config(config) + self.cort_model = cort_model + self.calc_correlation = calc_correlation + self.dummy_inputs = tf.zeros( + shape=(1, self.config.pretrained_config.max_position_embeddings), dtype=tf.int32 + ) + + def call(self, inputs, training=None, mask=None): + _, cort_outputs = self.cort_model(inputs) + + outputs = { + 'logits': cort_outputs['logits'], + 'probs': cort_outputs['probs'] + } + if self.calc_correlation: + attentions = cort_outputs['attentions'] + attention_maps = [] + for attention in attentions: + reduced = tf.reduce_mean(attention, axis=1) + attention_maps.append(reduced) + + reduced_attention = tf.concat(attention_maps, axis=1) + reduced_attention = tf.reduce_mean(reduced_attention, axis=1) + outputs['correlations'] = reduced_attention + return outputs + class CortMainLayer(layers.Layer): diff --git a/requirements.txt b/requirements.txt index bbcc31d..98e9ab5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -127,6 +127,7 @@ tensorflow==2.10.0 tensorflow-addons==0.18.0 tensorflow-estimator==2.10.0 tensorflow-io-gcs-filesystem==0.27.0 +tensorflow-serving-api==2.10.0 termcolor==2.0.1 terminado==0.15.0 threadpoolctl==3.1.0