Keras Gaia handles datasets and models for Keras in simple projects. A simple abstraction is used in the dataset and model to allow easy interchanges. The training of the network can be done with a command line tool using the project configuration. The trained network can be used to make predication with a command line tool or and http service.
The following dependencies are used:
- h5py
- ijson
- isodate
- keras
The project settings are stored in a JSON file with the following structure:
- `label': Name of the project
description
: Description of the projectweightsHdf5
: Path to the HDF5 weights file. File will be written after the training and read before prediction.dataset
: Dataset optionstrainingDataJson
: Path to the JSON training data filetestDataJson
Path to the JSON test data file
model
: Model optionsioNamesJson
: Path to the JSON I/O names definition filetopologyPython
: Path to the python code that creates the Keras model.
The code must contain acreate
function which returns the model.
training
: Training optionsbatchSize
: Batch sizeepochs
: Number of epochsshuffle
: Shuffle the dataset on each epochlossLogFile
: Path to the CSV loss log filetestLogFile
: Path to the CSV test log filetestLogInterval
: The epoch interval for the testscheckpointFile
: File pattern to dump the weights during the training. Use Python string templates for epoch (e.g. {epoch:06d}).checkpointInterval
: The epoch interval for the weights dumps
The training step generates the weights for the model. This can be done from the command line. The following command runs the training for a project:
python train.py [options] <projectFile>
projectFile
: Path to the JSON project filebase
: The base path for all files (optional)resume
: Resume the training at the given epoch (optional)
Predictions can be made based on existing weights. This can be done from the command line. The input and output is read and written to JSON files. The following command runs a prediction for a project:
python predict.py [options] --input=<inputFile> --output=<outputFile> <projectFile>
projectFile
: Path to the JSON project fileinput
: Path to the JSON input fileoutput
: Path to the JSON output filebase
: The base path for all files (optional)
It's also possible to use a HTTP interface for the predictions.
The input must be sent as JSON string with a POST request to the endpoint URL.
For example if port 8080 is used http://localhost:8080/
.
The output is returned to the client as JSON string.
The following command starts the HTTP prediction server for a project:
python predict-http.py [options] --port=<port> <projectFile>
projectFile
: Path to the JSON project fileport
: Port for the HTTP serverbase
: The base path for all files (optional)
The package comes with a simple calculator example.
Use The JavaScript nn-mapping package to generate the example datasets.
Copy the data
folder in the examples folder of nn-mapping to examples/calculator/data
.
It contains two datasets (short and long).
The example defines two different models (lstm10 and lstm30).
The four projects use all combinations of the datasets and models.
The calculator.sh
bash file runs the training for all projects and makes a prediction.