From d25a7ffaeecd93bcebf9f010b862142666e4f053 Mon Sep 17 00:00:00 2001 From: JulianKarlBauer Date: Thu, 3 Feb 2022 16:45:33 +0100 Subject: [PATCH] Refactor function into Operator class for consistency --- mechkit/operators.py | 75 ++++++++++++++++++++++++++++-------------- test/test_operators.py | 50 ++++++++++++++++++++++++++-- 2 files changed, 98 insertions(+), 27 deletions(-) diff --git a/mechkit/operators.py b/mechkit/operators.py index 1a5185e..26b2a2a 100644 --- a/mechkit/operators.py +++ b/mechkit/operators.py @@ -8,32 +8,61 @@ ########################################################################## -def sym(tensor, sym_axes=None): +class Abstract_Operator: + def __call__(self, *args, **kwargs): + return self.function(*args, **kwargs) + + def check(self, tensor, *args, **kwargs): + """Check whether `tensor` has the symmetry specified by Sym operator""" + + return np.allclose(tensor, self.function(tensor=tensor, *args, **kwargs)) + + +class Sym(Abstract_Operator): """ - Symmetrize selected axes of tensor. - If no sym_axes are specified, all axes are symmetrized + Based on the `axes` argument of the class initiation, + the returned instance act as a symmetrization function, + which symmetrices a given tensor with respect to the + specified axes. + If `axes` is `None`, all axes of the tensor are symmetrized """ - base_axis = np.array(range(len(tensor.shape))) - sym_axes = base_axis if sym_axes is None else sym_axes + def __init__(self, axes=None): + if axes is None: + self.function = self._sym_all_axes + else: + self.function = self._sym_selected_axes + self.axes = axes + + def _sym_all_axes(self, tensor): + permutations = list(itertools.permutations(self._get_base_axes(tensor=tensor))) - perms = itertools.permutations(sym_axes) + return self._symmetrize(tensor=tensor, permutations=permutations) - axes = list() - for perm in perms: - axis = base_axis.copy() - axis[sym_axes] = perm - axes.append(axis) + def _sym_selected_axes(self, tensor): + tmp_permutations = itertools.permutations(self.axes) + base_axes = self._get_base_axes(tensor=tensor) - return 1.0 / len(axes) * sum(tensor.transpose(axis) for axis in axes) + permutations = [] + for perm in tmp_permutations: + axis = base_axes.copy() + axis[self.axes] = perm + permutations.append(axis) + return self._symmetrize(tensor=tensor, permutations=permutations) -def is_sym(tensor, sym_axes=None): - """Test whether `tensor` has the index symmetry specified by `sym_axes`""" - return np.allclose(tensor, sym(tensor, sym_axes=sym_axes)) + def _get_base_axes(self, tensor): + return np.array(range(len(tensor.shape))) + def _symmetrize(self, tensor, permutations): + return ( + 1.0 + / len(permutations) + * sum(tensor.transpose(perm) for perm in permutations) + ) -class Sym_Fourth_Order_Special(object): + +class Sym_Fourth_Order_Special(Abstract_Operator): """ Based on the `label` argument of the class initiation, the returned instance act as a symmetrization function, @@ -59,12 +88,6 @@ def __init__(self, label=None): else: raise utils.Ex("Please specify a valid symmetry label") - def __call__(self, *args, **kwargs): - return self.function(*args, **kwargs) - - def check(self, tensor, *args, **kwargs): - return np.allclose(tensor, self.function(tensor=tensor, *args, **kwargs)) - def _set_permutation_lists(self): base_permutations = { "identity": (0, 1, 2, 3), @@ -142,6 +165,8 @@ def dev_tensor_4th_order_simple(tensor): "Tensor of fourth order" " has to be in tensor notation" ) + sym = Sym() + tensor_4 = sym(tensor) tensor_2 = np.einsum("ppij->ij", tensor_4) trace = np.einsum("ii->", tensor_2) @@ -182,11 +207,13 @@ def dev_t4_kanatani1984(self, tensor): 3, ), "Requires tensor 4.order tensor notation" - assert is_sym(tensor), "Only valid for completely symmetric tensor" + assert Sym().check(tensor), "Only valid for completely symmetric tensor" assert np.isclose( np.einsum("iijj->", tensor), 1.0 ), "Only valid for completely symmetric tensor with complete trace is one" + sym = Sym() + I2 = np.eye(3, dtype="float64") dev = ( tensor @@ -204,7 +231,7 @@ def dev_t4_spencer1970(self, tensor): "Requires tensor 4.order in " "tensor notation" ) - tensor_sym = sym(tensor) + tensor_sym = Sym()(tensor) I2 = np.eye(3, dtype="float64") a2 = np.einsum("ppij->ij", tensor_sym) diff --git a/test/test_operators.py b/test/test_operators.py index 393a841..c5ca36f 100644 --- a/test/test_operators.py +++ b/test/test_operators.py @@ -6,6 +6,7 @@ import mechkit from mechkit.operators import Sym_Fourth_Order_Special import pytest +import itertools con = mechkit.notation.Converter() @@ -69,6 +70,18 @@ def has_sym_inner(A): assert np.isclose(A[iii, jjj, kkk, lll], A[kkk, lll, iii, jjj]) +def has_sym_complete(A): + for iii in range(3): + for jjj in range(3): + for kkk in range(3): + for lll in range(3): + print(iii, jjj, kkk, lll) + assert np.isclose(A[iii, jjj, kkk, lll], A[jjj, iii, kkk, lll]) + assert np.isclose(A[iii, jjj, kkk, lll], A[iii, jjj, lll, kkk]) + assert np.isclose(A[iii, jjj, kkk, lll], A[kkk, lll, iii, jjj]) + assert np.isclose(A[iii, jjj, kkk, lll], A[lll, kkk, jjj, iii]) + + class Test_Sym_Fourth_Order_Special: def test_check_sym_by_loop_left(self, tensor4): t_sym = Sym_Fourth_Order_Special(label="left")(tensor4) @@ -101,6 +114,37 @@ def test_check_sym_by_loop_inner(self, tensor4): has_sym_inner(t_sym) +def test_check_sym_by_loop_complete(tensor4): + t_sym = mechkit.operators.Sym()(tensor4) + pprint.pprint(con.to_mandel9(t_sym)) + has_sym_complete(t_sym) + + +def test_check_sym_by_loop_alternative_implementation(tensor4): + t_sym = mechkit.operators.Sym()(tensor4) + + def sym(tensor, sym_axes=None): + """ + Symmetrize selected axes of tensor. + If no sym_axes are specified, all axes are symmetrized + """ + base_axis = np.array(range(len(tensor.shape))) + + sym_axes = base_axis if sym_axes is None else sym_axes + + perms = itertools.permutations(sym_axes) + + axes = list() + for perm in perms: + axis = base_axis.copy() + axis[sym_axes] = perm + axes.append(axis) + + return 1.0 / len(axes) * sum(tensor.transpose(axis) for axis in axes) + + np.allclose(sym(tensor4), t_sym) + + def test_compare_sym_inner_inner_mandel(tensor4): t_sym_inner = Sym_Fourth_Order_Special(label="inner")(tensor4) t_sym_inner_mandel = Sym_Fourth_Order_Special(label="inner_mandel")(tensor4) @@ -126,7 +170,7 @@ def test_sym_minor_mandel(tensor4): def test_sym_axes_label_left(tensor4): """Two implementation should do the same job""" - sym_axes = mechkit.operators.sym(tensor4, sym_axes=[0, 1]) + sym_axes = mechkit.operators.Sym(axes=[0, 1])(tensor4) sym_label = Sym_Fourth_Order_Special(label="left")(tensor4) print(sym_axes) @@ -137,7 +181,7 @@ def test_sym_axes_label_left(tensor4): def test_sym_axes_label_right(tensor4): """Two implementation should do the same job""" - sym_axes = mechkit.operators.sym(tensor4, sym_axes=[2, 3]) + sym_axes = mechkit.operators.Sym(axes=[2, 3])(tensor4) sym_label = Sym_Fourth_Order_Special(label="right")(tensor4) print(sym_axes) @@ -170,7 +214,7 @@ def test_deviator_part_of_4_tensor_is_deviator(self, tensors): for key in ["hooke", "complete"]: deviator = alternatives.dev_t4_spencer1970(tensors[key]) # Is symmetric - sym = mechkit.operators.sym + sym = mechkit.operators.Sym() assert np.allclose(deviator, sym(deviator)) # Is traceless (as it is already symmetric, one trace tested is sufficient) assert np.allclose(np.einsum("ijkk->ij", deviator), np.zeros((3, 3)))