Skip to content

Commit

Permalink
improve(persistence): transactional behavior
Browse files Browse the repository at this point in the history
- make predicition service participate in engine's transaction if present
- enforce some constraints in database tables
- perform database type check when creating tables
- create tables only if not yet present
  • Loading branch information
ThorbenLindhauer committed Sep 23, 2016
1 parent dc38a51 commit a8d1f3b
Show file tree
Hide file tree
Showing 17 changed files with 304 additions and 143 deletions.
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
TODO:
* h2 only

# Camunda CMMN Prediction

This project extends the Camunda CMMN engine by capabilites to make predictions over case instance variables and case instance activity execution. It can be used for recommending tasks to case workers in the context of a case instance. Probability distributions are learned over time whenever a case instance is closed. The formalism used is [Bayesian Networks](https://en.wikipedia.org/wiki/Bayesian_network) with multinomial distributions (aka table-based distributions).
Expand Down Expand Up @@ -132,14 +135,31 @@ Map<String, Double> probabilities = predictionService.estimate(model, targetVari

## How It Works

Bayesian networks describe a probability distribution over all network variables and how this distribution factorizes. In our example CMMN case, we describe the joint distribution `P(PlanItem_Estimate_Value, PlanItem_Test_Drive, price, boot_size, doors)`. The dependencies encode that this distribution factorizes as `P(PlanItem_Estimate_Value, PlanItem_Test_Drive, price, boot_size, doors) = P(PlanItem_Estimate_Value | price) * P(PlanItem_Test_Drive | boot_size, doors) * P(price | boot_size) * P(boot_size) * P(doors)`. Assuming we know the family and parameters of these distributions, we can compute marginal probabilities like `P(PlanItem_Estimate_Value = true)` or `P(PlanItem_Estimate_Value = true | boot_size = 'large')`. These kinds of computations, called *inference*, are implemented in the [graphical-models](https://github.com/ThorbenLindhauer/graphical-models) library that this extension is based on.
Bayesian networks describe a probability distribution over all network variables and how this distribution factorizes. In our example CMMN case, we describe the following joint distribution:

```
P(PlanItem_Estimate_Value, PlanItem_Test_Drive, price, boot_size, doors)
```

The dependencies encode that this distribution factorizes as:

```
P(PlanItem_Estimate_Value, PlanItem_Test_Drive, price, boot_size, doors) =
P(PlanItem_Estimate_Value | price)
* P(PlanItem_Test_Drive | boot_size, doors)
* P(price | boot_size)
* P(boot_size)
* P(doors)
```

Assuming we know the family and parameters of these distributions, we can compute marginal probabilities like `P(PlanItem_Estimate_Value = true)` or `P(PlanItem_Estimate_Value = true | boot_size = 'large')`. These kinds of computations, called *inference*, are implemented in the [graphical-models](https://github.com/ThorbenLindhauer/graphical-models) library that this extension is based on.

In this application of Bayesian networks, we choose the involved distributions to be [multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution), i.e. discrete distributions that describe a repeated expirement with `n` different outcomes. We can learn the parameters of the individual distributions that the joint distribution factorizes to. Since we don't know these parameters for sure, we can define another distribution over the parameters. For multinomial distributions, this distribution is typically a [Dirichlet distribution](https://en.wikipedia.org/wiki/Dirichlet_distribution) due to its nice computational properties. A Dirichlet has parameters itself and these can be interpreted as how often each outcome of the multinomial was previously observed. When making a prediction, we choose a single set of multinomial parameters by taking the Dirichlet's expectation. Whenever a case instance finishes, we update the Dirichlet's parameters according to the observed outcomes. The current implementation requires that all variables must be observed. Again, this learning component is part of the graphical-models library.


## What Is Possible But Not Implemented?

* Providing **prior distributions** along with the bayesian network. A prior encodes the belief about probability distributions before seeing any data. For example, if a domain expert knows under which circumstances an activity is typically performed, this can be encoded in the deployed model and the predictions become more accurate in the beginning with little observed case instance.
* Providing **prior distributions** along with the bayesian network. A prior encodes the belief about probability distributions before seeing any data. For example, if a domain expert knows under which circumstances an activity is typically performed, this can be encoded in the deployed model and the predictions become more accurate in the beginning with only few observed case instances.
* **Hybrid networks** of multinomial distributions and conditional linear Gaussian distributions (i.e. Gaussian distributions linearly dependent on other Gaussian or multinomial distributed variables). Haven't looked into the math yet ;)
* **Learning with incomplete observations**, i.e. where not all variables in a case instance have been observed.
* **Learning the structure of the network** from case instance data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ protected List<CasePredictionTO> queryPredictions(String caseInstanceId) {

List<CasePredictionTO> result = new ArrayList<CasePredictionTO>();

Map<String, Object> variables = caseService.getVariables(caseInstanceId);
String caseDefinitionId = caseService.createCaseExecutionQuery()
.caseExecutionId(caseInstanceId).singleResult().getCaseDefinitionId();

Expand All @@ -70,6 +69,8 @@ protected List<CasePredictionTO> queryPredictions(String caseInstanceId) {

ParsedPredictionModel parsedModel = predictionService.parseModel(model);

Map<String, Object> variables = caseService.getVariables(caseInstanceId);

for (String planItemId : planItemIds) {
if (!parsedModel.getVariables().containsKey(planItemId)) {
// ignore activities not defined in model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.io.InputStream;
import java.io.InputStreamReader;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Map;

import javax.sql.DataSource;
Expand All @@ -12,23 +13,68 @@
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.SqlSessionFactory;
import org.apache.ibatis.session.SqlSessionFactoryBuilder;
import org.apache.ibatis.transaction.Transaction;
import org.apache.ibatis.transaction.TransactionFactory;
import org.apache.ibatis.transaction.managed.ManagedTransactionFactory;
import org.apache.ibatis.transaction.jdbc.JdbcTransactionFactory;
import org.camunda.bpm.hackdays.prediction.model.ParsedPredictionModel;

public class CmmnPredictionService {

protected DataSource dataSource;
protected SqlSessionFactory sqlSessionFactory;
protected TransactionFactory txFactory;

protected CmmnPredictionService() {
}

public void createDbTables() {
inTransactionDo(new ConnectionFunction() {

// TODO: perhaps this should internally get a connection from datasource
// to be consistent with other APIs
public void createDbTables(Connection dbConnection) {
new CreateTablesCmd(dbConnection).execute(this);
public void call(Connection connection) {
new CreateTablesCmd(connection).execute(CmmnPredictionService.this);
}
});
}

protected void inTransactionDo(ConnectionFunction callable) {
Transaction transaction = null;

try {
transaction = txFactory.newTransaction(dataSource, null, false);
callable.call(transaction.getConnection());

transaction.commit();
} catch (Exception e) {
if (transaction != null) {
try {
transaction.rollback();
} catch (SQLException e1) {
throw new RuntimeException("Could not rollback tx", e);
}
}
}
finally {
if (transaction != null) {
try {
transaction.close();
} catch (SQLException e) {
throw new RuntimeException("Could not close tx", e);
}
}
}
}

protected static interface ConnectionFunction {
void call(Connection connection);
}

public void dropDbTables(Connection dbConnection) {
new DropTablesCmd(dbConnection).execute(this);
public void dropDbTables() {
inTransactionDo(new ConnectionFunction() {

public void call(Connection connection) {
new DropTablesCmd(connection).execute(CmmnPredictionService.this);
}
});
}

public PredictionModel getModel(String name) {
Expand Down Expand Up @@ -56,27 +102,24 @@ public Map<String, Double> estimate(PredictionModel model, String variableName,
return new EstimateDistributionCmd(model, variableName, variableAssignments, expressionContext).execute(this);
}

public static CmmnPredictionService build(DataSource dataSource) {
public static CmmnPredictionService buildStandalone(DataSource dataSource) {
return buildWithTxFactory(dataSource, new JdbcTransactionFactory());
}

public static CmmnPredictionService buildWithTxFactory(DataSource dataSource, TransactionFactory txFactory) {
InputStream config = CmmnPredictionService.class.getClassLoader().getResourceAsStream("mybatis/mybatis-config.xml");
SqlSessionFactory sqlSessionFactory = createMyBatisSqlSessionFactory(config, dataSource);

CmmnPredictionService service = new CmmnPredictionService();
service.sqlSessionFactory = sqlSessionFactory;
service.dataSource = dataSource;
return service;
SqlSessionFactory sqlSessionFactory = createMyBatisSqlSessionFactory(config, dataSource, txFactory);

CmmnPredictionService service = new CmmnPredictionService();
service.sqlSessionFactory = sqlSessionFactory;
service.dataSource = dataSource;
service.txFactory = txFactory;
return service;
}

protected static SqlSessionFactory createMyBatisSqlSessionFactory(InputStream config, DataSource dataSource) {
// use this transaction factory if you work in a non transactional
// environment
// TransactionFactory transactionFactory = new JdbcTransactionFactory();

// use ManagedTransactionFactory if you work in a transactional
// environment (e.g. called within the engine or using JTA)

TransactionFactory transactionFactory = new ManagedTransactionFactory();
protected static SqlSessionFactory createMyBatisSqlSessionFactory(InputStream config, DataSource dataSource, TransactionFactory txFactory) {

Environment environment = new Environment("cmmn-prediction", transactionFactory, dataSource);
Environment environment = new Environment("cmmn-prediction", txFactory, dataSource);

XMLConfigBuilder parser = new XMLConfigBuilder(new InputStreamReader(config));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public Void execute(CmmnPredictionService predictionService) {
}
}

sqlSession.commit();

} finally {
sqlSession.close();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;

public class CreateTablesCmd implements Command<Void> {

Expand All @@ -22,10 +23,20 @@ public Void execute(CmmnPredictionService predictionService) {

String sqlString = new String(sqlBytes, StandardCharsets.UTF_8);

Statement statement = null;
try {
dbConnection.createStatement().execute(sqlString);
statement = dbConnection.createStatement();
statement.execute(sqlString);
} catch (SQLException e) {
throw new CmmnPredictionException("Could not create tables", e);
} finally {
try {
if (statement != null) {
statement.close();
}
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ public Void execute(CmmnPredictionService predictionService) {
for (PredictionModelPrior prior : model.getPriors()) {
sqlSession.update("PredictionModelPrior.update", prior);
}

sqlSession.commit();
} finally {
sqlSession.close();
}
Expand Down
25 changes: 20 additions & 5 deletions embedded-api/src/main/resources/sql/create.h2.sql
Original file line number Diff line number Diff line change
@@ -1,12 +1,27 @@
CREATE TABLE PREDICTION_MODEL (
ID_ varchar(64) not null,
NAME_ varchar(255) not null, -- add unique constraint
NAME_ varchar(255) not null,
RESOURCE_ longvarbinary,
primary key (ID_)
);

CREATE TABLE PREDICTION_PRIOR (
MODEL_ID_ varchar(64) not null, -- make foreign key
DESCRIBED_VARIABLE_ varchar(255) not null, -- make composite key with model_id
DATA_ longvarbinary
);
MODEL_ID_ varchar(64) not null,
DESCRIBED_VARIABLE_ varchar(255) not null,
DATA_ longvarbinary,

);

alter table PREDICTION_MODEL
add constraint PREDICTION_MODEL_NAME_UNIQUE
unique (NAME_);

alter table PREDICTION_PRIOR
add constraint PREDICTION_PRIOR_MODEL_ID_FK
foreign key (MODEL_ID_)
references PREDICTION_MODEL (ID_);

alter table PREDICTION_PRIOR
add constraint PREDICTION_PRIOR_MODEL_VARIABLE_UNIQUE
unique (MODEL_ID_, DESCRIBED_VARIABLE_);

Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,38 @@
import static org.assertj.core.api.Assertions.assertThat;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.util.HashSet;
import java.util.Set;

import javax.sql.DataSource;

import org.apache.ibatis.datasource.pooled.PooledDataSource;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

public class CreateTablesTest {

protected Connection connection;
protected DataSource dataSource;

@Before
public void setUp() throws Exception {
Class.forName("org.h2.Driver");
connection = DriverManager.getConnection("jdbc:h2:mem:foo");
// Class.forName("org.h2.Driver");
dataSource = new PooledDataSource("org.h2.Driver", "jdbc:h2:mem:foo", null);
// connection = DriverManager.getConnection("jdbc:h2:mem:foo");
}

@After
public void tearDown() throws Exception
{
if (connection != null) {
connection.close();
}
}

@Test
public void shouldCreateTables() throws Exception {
// given
CmmnPredictionService service = new CmmnPredictionService();
CmmnPredictionService service = CmmnPredictionService.buildStandalone(dataSource);

// when
service.createDbTables(connection);
service.createDbTables();

// then
ResultSet resultSet = connection.createStatement().executeQuery("SHOW TABLES");
ResultSet resultSet = dataSource.getConnection().createStatement().executeQuery("SHOW TABLES");
Set<String> tableNames = new HashSet<String>();

while (resultSet.next())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,43 +23,14 @@ protected void starting(Description description) {
"",
"");

predictionService = CmmnPredictionService.build(dataSource);

Connection connection = null;
try {
connection = dataSource.getConnection();
predictionService.createDbTables(connection);
} catch (Exception e) {
throw new RuntimeException("Could not create tables", e);
} finally {
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
throw new RuntimeException("could not close connection", e);
}
}
}
predictionService = CmmnPredictionService.buildStandalone(dataSource);
predictionService.createDbTables();
}

@Override
protected void finished(Description description) {

Connection connection = null;
try {
connection = dataSource.getConnection();
predictionService.dropDbTables(connection);
} catch (Exception e) {
throw new RuntimeException("Could not create tables", e);
} finally {
if (connection != null) {
try {
connection.close();
} catch (SQLException e) {
throw new RuntimeException("could not close connection", e);
}
}
}
predictionService.dropDbTables();

dataSource.forceCloseAll();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class PostDeployTest {

@Before
public void initPredictionService() {
predictionService = CmmnPredictionService.build(engineRule.getProcessEngineConfiguration().getDataSource());
predictionService = CmmnPredictionService.buildStandalone(engineRule.getProcessEngineConfiguration().getDataSource());
}

@After
Expand Down
2 changes: 1 addition & 1 deletion process-engine-plugin/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
</parent>

<properties>
<version.camunda>7.6.0-SNAPSHOT</version.camunda>
<version.camunda>7.6.0-alpha4</version.camunda>
<version.graphmod>0.0.1-SNAPSHOT</version.graphmod>
<version.h2>1.3.168</version.h2>
<version.junit>4.11</version.junit>
Expand Down
Loading

0 comments on commit a8d1f3b

Please sign in to comment.