Skip to content

Commit

Permalink
update with ddm utility
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreMarchand20 committed Nov 29, 2023
1 parent fc946fa commit cd4ee5c
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 50 deletions.
11 changes: 4 additions & 7 deletions example/use_cluster.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import matplotlib.pyplot as plt
import mpi4py
import numpy as np
from create_geometry import (
create_partitionned_geometries_test,
create_random_geometries,
)
from create_geometry import create_random_geometries

import Htool

Expand Down Expand Up @@ -57,7 +54,7 @@
ax2 = fig.add_subplot(1, 2, 2, projection="3d")

ax1.set_title("target cluster\ndepth 2")
ax2.set_title("local cluster\ntarget partition number 0\ndepth 1")
Htool.plot(ax1, target_cluster, target_points, 1)
Htool.plot(ax2, local_target_cluster, target_points, 4)
ax2.set_title("local cluster\ntarget partition number 0\ndepth 2")
Htool.plot(ax1, target_cluster, target_points, 2)
Htool.plot(ax2, local_target_cluster, target_points, 2)
plt.show()
4 changes: 2 additions & 2 deletions src/htool/hmatrix/interfaces/virtual_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class VirtualGeneratorWithPermutationPython : public htool::VirtualGeneratorWith
if (M * N > 0) {
py::array_t<CoefficientPrecision, py::array::f_style> mat(std::array<long int, 2>{M, N}, ptr, py::capsule(ptr));

py::array_t<int> py_rows(std::array<long int, 2>{M, 1}, rows, py::capsule(rows));
py::array_t<int> py_cols(std::array<long int, 2>{N, 1}, cols, py::capsule(cols));
py::array_t<int> py_rows(std::array<long int, 1>{M}, rows, py::capsule(rows));
py::array_t<int> py_cols(std::array<long int, 1>{N}, cols, py::capsule(cols));

build_submatrix(py_rows, py_cols, mat);
}
Expand Down
2 changes: 2 additions & 0 deletions src/htool/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ PYBIND11_MODULE(Htool, m) {
declare_distributed_operator_utility<double, double>(m);

declare_DDM<double>(m, "Solver");
declare_solver_utility(m);
declare_solver_utility<double, double>(m);

declare_matplotlib_cluster<double>(m);
declare_matplotlib_hmatrix<double, double>(m);

declare_hmatrix_builder<std::complex<double>, double>(m, "ComplexHMatrixBuilder");
declare_HMatrix<std::complex<double>, double>(m, "ComplexHMatrix");
declare_virtual_generator<std::complex<double>>(m, "ComplexVirtualGenerator", "IComplexGenerator");

Expand Down
16 changes: 15 additions & 1 deletion src/htool/matplotlib/cluster.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,21 @@ void declare_matplotlib_cluster(py::module &m) {

// Create Color Map
py::object colormap = plt.attr("get_cmap")("Dark2");
py::object norm = colors.attr("Normalize")("vmin"_a = (*std::min_element(partition_numbers.begin(), partition_numbers.end())), "vmax"_a = (*std::max_element(partition_numbers.begin(), partition_numbers.end())));
if (counter == 9) {
colormap = plt.attr("get_cmap")("Set1");
}
if (counter == 10) {
colormap = plt.attr("get_cmap")("tab10");
}
if (counter > 10 && counter <= 20) {
colormap = plt.attr("get_cmap")("tab20");
}
if (counter > 20) {
htool::Logger::get_instance()
.log(LogLevel::WARNING, "Colormap does not support more than 20 colors.");
}

py::object norm = colors.attr("Normalize")("vmin"_a = (*std::min_element(partition_numbers.begin(), partition_numbers.end())), "vmax"_a = (*std::max_element(partition_numbers.begin(), partition_numbers.end())));

// Figure
if (spatial_dimension == 2) {
Expand Down
32 changes: 25 additions & 7 deletions src/htool/solver/utility.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,41 @@

#include <htool/solvers/utility.hpp>

void declare_solver_utility(py::module &m) {
py::class_<LocalNumberingBuilder> py_class(m, "LocalNumberingBuilder");
py_class.def(py::init<const std::vector<int> &, const std::vector<int> &, const std::vector<std::vector<int>> &>());
py_class.def_property_readonly(
"local_to_global_numbering", [](const LocalNumberingBuilder &self) { return &self.local_to_global_numbering; }, py::return_value_policy::reference_internal);
py_class.def_property_readonly(
"intersections", [](const LocalNumberingBuilder &self) { return &self.intersections; }, py::return_value_policy::reference_internal);
}

template <typename CoefficientPrecision, typename CoordinatePrecision>
void declare_solver_utility(py::module &m, std::string prefix = "") {

using DefaultSolverBuilder = DefaultSolverBuilder<CoefficientPrecision, CoordinatePrecision>;
using DefaultDDMSolverBuilder = DefaultDDMSolverBuilder<CoefficientPrecision, CoordinatePrecision>;
std::string default_solver_name = prefix + "DefaultSolverBuilder";
std::string default_ddm_solver_name = prefix + "DefaultDDMSolverBuilder";
using DefaultSolverBuilder = DefaultSolverBuilder<CoefficientPrecision, CoordinatePrecision>;
using DefaultDDMSolverBuilderAddingOverlap = DefaultDDMSolverBuilderAddingOverlap<CoefficientPrecision, CoordinatePrecision>;
using DefaultDDMSolverBuilder = DefaultDDMSolverBuilder<CoefficientPrecision, CoordinatePrecision>;

std::string default_solver_name = prefix + "DefaultSolverBuilder";
std::string default_ddm_solver_adding_overlap_name = prefix + "DefaultDDMSolverBuilderAddingOverlap";
std::string default_ddm_solver_name = prefix + "DefaultDDMSolverBuilder";

py::class_<DefaultSolverBuilder> default_solver_class(m, default_solver_name.c_str());
default_solver_class.def(py::init<DistributedOperator<CoefficientPrecision> &, const HMatrix<CoefficientPrecision, CoordinatePrecision> *>());
default_solver_class.def_property_readonly(
"solver", [](const DefaultSolverBuilder &self) { return &self.solver; }, py::return_value_policy::reference_internal);

py::class_<DefaultDDMSolverBuilderAddingOverlap> default_ddm_solver_adding_overlap_class(m, default_ddm_solver_adding_overlap_name.c_str());
default_ddm_solver_adding_overlap_class.def(py::init<DistributedOperator<CoefficientPrecision> &, const HMatrix<CoefficientPrecision, CoordinatePrecision> *, const VirtualGeneratorWithPermutation<CoefficientPrecision> &, const std::vector<int> &, const std::vector<int> &, const std::vector<int> &, const std::vector<std::vector<int>> &>());
default_ddm_solver_adding_overlap_class.def_property_readonly(
"solver", [](const DefaultDDMSolverBuilderAddingOverlap &self) { return &self.solver; }, py::return_value_policy::reference_internal);
default_ddm_solver_adding_overlap_class.def_property_readonly(
"local_to_global_numbering", [](const DefaultDDMSolverBuilderAddingOverlap &self) { return &self.local_to_global_numbering; }, py::return_value_policy::reference_internal);

py::class_<DefaultDDMSolverBuilder> default_ddm_solver_class(m, default_ddm_solver_name.c_str());
default_ddm_solver_class.def(py::init<DistributedOperator<CoefficientPrecision> &, const HMatrix<CoefficientPrecision, CoordinatePrecision> *, const VirtualGeneratorWithPermutation<CoefficientPrecision> &, const std::vector<int> &, const std::vector<int> &, const std::vector<int> &, const std::vector<std::vector<int>> &>());
default_ddm_solver_class.def(py::init<DistributedOperator<CoefficientPrecision> &, const HMatrix<CoefficientPrecision, CoordinatePrecision> &, const std::vector<int> &, const std::vector<std::vector<int>> &>(), py::keep_alive<1, 2>(), py::keep_alive<1, 4>(), py::keep_alive<1, 5>());
default_ddm_solver_class.def_property_readonly(
"solver", [](const DefaultDDMSolverBuilder &self) { return &self.solver; }, py::return_value_policy::reference_internal);
default_ddm_solver_class.def_property_readonly(
"local_to_global_numbering", [](const DefaultDDMSolverBuilder &self) { return &self.local_to_global_numbering; }, py::return_value_policy::reference_internal);
}
#endif
33 changes: 33 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,29 @@ def build_submatrix(self, J, K, mat):
mat[j, k] = self.get_coef(J[j], K[k])


class LocalGeneratorFromMatrix(Htool.ComplexVirtualGeneratorWithPermutation):
def __init__(
self,
permutation,
local_to_global_numbering,
matrix,
):
super().__init__(permutation, permutation)
self.matrix = matrix
self.local_to_global_numbering = local_to_global_numbering

def get_coef(self, i, j):
return self.matrix[i, j]

def build_submatrix(self, J, K, mat):
for j in range(0, len(J)):
for k in range(0, len(K)):
mat[j, k] = self.get_coef(
self.local_to_global_numbering[J[j]],
self.local_to_global_numbering[K[k]],
)


@pytest.hookimpl(trylast=True)
def pytest_configure(config):
if mpi4py.MPI.COMM_WORLD.Get_rank() != 0:
Expand Down Expand Up @@ -268,6 +291,15 @@ def load_data_solver(symmetry, mu):
A = np.frombuffer(data[8:], dtype=np.dtype("complex128"))
A = np.transpose(A.reshape((m, n)))

# Geometry
with open(
path_to_data / "geometry.bin",
"rb",
) as input:
data = input.read()
geometry = np.frombuffer(data[4:], dtype=np.dtype("double"))
geometry = geometry.reshape(3, m, order="F")

# Right-hand side
with open(
path_to_data / "rhs.bin",
Expand Down Expand Up @@ -348,6 +380,7 @@ def load_data_solver(symmetry, mu):
A,
x_ref,
f,
geometry,
cluster,
neighbors,
intersections,
Expand Down
117 changes: 86 additions & 31 deletions tests/test_ddm_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import pathlib
import struct

import matplotlib.pyplot as plt
import mpi4py
import numpy as np
import pytest
from conftest import GeneratorFromMatrix
from conftest import GeneratorFromMatrix, LocalGeneratorFromMatrix
from mpi4py import MPI

import Htool
Expand All @@ -16,32 +17,44 @@
@pytest.mark.parametrize("eta", [10])
@pytest.mark.parametrize("tol", [1e-6])
@pytest.mark.parametrize(
"mu,symmetry,overlap,hpddm_schwarz_method,hpddm_schwarz_coarse_correction",
"mu,symmetry,ddm_builder,hpddm_schwarz_method,hpddm_schwarz_coarse_correction",
[
(1, "N", False, "none", "none"),
(1, "N", False, "asm", "none"),
(1, "N", False, "ras", "none"),
(1, "N", True, "asm", "none"),
(1, "N", True, "ras", "none"),
(10, "N", False, "none", "none"),
(10, "N", False, "asm", "none"),
(10, "N", False, "ras", "none"),
(10, "N", True, "asm", "none"),
(10, "N", True, "ras", "none"),
(1, "S", False, "none", "none"),
(1, "S", False, "asm", "none"),
(1, "S", False, "ras", "none"),
(1, "S", True, "asm", "none"),
(1, "S", True, "ras", "none"),
(10, "S", False, "none", "none"),
(10, "S", False, "asm", "none"),
(10, "S", False, "ras", "none"),
(10, "S", True, "asm", "none"),
(10, "S", True, "ras", "none"),
(1, "S", True, "asm", "additive"),
(1, "S", True, "ras", "additive"),
(10, "S", True, "asm", "additive"),
(10, "S", True, "ras", "additive"),
(1, "N", "SolverBuilder", "none", "none"),
(1, "N", "SolverBuilder", "asm", "none"),
(1, "N", "SolverBuilder", "ras", "none"),
(1, "N", "DDMSolverBuilderAddingOverlap", "asm", "none"),
(1, "N", "DDMSolverBuilderAddingOverlap", "ras", "none"),
(1, "N", "DDMSolverBuilder", "asm", "none"),
(1, "N", "DDMSolverBuilder", "ras", "none"),
(10, "N", "SolverBuilder", "none", "none"),
(10, "N", "SolverBuilder", "asm", "none"),
(10, "N", "SolverBuilder", "ras", "none"),
(10, "N", "DDMSolverBuilderAddingOverlap", "asm", "none"),
(10, "N", "DDMSolverBuilderAddingOverlap", "ras", "none"),
(10, "N", "DDMSolverBuilder", "asm", "none"),
(10, "N", "DDMSolverBuilder", "ras", "none"),
(1, "S", "SolverBuilder", "none", "none"),
(1, "S", "SolverBuilder", "asm", "none"),
(1, "S", "SolverBuilder", "ras", "none"),
(1, "S", "DDMSolverBuilderAddingOverlap", "asm", "none"),
(1, "S", "DDMSolverBuilderAddingOverlap", "ras", "none"),
(1, "S", "DDMSolverBuilder", "asm", "none"),
(1, "S", "DDMSolverBuilder", "ras", "none"),
(10, "S", "SolverBuilder", "none", "none"),
(10, "S", "SolverBuilder", "asm", "none"),
(10, "S", "SolverBuilder", "ras", "none"),
(10, "S", "DDMSolverBuilderAddingOverlap", "asm", "none"),
(10, "S", "DDMSolverBuilderAddingOverlap", "ras", "none"),
(1, "S", "DDMSolverBuilderAddingOverlap", "asm", "additive"),
(1, "S", "DDMSolverBuilderAddingOverlap", "ras", "additive"),
(10, "S", "DDMSolverBuilderAddingOverlap", "asm", "additive"),
(10, "S", "DDMSolverBuilderAddingOverlap", "ras", "additive"),
(10, "S", "DDMSolverBuilder", "asm", "none"),
(10, "S", "DDMSolverBuilder", "ras", "none"),
(1, "S", "DDMSolverBuilder", "asm", "additive"),
(1, "S", "DDMSolverBuilder", "ras", "additive"),
(10, "S", "DDMSolverBuilder", "asm", "additive"),
(10, "S", "DDMSolverBuilder", "ras", "additive"),
],
# indirect=["setup_solver_dependencies"],
)
Expand All @@ -50,7 +63,7 @@ def test_ddm_solver(
epsilon,
eta,
mu,
overlap,
ddm_builder,
symmetry,
tol,
hpddm_schwarz_method,
Expand All @@ -69,6 +82,7 @@ def test_ddm_solver(
A,
x_ref,
f,
geometry,
cluster,
neighbors,
intersections,
Expand All @@ -90,16 +104,20 @@ def test_ddm_solver(
UPLO,
mpi4py.MPI.COMM_WORLD,
)

# print("Geometry", geometry)
# fig = plt.figure()
# ax = fig.add_subplot(1, 1, 1, projection="3d")
# ax.scatter(geometry[0, :], geometry[1, :], geometry[2, :], marker="o")
# plt.show()
solver = None
if not overlap:
if ddm_builder == "SolverBuilder":
default_solver_builder = Htool.ComplexDefaultSolverBuilder(
default_approximation.distributed_operator,
default_approximation.block_diagonal_hmatrix,
)
solver = default_solver_builder.solver
else:
default_solver_builder = Htool.ComplexDefaultDDMSolverBuilder(
elif ddm_builder == "DDMSolverBuilderAddingOverlap":
default_solver_builder = Htool.ComplexDefaultDDMSolverBuilderAddingOverlap(
default_approximation.distributed_operator,
default_approximation.block_diagonal_hmatrix,
generator,
Expand All @@ -109,6 +127,43 @@ def test_ddm_solver(
intersections,
)
solver = default_solver_builder.solver
elif ddm_builder == "DDMSolverBuilder":
local_numbering_builder = Htool.LocalNumberingBuilder(
ovr_subdomain_to_global,
cluster_to_ovr_subdomain,
intersections,
)
intersections = local_numbering_builder.intersections
local_to_global_numbering = local_numbering_builder.local_to_global_numbering
local_geometry = geometry[:, local_to_global_numbering]

local_cluster_builder = Htool.ClusterBuilder()

local_cluster: Htool.Cluster = local_cluster_builder.create_cluster_tree(
local_geometry, 2, 2
)

local_hmatrix_builder = Htool.ComplexHMatrixBuilder(
local_cluster,
local_cluster,
epsilon,
eta,
symmetry,
UPLO,
-1,
-1,
)
local_generator = LocalGeneratorFromMatrix(
local_cluster.get_permutation(), local_to_global_numbering, A
)
local_hmatrix = local_hmatrix_builder.build(local_generator)
default_solver_builder = Htool.ComplexDefaultDDMSolverBuilder(
default_approximation.distributed_operator,
local_hmatrix,
neighbors,
intersections,
)
solver = default_solver_builder.solver

distributed_operator = default_approximation.distributed_operator

Expand Down

0 comments on commit cd4ee5c

Please sign in to comment.