Skip to content

Commit

Permalink
Merge pull request #9 from minnerbe/levenberg-marquardt-cleanup
Browse files Browse the repository at this point in the history
Levenberg marquardt cleanup
  • Loading branch information
tpietzsch authored May 7, 2024
2 parents 4f8acf5 + f1ca5c6 commit 29e1bcd
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
package net.imglib2.algorithm.localization;

/**
* A n-dimensional, symmetric Gaussian peak function.
* An n-dimensional, symmetric Gaussian peak function.
* <p>
* This fitting target function is defined over dimension <code>n</code>, by the
* following <code>n+2</code> parameters:
Expand Down Expand Up @@ -70,7 +70,7 @@ public final double val(final double[] x, final double[] a) {
}

/**
* Partial derivatives indices are ordered as follow:
* Partial derivatives indices are ordered as follows:
* <pre>k = 0..n-1 - x_i (with i = k)
*k = n - A
*k = n+1 - b</pre>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import Jama.Matrix;

/**
* A plain implementation of Levenberg-Marquardt least-square curve fitting algorithm.
* A plain implementation of Levenberg-Marquardt least-squares curve fitting algorithm.
* This solver makes use of only the function value and its gradient. That is:
* candidate functions need only to implement the {@link FitFunction#val(double[], double[])}
* and {@link FitFunction#grad(double[], double[], int)} methods to operate with this
Expand All @@ -48,7 +48,7 @@ public class LevenbergMarquardtSolver implements FunctionFitter {
private final double termEpsilon;

/**
* Creates a new Levenberg-Marquardt solver for least-square curve fitting problems.
* Creates a new Levenberg-Marquardt solver for least-squares curve fitting problems.
* @param lambda blend between steepest descent (lambda high) and
* jump to bottom of quadratic (lambda zero). Start with 0.001.
* @param termEpsilon termination accuracy (0.01)
Expand All @@ -66,11 +66,11 @@ public LevenbergMarquardtSolver(int maxIteration, double lambda, double termEpsi

@Override
public String toString() {
return "Levenberg-Marquardt least-square curve fitting algorithm";
return "Levenberg-Marquardt least-squares curve fitting algorithm";
}

/**
* Creates a new Levenberg-Marquardt solver for least-square curve fitting problems,
* Creates a new Levenberg-Marquardt solver for least-squares curve fitting problems,
* with default parameters set to:
* <ul>
* <li> <code>lambda = 1e-3</code>
Expand All @@ -83,13 +83,13 @@ public LevenbergMarquardtSolver() {
}

/*
* MEETHODS
* METHODS
*/


@Override
public void fit(double[][] x, double[] y, double[] a, FitFunction f) throws Exception {
solve(x, a, y, f, lambda, termEpsilon, maxIteration);
public void fit(double[][] x, double[] y, double[] a, FitFunction f) {
fit(x, y, a, f, maxIteration, lambda, termEpsilon);
}


Expand All @@ -100,60 +100,102 @@ public void fit(double[][] x, double[] y, double[] a, FitFunction f) throws Exce

/**
* Calculate the current sum-squared-error
* This is deprecated in favor of {@link LevenbergMarquardtSolver#computeSquaredError(double[][], double[], double[], FitFunction)}.
*/
@Deprecated
public static double chiSquared(final double[][] x, final double[] a, final double[] y, final FitFunction f) {
return computeSquaredError(x, y, a, f);
}

/**
* Calculate the squared least-squares error of the given data.
*/
public static final double chiSquared(final double[][] x, final double[] a, final double[] y, final FitFunction f) {
public static double computeSquaredError(final double[][] x, final double[] y, final double[] a, final FitFunction f) {
int npts = y.length;
double sum = 0.;

for( int i = 0; i < npts; i++ ) {
double d = y[i] - f.val(x[i], a);
sum = sum + (d*d);
sum += d * d;
}

return sum;
} //chiSquared
}

/**
* Minimize E = sum {(y[k] - f(x[k],a)) }^2
* Note that function implements the value and gradient of f(x,a),
* NOT the value and gradient of E with respect to a!
* This is deprecated, use {@link LevenbergMarquardtSolver#fit(double[][], double[], double[], FitFunction, int, double, double)} instead.
*
* @param x array of domain points, each may be multidimensional
* @param y corresponding array of values
* @param a the parameters/state of the model
* @param y corresponding array of values
* @param f the function to fit
* @param lambda blend between steepest descent (lambda high) and
* jump to bottom of quadratic (lambda zero). Start with 0.001.
* @param termepsilon termination accuracy (0.01)
* @param maxiter stop and return after this many iterations if not done
*
* @return the number of iteration used by minimization
*/
public static final int solve(double[][] x, double[] a, double[] y, FitFunction f,
double lambda, double termepsilon, int maxiter) throws Exception {
@Deprecated
public static int solve(double[][] x, double[] a, double[] y, FitFunction f,
double lambda, double termepsilon, int maxiter) {
return fit(x, y, a, f, maxiter, lambda, termepsilon);
}

/**
* Minimize E = sum {(y[k] - f(x[k],a)) }^2
* Note that function implements the value and gradient of f(x,a),
* NOT the value and gradient of E with respect to a!
*
* @param x array of domain points, each may be multidimensional
* @param y corresponding array of values
* @param a the parameters/state of the model
* @param f the function to fit
* @param maxiter stop and return after this many iterations if not done
* @param lambda blend between steepest descent (lambda high) and
* jump to bottom of quadratic (lambda zero). Start with 0.001.
* @param termepsilon termination accuracy (0.01)
*
* @return the number of iteration used by minimization
*/
public static int fit(double[][] x, double[] y, double[] a, FitFunction f, int maxiter, double lambda, double termepsilon) {
int npts = y.length;
int nparm = a.length;

double e0 = chiSquared(x, a, y, f);
double e0 = computeSquaredError(x, y, a, f);
boolean done = false;

// g = gradient, H = hessian, d = step to minimum
// H d = -g, solve for d
double[][] H = new double[nparm][nparm];
double[] g = new double[nparm];

double[] valf = new double[npts];
double[][] gradf = new double[nparm][npts];

int iter = 0;
int term = 0; // termination count test

do {
++iter;

// precompute values and gradients of f
for (int i = 0; i < npts; i++) {
valf[i] = f.val(x[i], a);
for (int k = 0; k < nparm; k++) {
gradf[k][i] = f.grad(x[i], a, k);
}
}

// hessian approximation
for( int r = 0; r < nparm; r++ ) {
for( int c = 0; c < nparm; c++ ) {
H[r][c] = 0.;
for( int i = 0; i < npts; i++ ) {
double[] xi = x[i];
H[r][c] += f.grad(xi, a, r) * f.grad(xi, a, c);
H[r][c] += gradf[r][i] * gradf[c][i];
} //npts
} //c
} //r
Expand All @@ -166,12 +208,11 @@ public static final int solve(double[][] x, double[] a, double[] y, FitFunction
for( int r = 0; r < nparm; r++ ) {
g[r] = 0.;
for( int i = 0; i < npts; i++ ) {
double[] xi = x[i];
g[r] += (y[i]-f.val(xi,a)) * f.grad(xi, a, r);
g[r] += (y[i]-valf[i]) * gradf[r][i];
}
} //npts

double[] d = null;
double[] d;
try {
d = (new Matrix(H)).lu().solve(new Matrix(g, nparm)).getRowPackedCopy();
} catch (RuntimeException re) {
Expand All @@ -180,7 +221,7 @@ public static final int solve(double[][] x, double[] a, double[] y, FitFunction
continue;
}
double[] na = (new Matrix(a, nparm)).plus(new Matrix(d,nparm)).getRowPackedCopy();
double e1 = chiSquared(x, na, y, f);
double e1 = computeSquaredError(x, y, na, f);

// termination test (slightly different than NR)
if (Math.abs(e1-e0) > termepsilon) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


/**
* An fit initializer suitable for the fitting of elliptic orthogonal gaussians
* A fit initializer suitable for the fitting of elliptic orthogonal gaussians
* ({@link EllipticGaussianOrtho}, ellipse axes must be parallel to image axes)
* functions on n-dimensional image data. It uses plain maximum-likelihood
* estimator for a normal distribution.
Expand All @@ -45,7 +45,7 @@
* {@link #initializeFit(Localizable, Observation)} is based on
* maximum-likelihood estimator for a normal distribution, which requires the
* background of the image (out of peaks) to be close to 0. Returned parameters
* are ordered as follow:
* are ordered as follows:
*
* <pre>0 → ndims-1 x₀ᵢ
* ndims. A
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@


/**
* An fit initializer suitable for the fitting of gaussian peaks (
* {@link Gaussian}, on n-dimensional image data. It uses plain
* maximum-likelohood estimator for a normal distribution.
* A fit initializer suitable for the fitting of gaussian peaks (
* {@link Gaussian}), on n-dimensional image data. It uses plain
* maximum-likelihood estimator for a normal distribution.
* <p>
* The problem dimensionality is specified at construction by <code>nDims</code>
* parameter.
Expand All @@ -42,8 +42,8 @@
* <p>
* Parameters estimation returned by
* {@link #initializeFit(Localizable, Observation)} is based on
* maximum-likelihood esimtation, which requires the background of the image
* (out of peaks) to be close to 0. Returned parameters are ordered as follow:
* maximum-likelihood estimation, which requires the background of the image
* (out of peaks) to be close to 0. Returned parameters are ordered as follows:
*
* <pre>
* 0. A
Expand Down Expand Up @@ -81,7 +81,7 @@ public MLGaussianEstimator(double typicalSigma, int nDims) {

@Override
public String toString() {
return "Maximum-likelihood estimator for symetric gaussian peaks";
return "Maximum-likelihood estimator for symmetric gaussian peaks";
}


Expand Down Expand Up @@ -121,7 +121,7 @@ public double[] initializeFit(final Localizable point, final Observation data) {
}

// Estimate b in all dimension
double bs[] = new double[nDims];
double[] bs = new double[nDims];
for (int j = 0; j < nDims; j++) {
double C = 0;
double dx;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import net.imglib2.algorithm.Benchmark;
import net.imglib2.algorithm.MultiThreaded;
import net.imglib2.algorithm.OutputAlgorithm;
import net.imglib2.img.Img;
import net.imglib2.type.numeric.RealType;

/**
Expand Down Expand Up @@ -131,7 +130,7 @@ public void run() {
double[] I = data.I;
fitter.fit(X, I, params, peakFunction);
} catch (Exception e) {
errorHolder.append(BASE_ERROR_MESSAGE +
errorHolder.append(BASE_ERROR_MESSAGE +
"Problem fitting around " + peak +
": " + e.getMessage() + ".\n");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
* <li> They must be able to provide a starting point to the curve fitting solver,
* based on the image data around the coarse peak location.
* </ul>
* Depending on the problem they are taylored for, implementations can be very crude:
* Depending on the problem they are tailored for, implementations can be very crude:
* One can return plain constants if the typical parameters of all peaks are known
* and uniform. Refined method are also possible.
* <p>
Expand Down Expand Up @@ -69,7 +69,7 @@ public interface StartPointEstimator {
public long[] getDomainSpan();

/**
* Returns a new double array containing an starting point estimate for a
* Returns a new double array containing a starting point estimate for a
* specific curve fitting problem. Depending on the implementation, this
* estimate can be calculated from the specified point and the specified
* image data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public class PeakFitterTest {
private static final double LOCALIZATION_TOLERANCE = 0.1d;

@Test
public void testSymetricGaussian() {
public void testSymmetricGaussian() {

int width = 200;
int height = 200;
Expand All @@ -64,8 +64,8 @@ public void testSymetricGaussian() {
long[] dimensions = new long[] { width, height };
ArrayImg<UnsignedByteType,ByteArray> img = ArrayImgs.unsignedBytes(dimensions);

Collection<Localizable> peaks = new HashSet<Localizable>(nspots);
Map<Localizable, double[]> groundTruth = new HashMap<Localizable, double[]>(nspots);
Collection<Localizable> peaks = new HashSet<>(nspots);
Map<Localizable, double[]> groundTruth = new HashMap<>(nspots);

for (int i = 1; i < nspots; i++) {

Expand All @@ -88,7 +88,7 @@ public void testSymetricGaussian() {
}

// Instantiate fitter once
PeakFitter<UnsignedByteType> fitter = new PeakFitter<UnsignedByteType>(img, peaks,
PeakFitter<UnsignedByteType> fitter = new PeakFitter<>(img, peaks,
new LevenbergMarquardtSolver(), new Gaussian(), new MLGaussianEstimator(2d, 2));

if ( !fitter.checkInput() || !fitter.process()) {
Expand All @@ -106,7 +106,7 @@ public void testSymetricGaussian() {
assertEquals("Bad accuracy on amplitude parameter A: ", truth[2], params[2], TOLERANCE * truth[2]);
assertEquals("Bad accuracy on peak location x0: ", truth[0], params[0], LOCALIZATION_TOLERANCE);
assertEquals("Bad accuracy on peak location y0: ", truth[1], params[1], LOCALIZATION_TOLERANCE);
assertEquals("Bad accuracy on peak paramter b: ", truth[3], params[3], TOLERANCE * truth[3]);
assertEquals("Bad accuracy on peak parameter b: ", truth[3], params[3], TOLERANCE * truth[3]);
}
}

Expand All @@ -120,8 +120,8 @@ public void testEllipticGaussian() {
long[] dimensions = new long[] { width, height };
ArrayImg<UnsignedByteType,ByteArray> img = ArrayImgs.unsignedBytes(dimensions);

Collection<Localizable> peaks = new HashSet<Localizable>(nspots);
Map<Localizable, double[]> groundTruth = new HashMap<Localizable, double[]>(nspots);
Collection<Localizable> peaks = new HashSet<>(nspots);
Map<Localizable, double[]> groundTruth = new HashMap<>(nspots);

for (int i = 1; i < nspots; i++) {

Expand All @@ -145,7 +145,7 @@ public void testEllipticGaussian() {
}

// Instantiate fitter once
PeakFitter<UnsignedByteType> fitter = new PeakFitter<UnsignedByteType>(img, peaks,
PeakFitter<UnsignedByteType> fitter = new PeakFitter<>(img, peaks,
new LevenbergMarquardtSolver(), new EllipticGaussianOrtho(), new MLEllipticGaussianEstimator(new double[] { 2d, 2d}));

if ( !fitter.checkInput() || !fitter.process()) {
Expand All @@ -163,8 +163,8 @@ public void testEllipticGaussian() {
assertEquals("Bad accuracy on amplitude parameter A: ", truth[2], params[2], TOLERANCE * truth[2]);
assertEquals("Bad accuracy on peak location x0: ", truth[0], params[0], LOCALIZATION_TOLERANCE);
assertEquals("Bad accuracy on peak location y0: ", truth[1], params[1], LOCALIZATION_TOLERANCE);
assertEquals("Bad accuracy on peak paramter bx: ", truth[3], params[3], TOLERANCE * truth[3]);
assertEquals("Bad accuracy on peak paramter by: ", truth[4], params[4], TOLERANCE * truth[4]);
assertEquals("Bad accuracy on peak parameter bx: ", truth[3], params[3], TOLERANCE * truth[3]);
assertEquals("Bad accuracy on peak parameter by: ", truth[4], params[4], TOLERANCE * truth[4]);

// System.out.println(String.format("- For " + peak + "\n - Found : " +
// "A = %6.2f, x0 = %6.2f, y0 = %6.2f, sx = %5.2f, sy = %5.2f",
Expand Down

0 comments on commit 29e1bcd

Please sign in to comment.