Skip to content

Commit

Permalink
Refactor function into Operator class for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
JulianKarlBauer committed Feb 3, 2022
1 parent cc09d6b commit d25a7ff
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 27 deletions.
75 changes: 51 additions & 24 deletions mechkit/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,61 @@
##########################################################################


def sym(tensor, sym_axes=None):
class Abstract_Operator:
def __call__(self, *args, **kwargs):
return self.function(*args, **kwargs)

def check(self, tensor, *args, **kwargs):
"""Check whether `tensor` has the symmetry specified by Sym operator"""

return np.allclose(tensor, self.function(tensor=tensor, *args, **kwargs))


class Sym(Abstract_Operator):
"""
Symmetrize selected axes of tensor.
If no sym_axes are specified, all axes are symmetrized
Based on the `axes` argument of the class initiation,
the returned instance act as a symmetrization function,
which symmetrices a given tensor with respect to the
specified axes.
If `axes` is `None`, all axes of the tensor are symmetrized
"""
base_axis = np.array(range(len(tensor.shape)))

sym_axes = base_axis if sym_axes is None else sym_axes
def __init__(self, axes=None):
if axes is None:
self.function = self._sym_all_axes
else:
self.function = self._sym_selected_axes
self.axes = axes

def _sym_all_axes(self, tensor):
permutations = list(itertools.permutations(self._get_base_axes(tensor=tensor)))

perms = itertools.permutations(sym_axes)
return self._symmetrize(tensor=tensor, permutations=permutations)

axes = list()
for perm in perms:
axis = base_axis.copy()
axis[sym_axes] = perm
axes.append(axis)
def _sym_selected_axes(self, tensor):
tmp_permutations = itertools.permutations(self.axes)
base_axes = self._get_base_axes(tensor=tensor)

return 1.0 / len(axes) * sum(tensor.transpose(axis) for axis in axes)
permutations = []
for perm in tmp_permutations:
axis = base_axes.copy()
axis[self.axes] = perm
permutations.append(axis)

return self._symmetrize(tensor=tensor, permutations=permutations)

def is_sym(tensor, sym_axes=None):
"""Test whether `tensor` has the index symmetry specified by `sym_axes`"""
return np.allclose(tensor, sym(tensor, sym_axes=sym_axes))
def _get_base_axes(self, tensor):
return np.array(range(len(tensor.shape)))

def _symmetrize(self, tensor, permutations):
return (
1.0
/ len(permutations)
* sum(tensor.transpose(perm) for perm in permutations)
)

class Sym_Fourth_Order_Special(object):

class Sym_Fourth_Order_Special(Abstract_Operator):
"""
Based on the `label` argument of the class initiation,
the returned instance act as a symmetrization function,
Expand All @@ -59,12 +88,6 @@ def __init__(self, label=None):
else:
raise utils.Ex("Please specify a valid symmetry label")

def __call__(self, *args, **kwargs):
return self.function(*args, **kwargs)

def check(self, tensor, *args, **kwargs):
return np.allclose(tensor, self.function(tensor=tensor, *args, **kwargs))

def _set_permutation_lists(self):
base_permutations = {
"identity": (0, 1, 2, 3),
Expand Down Expand Up @@ -142,6 +165,8 @@ def dev_tensor_4th_order_simple(tensor):
"Tensor of fourth order" " has to be in tensor notation"
)

sym = Sym()

tensor_4 = sym(tensor)
tensor_2 = np.einsum("ppij->ij", tensor_4)
trace = np.einsum("ii->", tensor_2)
Expand Down Expand Up @@ -182,11 +207,13 @@ def dev_t4_kanatani1984(self, tensor):
3,
), "Requires tensor 4.order tensor notation"

assert is_sym(tensor), "Only valid for completely symmetric tensor"
assert Sym().check(tensor), "Only valid for completely symmetric tensor"
assert np.isclose(
np.einsum("iijj->", tensor), 1.0
), "Only valid for completely symmetric tensor with complete trace is one"

sym = Sym()

I2 = np.eye(3, dtype="float64")
dev = (
tensor
Expand All @@ -204,7 +231,7 @@ def dev_t4_spencer1970(self, tensor):
"Requires tensor 4.order in " "tensor notation"
)

tensor_sym = sym(tensor)
tensor_sym = Sym()(tensor)

I2 = np.eye(3, dtype="float64")
a2 = np.einsum("ppij->ij", tensor_sym)
Expand Down
50 changes: 47 additions & 3 deletions test/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import mechkit
from mechkit.operators import Sym_Fourth_Order_Special
import pytest
import itertools

con = mechkit.notation.Converter()

Expand Down Expand Up @@ -69,6 +70,18 @@ def has_sym_inner(A):
assert np.isclose(A[iii, jjj, kkk, lll], A[kkk, lll, iii, jjj])


def has_sym_complete(A):
for iii in range(3):
for jjj in range(3):
for kkk in range(3):
for lll in range(3):
print(iii, jjj, kkk, lll)
assert np.isclose(A[iii, jjj, kkk, lll], A[jjj, iii, kkk, lll])
assert np.isclose(A[iii, jjj, kkk, lll], A[iii, jjj, lll, kkk])
assert np.isclose(A[iii, jjj, kkk, lll], A[kkk, lll, iii, jjj])
assert np.isclose(A[iii, jjj, kkk, lll], A[lll, kkk, jjj, iii])


class Test_Sym_Fourth_Order_Special:
def test_check_sym_by_loop_left(self, tensor4):
t_sym = Sym_Fourth_Order_Special(label="left")(tensor4)
Expand Down Expand Up @@ -101,6 +114,37 @@ def test_check_sym_by_loop_inner(self, tensor4):
has_sym_inner(t_sym)


def test_check_sym_by_loop_complete(tensor4):
t_sym = mechkit.operators.Sym()(tensor4)
pprint.pprint(con.to_mandel9(t_sym))
has_sym_complete(t_sym)


def test_check_sym_by_loop_alternative_implementation(tensor4):
t_sym = mechkit.operators.Sym()(tensor4)

def sym(tensor, sym_axes=None):
"""
Symmetrize selected axes of tensor.
If no sym_axes are specified, all axes are symmetrized
"""
base_axis = np.array(range(len(tensor.shape)))

sym_axes = base_axis if sym_axes is None else sym_axes

perms = itertools.permutations(sym_axes)

axes = list()
for perm in perms:
axis = base_axis.copy()
axis[sym_axes] = perm
axes.append(axis)

return 1.0 / len(axes) * sum(tensor.transpose(axis) for axis in axes)

np.allclose(sym(tensor4), t_sym)


def test_compare_sym_inner_inner_mandel(tensor4):
t_sym_inner = Sym_Fourth_Order_Special(label="inner")(tensor4)
t_sym_inner_mandel = Sym_Fourth_Order_Special(label="inner_mandel")(tensor4)
Expand All @@ -126,7 +170,7 @@ def test_sym_minor_mandel(tensor4):

def test_sym_axes_label_left(tensor4):
"""Two implementation should do the same job"""
sym_axes = mechkit.operators.sym(tensor4, sym_axes=[0, 1])
sym_axes = mechkit.operators.Sym(axes=[0, 1])(tensor4)
sym_label = Sym_Fourth_Order_Special(label="left")(tensor4)

print(sym_axes)
Expand All @@ -137,7 +181,7 @@ def test_sym_axes_label_left(tensor4):

def test_sym_axes_label_right(tensor4):
"""Two implementation should do the same job"""
sym_axes = mechkit.operators.sym(tensor4, sym_axes=[2, 3])
sym_axes = mechkit.operators.Sym(axes=[2, 3])(tensor4)
sym_label = Sym_Fourth_Order_Special(label="right")(tensor4)

print(sym_axes)
Expand Down Expand Up @@ -170,7 +214,7 @@ def test_deviator_part_of_4_tensor_is_deviator(self, tensors):
for key in ["hooke", "complete"]:
deviator = alternatives.dev_t4_spencer1970(tensors[key])
# Is symmetric
sym = mechkit.operators.sym
sym = mechkit.operators.Sym()
assert np.allclose(deviator, sym(deviator))
# Is traceless (as it is already symmetric, one trace tested is sufficient)
assert np.allclose(np.einsum("ijkk->ij", deviator), np.zeros((3, 3)))
Expand Down

0 comments on commit d25a7ff

Please sign in to comment.