Skip to content

Commit

Permalink
Flatten fix (#178)
Browse files Browse the repository at this point in the history
* Inclusion of interpolation option to rescale_values function

* Added tests for rescale_values function with interpolation

* file: includes the possibility of unzipping only the files contained in a list

* array: fixes flatten and adds the option to return original shapes
- return_shapes is necessary if flattened array should later be reshaped to original shape

* Update _array.py

array: adicionada conversão de dicionário para tupla dentro de `flatten`

* cereja: updated version

Co-authored-by: Denny <[email protected]>
Co-authored-by: Joab Leite S. Neto <[email protected]>
  • Loading branch information
3 people authored May 19, 2022
1 parent 1d9d471 commit 423e609
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cereja/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from ._requests import request


VERSION = "1.7.2.final.0"
VERSION = "1.7.3.final.0"


__version__ = get_version_pep440_compliant(VERSION)
Expand Down
40 changes: 32 additions & 8 deletions cereja/array/_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,11 @@ def array_gen(shape: Tuple[int, ...], v: Union[Sequence[Any], Any] = None) -> Li
return v[0]


def flatten(sequence: Union[Sequence[Any], 'Matrix'], depth: Optional[int] = -1, **kwargs) -> Union[List[Any], Any]:
def flatten(sequence: Union[Sequence[Any], 'Matrix'], depth: Optional[int] = -1, return_shapes=False, **kwargs) -> Union[List[Any], Any]:
"""
Receives values, whether arrays of values, regardless of their shape and flatness
:param return_shapes: should return the shapes of the original values?
:param sequence: Is sequence of values.
:param depth: allows you to control a max depth, for example if you send a
sequence=[1,2, [[3]]] and depth=1 your return will be [1, 2, [3]].
Expand Down Expand Up @@ -180,21 +181,35 @@ def flatten(sequence: Union[Sequence[Any], 'Matrix'], depth: Optional[int] = -1,

flattened = []
i = 0
deep = 0
jump = len(sequence)
deep = 0
deep_counter = {deep: jump}
shapes = {deep: [jump]}
while i < len(sequence):
element = sequence[i]
if is_sequence(element) and (depth == -1 or depth > deep):
jump = len(element)
deep += 1
deep_counter[deep] = deep_counter.get(deep, 0) + jump
shapes[deep] = shapes.get(deep, []) + [jump]
sequence = list(element) + list(sequence[i + 1:])
if jump == 0:
deep -= 1
i = 0
else:
flattened.append(element)
deep_counter[deep] -= 1
i += 1
if i >= jump:
deep -= 1
jump = len(sequence)
if i >= jump:
for d in range(deep, 0, -1):
if deep_counter[d] == 0:
deep_counter[d-1] -= 1
deep -= 1
else:
break

if return_shapes:
return flattened, shapes
return flattened


Expand Down Expand Up @@ -403,12 +418,18 @@ def __matmul__(self, other):

def __add__(self, other):
assert self.shape == get_shape(other), "the shape must be the same"
return Matrix([list(map(sum, zip(*t))) for t in zip(self, other)])
if len(self.shape) == 1:
return Matrix([sum(t) for t in zip(self, other)])
else:
return Matrix([list(map(sum, zip(*t))) for t in zip(self, other)])

def __sub__(self, other):
if is_numeric_sequence(other):
assert self.shape == get_shape(other), "the shape must be the same"
return Matrix([list(map(sub, zip(*t))) for t in zip(self, other)])
if len(self.shape) == 1:
return Matrix([sub(t) for t in zip(self, other)])
else:
return Matrix([list(map(sub, zip(*t))) for t in zip(self, other)])
return Matrix(array_gen(self.shape, list(map(lambda x: x - other, self.flatten()))))

def __mul__(self, other):
Expand All @@ -421,7 +442,10 @@ def __truediv__(self, other):
if isinstance(other, (float, int)):
other = Matrix(array_gen(self.shape, other))
assert self.shape == get_shape(other), "the shape must be the same"
result = Matrix([list(map(div, zip(*t))) for t in zip(self, other)])
if len(self.shape) == 1:
result = Matrix([div(t) for t in zip(self, other)])
else:
result = Matrix([list(map(div, zip(*t))) for t in zip(self, other)])
assert self.shape == result.shape, "the shape must be the same"
return result

Expand Down

0 comments on commit 423e609

Please sign in to comment.