Skip to content

Commit c673491

Browse files
Added RANSAC algorithm for estimating the parameters of a mathematical model.
1 parent a0959f4 commit c673491

File tree

13 files changed

+945
-0
lines changed

13 files changed

+945
-0
lines changed

hipparchus-fitting/src/changes/changes.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ If the output is not quite correct, check for invisible trailing spaces!
4949
<title>Hipparchus Fitting Release Notes</title>
5050
</properties>
5151
<body>
52+
<release version="4.1" date="TBD" description="TBD">
53+
<action dev="bryan" type="add" issue="issues/424">
54+
Added RANSAC algorithm for estimating the parameters of a mathematical model.
55+
</action>
56+
</release>
5257
<release version="4.0.2" date="2025-09-08" description="This is a patch release.">
5358
<action dev="bryan" type="update">
5459
No changes directly in this module. However, lower level Hipparchus modules did change,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* https://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.hipparchus.fitting.ransac;
18+
19+
import java.util.List;
20+
21+
/**
22+
* Base class for mathematical model fitter used with {@link RansacFitter}.
23+
* @param <M> mathematical model representing the parameters to estimate
24+
* @since 4.1
25+
*/
26+
public interface IModelFitter<M> {
27+
28+
/**
29+
* Fits the mathematical model parameters based on the set of observed data.
30+
* @param points set of observed data
31+
* @return the fitted model parameters
32+
*/
33+
M fitModel(final List<double[]> points);
34+
35+
/**
36+
* Computes the error between the model and an observed data.
37+
* <p>
38+
* This method is used to determine if the observed data is an inlier or an outlier.
39+
* </p>
40+
* @param model fitted model
41+
* @param point observed data
42+
* @return the error between the model and the observed data
43+
*/
44+
double computeModelError(final M model, final double[] point);
45+
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* https://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.hipparchus.fitting.ransac;
18+
19+
import java.util.List;
20+
import java.util.stream.IntStream;
21+
import org.hipparchus.exception.LocalizedCoreFormats;
22+
import org.hipparchus.exception.MathIllegalArgumentException;
23+
import org.hipparchus.linear.Array2DRowRealMatrix;
24+
import org.hipparchus.linear.ArrayRealVector;
25+
import org.hipparchus.linear.RealMatrix;
26+
import org.hipparchus.linear.RealVector;
27+
import org.hipparchus.linear.SingularValueDecomposition;
28+
import org.hipparchus.util.FastMath;
29+
30+
/**
31+
* Fitter for polynomial model.
32+
* @since 4.1
33+
*/
34+
public class PolynomialModelFitter implements IModelFitter<PolynomialModelFitter.Model> {
35+
36+
/** Class representing the polynomial model to fit. */
37+
public static final class Model {
38+
39+
/** Coefficients of the polynomial model. */
40+
private final double[] coefficients;
41+
42+
/**
43+
* Constructor.
44+
* @param coefficients coefficients of the polynomial model
45+
*/
46+
public Model(final double[] coefficients) {
47+
this.coefficients = coefficients.clone();
48+
}
49+
50+
/**
51+
* Predicts the model value for the input point.
52+
* @param x point
53+
* @return the model value for the given point
54+
*/
55+
public double predict(final double x) {
56+
return IntStream.range(0, coefficients.length).mapToDouble(i -> coefficients[i] * FastMath.pow(x, i)).sum();
57+
}
58+
59+
/**
60+
* Get the coefficients of the polynomial model.
61+
* <p>
62+
* The coefficients are sort by degree.
63+
* For instance, for a quadratic equation the coefficients are as followed:
64+
* <code>y = coefficients[2] * x * x + coefficients[1] * x + coefficients[0]</code>
65+
* </p>
66+
* @return the coefficients of the polynomial model
67+
*/
68+
public double[] getCoefficients() {
69+
return coefficients;
70+
}
71+
}
72+
73+
/** Degree of the polynomial to fit. */
74+
private final int degree;
75+
76+
/**
77+
* Constructor.
78+
* @param degree degree of the polynomial to fit
79+
*/
80+
public PolynomialModelFitter(final int degree) {
81+
if (degree < 1) {
82+
throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, degree, 1);
83+
}
84+
this.degree = degree;
85+
}
86+
87+
/** {@inheritDoc. */
88+
@Override
89+
public Model fitModel(final List<double[]> points) {
90+
// Reference: Wikipedia page "Polynomial regression"
91+
final int size = points.size();
92+
checkSampleSize(size);
93+
94+
// Fill the data
95+
final double[][] x = new double[size][degree + 1];
96+
final double[] y = new double[size];
97+
for (int i = 0; i < size; i++) {
98+
final double currentX = points.get(i)[0];
99+
final double currentY = points.get(i)[1];
100+
double value = 1.0;
101+
for (int j = 0; j <= degree; j++) {
102+
x[i][j] = value;
103+
value *= currentX;
104+
}
105+
y[i] = currentY;
106+
}
107+
108+
// Computes (X^T.X)^-1 X^T.Y to determine the coefficients "C" of the polynomial (Y = X.C)
109+
final RealMatrix matrixX = new Array2DRowRealMatrix(x);
110+
final RealVector matrixY = new ArrayRealVector(y);
111+
final RealMatrix matrixXTranspose = matrixX.transpose();
112+
final RealMatrix xTx = matrixXTranspose.multiply(matrixX);
113+
final RealVector xTy = matrixXTranspose.operate(matrixY);
114+
final RealVector coefficients = new SingularValueDecomposition(xTx).getSolver().solve(xTy);
115+
return new Model(coefficients.toArray());
116+
}
117+
118+
/** {@inheritDoc}. */
119+
@Override
120+
public double computeModelError(final Model model, final double[] point) {
121+
return FastMath.abs(point[1] - model.predict(point[0]));
122+
}
123+
124+
/**
125+
* Verifies that the size of the set of observed data is consistent with the degree of the polynomial to fit.
126+
* @param size size of the set of observed data
127+
*/
128+
private void checkSampleSize(final int size) {
129+
if (size < degree + 1) {
130+
throw new IllegalArgumentException(String.format("Not enough points to fit polynomial model, at least %d points are required", degree + 1));
131+
}
132+
}
133+
}
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* https://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.hipparchus.fitting.ransac;
18+
19+
import java.util.ArrayList;
20+
import java.util.Collections;
21+
import java.util.List;
22+
import java.util.Optional;
23+
import java.util.Random;
24+
import java.util.stream.Collectors;
25+
import org.hipparchus.exception.LocalizedCoreFormats;
26+
import org.hipparchus.exception.MathIllegalArgumentException;
27+
28+
/**
29+
* Class implementing Random sample consensus (RANSAC) algorithm.
30+
* <p>
31+
* RANSAC is a robust method for estimating the parameters of a
32+
* mathematical model from a set of observed data.
33+
* It works iteratively selecting random subsets of the input data,
34+
* fitting a model to these subsets, and then determining how many
35+
* data points from the entire set are consistent with the estimated
36+
* model parameters.
37+
* The model can yields the largest number of inliers (i.e., point
38+
* that fit well) is considered the best estimate.
39+
* </p>
40+
* <p>
41+
* This implementation is designed to be generic and can be used with
42+
* different types of models, such as {@link PolynomialModelFitter
43+
* polynomial models}.
44+
* </p>
45+
* @param <M> mathematical model representing the parameters to estimate
46+
* @since 4.1
47+
*/
48+
public class RansacFitter<M> {
49+
50+
/** Mathematical model fitter. */
51+
private final IModelFitter<M> fitter;
52+
53+
/** The minimum number of data points to estimate the model parameters. */
54+
private final int sampleSize;
55+
56+
/** The maximum number of iterations allowed to fit the model. */
57+
private final int maxIterations;
58+
59+
/** Threshold to assert that a data point fits the model. */
60+
private final double threshold;
61+
62+
/** The minimum number of close data points required to assert that the model fits the input data. */
63+
private final int minInliers;
64+
65+
/** Random generator. */
66+
private final Random random;
67+
68+
/**
69+
* Constructor.
70+
* @param fitter mathematical model fitter
71+
* @param sampleSize minimum number of data points to estimate the model parameters
72+
* @param maxIterations maximum number of iterations allowed to fit the model
73+
* @param threshold threshold to assert that a data point fits the model
74+
* @param minInliers minimum number of close data points required to assert that the model fits the input data
75+
* @param seed seed for the random generator
76+
*/
77+
public RansacFitter(final IModelFitter<M> fitter, final int sampleSize,
78+
final int maxIterations, final double threshold,
79+
final int minInliers, final int seed) {
80+
this.fitter = fitter;
81+
this.sampleSize = sampleSize;
82+
this.maxIterations = maxIterations;
83+
this.threshold = threshold;
84+
this.minInliers = minInliers;
85+
this.random = new Random(seed);
86+
checkInputs();
87+
}
88+
89+
/**
90+
* Fits the set of observed data to determine the model parameters.
91+
* @param points set of observed data
92+
* @return a java class containing the best estimate of the model parameters
93+
*/
94+
public RansacFitterOutputs<M> fit(final List<double[]> points) {
95+
96+
// Initialize the best model data
97+
final List<double[]> data = new ArrayList<>(points);
98+
Optional<M> bestModel = Optional.empty();
99+
List<double[]> bestInliers = new ArrayList<>();
100+
101+
// Iterative loop to determine the best model
102+
for (int iteration = 0; iteration < maxIterations; iteration++) {
103+
104+
// Random permute the set of observed data and determine the inliers
105+
Collections.shuffle(data, random);
106+
final List<double[]> inliers = determineCurrentInliersFromRandomlyPermutedPoints(data);
107+
108+
// Verifies if the current inliers are fit better the model than the previous ones
109+
if (isCurrentInliersSetBetterThanPreviousOne(inliers, bestInliers)) {
110+
bestModel = Optional.of(fitter.fitModel(inliers));
111+
bestInliers = inliers;
112+
}
113+
114+
}
115+
116+
// Returns the best model data
117+
return new RansacFitterOutputs<>(bestModel, bestInliers);
118+
}
119+
120+
/**
121+
* Determines the current inliers (i.e., points that fit well the model) from the input randomly permuted data.
122+
* @param permutedPoints randomly permuted data
123+
* @return the list of inliers
124+
*/
125+
private List<double[]> determineCurrentInliersFromRandomlyPermutedPoints(final List<double[]> permutedPoints) {
126+
M model = fitter.fitModel(permutedPoints.subList(0, sampleSize));
127+
return permutedPoints.stream().filter(point -> fitter.computeModelError(model, point) < threshold).collect(Collectors.toList());
128+
}
129+
130+
/**
131+
* Verifies is the current inliers are better than the previous ones.
132+
* @param current current inliers
133+
* @param previous previous inliers
134+
* @return true is the current inlier are better than the previous ones
135+
*/
136+
private boolean isCurrentInliersSetBetterThanPreviousOne(final List<double[]> current, final List<double[]> previous) {
137+
return current.size() > previous.size() && current.size() >= minInliers;
138+
}
139+
140+
/**
141+
* Checks that the fitter inputs are correct.
142+
*/
143+
private void checkInputs() {
144+
if (maxIterations < 0) {
145+
throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, maxIterations, 0);
146+
}
147+
if (sampleSize < 0) {
148+
throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, sampleSize, 0);
149+
}
150+
if (threshold < 0.) {
151+
throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, threshold, 0);
152+
}
153+
if (minInliers < 0) {
154+
throw new MathIllegalArgumentException(LocalizedCoreFormats.NUMBER_TOO_SMALL, minInliers, 0);
155+
}
156+
}
157+
}

0 commit comments

Comments
 (0)