Skip to content

Commit da7ba12

Browse files
committed
#26 Factor builders for Hfactor
1 parent 7173a7a commit da7ba12

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package ch.idsia.credici.factor;
2+
3+
import ch.idsia.crema.IO;
4+
import ch.idsia.crema.core.Strides;
5+
import ch.idsia.crema.factor.credal.linear.separate.SeparateHalfspaceFactor;
6+
import ch.idsia.crema.factor.credal.linear.separate.SeparateHalfspaceFactorFactory;
7+
import ch.idsia.crema.model.graphical.DAGModel;
8+
import com.google.common.primitives.Ints;
9+
import org.apache.commons.math3.optim.linear.LinearConstraint;
10+
import org.apache.commons.math3.optim.linear.Relationship;
11+
12+
import java.io.IOException;
13+
import java.nio.file.Path;
14+
import java.util.ArrayList;
15+
import java.util.List;
16+
17+
public class HalfSpaceFactorBuilder {
18+
19+
20+
public static LinearConstraint[] buildConstraints(boolean normalized, boolean nonnegative, double[][] coefficients, double[] values, Relationship... rel) {
21+
22+
int left_combinations = coefficients[0].length;
23+
List<LinearConstraint> C = new ArrayList<LinearConstraint>();
24+
25+
26+
// check the coefficient shape
27+
for (double[] c : coefficients) {
28+
if (c.length != left_combinations)
29+
throw new IllegalArgumentException("ERROR: coefficient matrix shape");
30+
}
31+
32+
// check the relationship vector length
33+
if (rel.length == 0) rel = new Relationship[]{Relationship.EQ};
34+
if (rel.length == 1) {
35+
Relationship[] rel_aux = new Relationship[coefficients.length];
36+
for (int i = 0; i < coefficients.length; i++)
37+
rel_aux[i] = rel[0];
38+
rel = rel_aux;
39+
} else if (rel.length != coefficients.length) {
40+
throw new IllegalArgumentException("ERROR: wrong relationship vector length: " + rel.length);
41+
}
42+
43+
for (int i = 0; i < coefficients.length; i++) {
44+
C.add(new LinearConstraint(coefficients[i], rel[i], values[i]));
45+
}
46+
47+
48+
// normalization constraint
49+
if (normalized) {
50+
double[] ones = new double[left_combinations];
51+
for (int i = 0; i < ones.length; i++)
52+
ones[i] = 1.;
53+
C.add(new LinearConstraint(ones, Relationship.EQ, 1.0));
54+
}
55+
56+
// non-negative constraints
57+
if (nonnegative) {
58+
double[] zeros = new double[left_combinations];
59+
for (int i = 0; i < left_combinations; i++) {
60+
double[] c = zeros.clone();
61+
c[i] = 1.;
62+
C.add(new LinearConstraint(c, Relationship.GEQ, 0));
63+
64+
}
65+
}
66+
67+
return C.toArray(LinearConstraint[]::new);
68+
}
69+
70+
71+
public static SeparateHalfspaceFactor deterministic(Strides left, Strides right, int... assignments) {
72+
73+
if (assignments.length != right.getCombinations())
74+
throw new IllegalArgumentException("ERROR: length of assignments should be equal to the number of combinations of the parents");
75+
76+
if (Ints.min(assignments) < 0 || Ints.max(assignments) >= left.getCombinations())
77+
throw new IllegalArgumentException("ERROR: assignments of deterministic function should be in the inteval [0," + left.getCombinations() + ")");
78+
79+
80+
SeparateHalfspaceFactorFactory factory =
81+
SeparateHalfspaceFactorFactory.factory().domain(left, right);
82+
83+
84+
int left_combinations = left.getCombinations();
85+
86+
for (int i = 0; i < right.getCombinations(); i++) {
87+
double[][] coeff = new double[left_combinations][left_combinations];
88+
for (int j = 0; j < left_combinations; j++) {
89+
coeff[j][j] = 1.;
90+
}
91+
double[] values = new double[left_combinations];
92+
values[assignments[i]] = 1.;
93+
94+
// Build the constraints
95+
LinearConstraint[] C = buildConstraints(true, true, coeff, values, Relationship.EQ);
96+
97+
// Add the constraints
98+
for (LinearConstraint c : C) {
99+
factory.constraint(c, i);
100+
}
101+
}
102+
103+
return factory.get();
104+
}
105+
106+
public static void main(String[] args) throws IOException {
107+
108+
Path folder = Path.of(".");
109+
110+
folder.resolve("models/party.uai");
111+
112+
113+
DAGModel m = IO.read(folder.resolve("models/party.uai").toString());
114+
115+
SeparateHalfspaceFactor f = deterministic(m.getDomain(3), m.getDomain(2,1), 0,0,1,1);
116+
System.out.println(f);
117+
118+
119+
120+
}
121+
122+
}

0 commit comments

Comments
 (0)