Skip to content

Commit

Permalink
gemm doesn't improve performance that much, numpy is really well opti…
Browse files Browse the repository at this point in the history
…mized
  • Loading branch information
CPerezRuiz335 committed Jun 25, 2023
1 parent 6cbc896 commit 224b011
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 106 deletions.
79 changes: 0 additions & 79 deletions giagrad/mathops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Tuple, Optional, Union
from giagrad.tensor import Function
from itertools import zip_longest
from scipy.linalg.blas import sgemm, dgemm

def collapse(partial: NDArray, p_shape: Tuple[int, ...]):
reduce_axis = []
Expand Down Expand Up @@ -98,84 +97,6 @@ def backward(self, partial: NDArray):
if p2.requires_grad:
p2.grad += p1.data.T.dot(partial)

class Gemm(Function):
def __init__(self, trans_a: bool = False, trans_b: bool = False):
super().__init__()
self.trans_a = trans_a
self.trans_b = trans_b

def forward(self, *tensors) -> NDArray:
self.save_for_backward(*tensors)

if len(tensors) == 3:
alpha, a, b = tensors
return sgemm(
alpha=alpha.data, a=a.data, b=b.data,
trans_a=self.trans_a, trans_b=self.trans_b
)
elif len(tensors) == 5:
alpha, a, b, beta, c = tensors
return sgemm(
alpha=alpha.data, a=a.data, b=b.data, beta=beta.data, c=c,
trans_a=self.trans_a, trans_b=self.trans_b
)
else:
raise ValueError()

def backward(self, partial: NDArray):
if len(self.parents) == 3:
alpha, a, b = self.parents
else:
alpha, a, b, beta, c = self.parents

if beta.requires_grad:
beta.grad += (partial * c).sum()

if c.requires_grad:
c.grad += (partial * beta)

if alpha.requires_grad:
alpha.grad += (
sgemm(1., a, b, trans_a=self.trans_a, trans_b=self.trans_b)
* partial
).sum()


if a.requires_grad:
if not self.trans_a:
a.grad += sgemm(
alpha.data, partial, b.data, trans_b=(not self.trans_b)
)
else:
a.grad += sgemm(
alpha.data, b.data, partial,
trans_a=self.trans_b, trans_b=True
)

if b.requires_grad:
if not self.trans_b:
b.grad += sgemm(
alpha.data, a.data, partial, trans_a=(not self.trans_a)
)
else:
b.grad += sgemm(
alpha.data, partial, a.data,
trans_a=True, trans_b=self.trans_a
)


class SGemm(Gemm):
_fun = sgemm

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

class DGemm(Gemm):
_fun = dgemm

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# ***** math functions (unary) *****
class Pow(Function):
def __init__(self):
Expand Down
28 changes: 14 additions & 14 deletions giagrad/nn/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,44 +61,44 @@ class Linear(Module):
@overload
def __init__(
self,
in_features: int,
in_features: Optional[int],
out_features: int,
bias: bool = True
):
super().__init__()
self.w: Tensor
self.b: Optional[Tensor] = None
self.w: Optional[Tensor] = None # uninitialized
self.b: Optional[Tensor] = None # uninitialized

self.bias = bias
self.out_features = out_features
self.in_features = in_features

def __init_tensors(self, batch: int, in_features: int):
self.in_features = in_features
if self.in_features is None:
self.in_features = in_features

k = 1 / sqrt(in_features)
self.w = Tensor.empty(self.out_features, in_features, requires_grad=True)
self.w.uniform(a=-k, b=k)

if self.bias:
b = np.random.uniform(-k, k, size=(batch, self.out_features))
b = np.repeat(b, batch, axis=0)
self.b = Tensor(b, requires_grad=True)
self.b = Tensor.empty(self.out_features, requires_grad=True)
self.b.uniform(a=-k, b=k)

def __call__(self, x: Tensor) -> Tensor:
if self.b is None:
self.__init_tensors(*x.shape)
if self.w is None:
self.__init_tensors(*x.shape[-2:])

if self.bias:
return x.gemm(alpha=1., b=self.w, c=self.b, trans_b=True)
return x @ self.w.T + self.b
else:
return x.gemm(alpha=1., b=self.w, trans_b=True)
return x @ self.w.T

def __str__(self):
return (
"Layer("
+ f"in_features={self.in_features}, " if self.in_features else ''
+ f"out_features={self.out_features}, "
+ f"bias={self.bias})"
+ (f"in_features={self.in_features}, " if self.in_features else '')
+ f"out_features={self.out_features}, "
+ f"bias={self.bias})"
)

13 changes: 0 additions & 13 deletions giagrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,19 +671,6 @@ def matmul(self, other) -> Tensor:
"""
return self.__matmul__(other)

def gemm(
self, alpha, b, beta=1., c=None, trans_a=False, trans_b=False
) -> Tensor:
alpha = Tensor(alpha) if not isinstance(alpha, Tensor) else alpha
b = Tensor(b) if not isinstance(b, Tensor) else b

if c == None:
return Tensor.comm(mops.Gemm(trans_a, trans_b), alpha, self, b)
else:
beta = Tensor(beta) if not isinstance(beta, Tensor) else beta
c = Tensor(c) if not isinstance(c, Tensor) else c
return Tensor.comm(mops.Gemm(trans_a, trans_b), alpha, self, b, beta, c)

def div(self, other) -> Tensor:
"""
Returns a new tensor with the division of `data` to ``other``.
Expand Down

0 comments on commit 224b011

Please sign in to comment.