Skip to content

End-to-end example of traning a model using PyTroch and using it for inference from JVM

License

Notifications You must be signed in to change notification settings

bzz/pytorch-jvm-onnx

Repository files navigation

Train PyTorch model + JVM inference \w ONNX

This is just an illustrative example of preparing a PyTorch model for beeing used from JVM environment.

The Problem and the data

Predict a group of the Yest gene

A Multi-class classification \w structured data in libsvm format.

  • classes: 14
  • features: 103
  • data points: 1,500 (train) / 917 (test)

Exploration

TODO: add a notebook checking the dataset for imbalanced classes.

Train

PyTorch

# install dependencies
virtualenv .venv
source .venv/bin/activate
pip install -r requirements.txt

# get the data
wget 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel/yeast_train.svm.bz2'
wget 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel/yeast_test.svm.bz2'
bzip2 -d *.bz2

# test the dataloader
./train.py

# train the model
./train.py --model model.pt

Does not include logging, early stopping, model checkpointing and lots of other nice goodies.

PyTorch Lightning

But PyTorch Lightning does include all that, and many more for free 🎉

./train_ptl.py --model models/ptl/model.pt

Monitor the progess \w tensorboard

tensorboard --logdir lightning_logs/
open http://localhost:6667

3 epochs of 400it/s result in precision 0.768 when the original paper has 0.762.

Predict

PyTorch inference in Python

./predict.py --model model.pt < single_example.txt

Correct anser is 2, 3.

Libtorch JNI bindings

TODO https://github.com/pytorch/java-demo/blob/master/src/main/java/demo/App.java

ONNX export the model

https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html

In: trained model.pt Out: model.onnx

./onnx_export.py --model model.pt --out model.onnx

ONNX inference in Python

./onnx_predict.py --model model.onnx < single_example.txt

ONNX inference in JVM

Using JNI-based Java API of ONNX JVM Runtime

cp model.onnx onnx-predict-java/src/main/resources/
cd onnx-predict-java
./gradlew jar

java -jar ./build/libs/onnx-predict-java.jar  < single_example.txt`
  • see this for discussion on JNI and multipel classloader support
  • ONNX Runtime dependency is 92Mb

Reduce the model size

Explore different NN architectures

Architecture-neutural optimizations

Model Params On disk Train time
fp32 mlp 52kb
onnx mlp 48kb
fp16 mlp ?
8bit mlp ?
fp32 mlp+hyperopt ?
fp32 dcn ?

Interpret the model

How important are some of the features? Explain, how it’s weights contribute towards it’s final decision.

  • Primary attribution \w integrated gradients for feature importance using https://captum.ai

Optimizations

  • PTL profiler
  • new PyTorch profiler
  • is model execution time dominated by loading weights from memory or computing the matrix multiplications?

About

End-to-end example of traning a model using PyTroch and using it for inference from JVM

Topics

Resources

License

Stars

Watchers

Forks