Skip to content

Commit

Permalink
refactor: Tidy interpolator and benchmark (#1989)
Browse files Browse the repository at this point in the history
  • Loading branch information
trisyoungs authored Oct 22, 2024
1 parent 96386cb commit 9aed5a3
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 120 deletions.
21 changes: 14 additions & 7 deletions benchmark/math/interpolator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Copyright (c) 2024 Team Dissolve and contributors

#include "math/interpolator.h"
#include "math/data1D.h"
#include <algorithm>
#include <benchmark/benchmark.h>
#include <random>
#include <vector>
Expand All @@ -12,23 +14,28 @@ static void BM_Interpolator(benchmark::State &state)
{
int bytes = state.range(0);
int numVals = (bytes / sizeof(double));
std::vector<double> xs(numVals), ys(numVals), samples(numVals);

// Set up rng
std::uniform_real_distribution<double> unif(-100, 100);
std::default_random_engine re;

for (auto &x : xs)
x = unif(re);
for (auto &y : ys)
y = unif(re);
// Generate data
Data1D data;
data.initialise(numVals);
std::iota(data.xAxis().begin(), data.xAxis().end(), 0.0);
std::generate(data.values().begin(), data.values().end(), [&]() { return unif(re); });

// Generate sampling
std::vector<double> samples(numVals);
for (auto &s : samples)
s = unif(re);

Interpolator interp(xs, ys, Interpolator::SplineInterpolation);
Interpolator interp(data, Interpolator::SplineInterpolation);

for (auto _ : state)
{
auto result = interp.y(samples);
for (auto s : samples)
auto result = interp.y(s);
}

state.SetBytesProcessed(long(state.iterations()) * (long(bytes)));
Expand Down
112 changes: 1 addition & 111 deletions src/math/interpolator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ double Interpolator::y(double x, int interval)
// Approximate data at specified x value using three-point interpolation
double Interpolator::approximate(const Data1D &data, double x)
{
// Grab xand y arrays
// Grab x and y arrays
const auto &xData = data.xAxis();
const auto &yData = data.values();

Expand Down Expand Up @@ -474,113 +474,3 @@ void Interpolator::addInterpolated(Interpolator &source, Data1D &dest, double fa
for (auto &&[x, y] : zip(dest.xAxis(), dest.values()))
y += source.y(x) * factor;
}

// Add test function for benchmarks
std::vector<double> Interpolator::y(const std::vector<double> &xs)
{
// Creates vector to store results in reserving space the same size as xs
std::vector<double> result;
result.reserve(xs.size());

// Do we need to (re)generate the interpolation?
if (lastInterval_ == -1)
{
// Do we know what the interpolation scheme is?
if (scheme_ != Interpolator::NoInterpolation)
interpolate(scheme_);
else
{
// No existing interpolation scheme, so use Spline by default
interpolate(Interpolator::SplineInterpolation);
}
}
// Runs double Interpolator::y(double x) on a loop for speed testing
for (auto x : xs)
{
// Quick check of our interval - if the data is sequential increasing in x then we should be able to quickly determine
// it and avoid the binary chop
if (lastInterval_ != -1)
{
// If the x value exceeds the next interval boundary, try to increment it
if ((lastInterval_ + 1) < x_.size() && x >= x_[lastInterval_ + 1])
{
++lastInterval_;

// If there are still intervals beyond this one, check the next limit
if ((lastInterval_ + 1) < x_.size() && x >= x_[lastInterval_ + 1])
lastInterval_ = -1;
}

// Double-check lower limit
if (lastInterval_ > 0 && (x < x_[lastInterval_]))
lastInterval_ = -1;
}

// Perform binary chop search if no valid interval was found
if (lastInterval_ == -1)
{
lastInterval_ = 0;
int i, right = h_.size() - 1;
while ((right - lastInterval_) > 1)
{
i = (right + lastInterval_) / 2;
if (x_[i] > x)
right = i;
else
lastInterval_ = i;
}
}

if (lastInterval_ < 0)
result.push_back(y_.front());
switch (scheme_)
{
case (Interpolator::SplineInterpolation):
{
if (x >= x_.back())
result.push_back(y_.back());

auto h = x - x_[lastInterval_];
auto hh = h * h;
result.push_back(a_[lastInterval_] + b_[lastInterval_] * h + c_[lastInterval_] * hh +
d_[lastInterval_] * hh * h);
}
break;
// case (Interpolator::ConstrainedSplineInterpolation):
// {
// auto h = x;
// auto hh = h*h;
// return a_[interval] + b_[interval]*h + c_[interval]*hh + d_[interval]*hh*h;
// }
// break;
case (Interpolator::LinearInterpolation):
{
if (lastInterval_ >= (x_.size() - 1))
result.push_back(y_.back());

auto delta = (x - x_[lastInterval_]) / h_[lastInterval_];
result.push_back(y_[lastInterval_] + delta * a_[lastInterval_]);
}
break;
case (Interpolator::ThreePointInterpolation):
{
if (lastInterval_ >= (x_.size() - 3))
result.push_back(y_.back());

auto ppp = (x - x_[lastInterval_]) / h_[lastInterval_];

auto vk0 = y_[lastInterval_];
auto vk1 = y_[lastInterval_ + 1];
auto vk2 = y_[lastInterval_ + 2];
auto t1 = vk0 + (vk1 - vk0) * ppp;
auto t2 = vk1 + (vk2 - vk1) * (ppp - 1.0);
result.push_back(t1 + (t2 - t1) * ppp * 0.5);
}
break;
default:
// if no interpolation scheme selected then fills vector with 0.0
result.push_back(0.0);
}
}
return result;
}
2 changes: 0 additions & 2 deletions src/math/interpolator.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class Interpolator
double y(double x);
// Return interpolated y value for supplied x, specifying containing interval
double y(double x, int interval);
// Return interpolated y values for all supplied x values
std::vector<double> y(const std::vector<double> &xs);

/*
* Static Functions
Expand Down

0 comments on commit 9aed5a3

Please sign in to comment.