Skip to content

add a simple sgd optimizer and auto-grad usage #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

### guide
- [mindspore 安装指南](./installation/installation.md)
- mindspore API 案例
- [mindspore API 案例](./api-examples/README.md)
- [mindspore 模型案例](./model-examples/README.md)
14 changes: 8 additions & 6 deletions api-examples/README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
### 目录:
| 类别 | 链接 |
|:--|:-- |
| ops | [operator](./operator) |
| nn | [nerual network](./nn) |
| tensor | [tensor](./tensor) |
| runtime | [runtime](./runtime) |
| 类别 | 链接 |
|:----------|:---------------------------------------------------|
| ops | [operator](./operator) |
| nn | [nerual network](./nn) |
| optimizer | [optimizer](./optimizer/README.md) |
| auto-grad | [automatic differentiation](./auto-grad/README.md) |
| tensor | [tensor](./tensor) |
| runtime | [runtime](./runtime) |

6 changes: 6 additions & 0 deletions api-examples/auto-grad/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# 自动微分接口

| 案例 | 说明 | MindSpore版本 |
|:--------------------------------------------------------------------------|:----|:----------------|
| [使用mindspore对函数求导](./taking_gradients_of_function_with_mindspore_grad.md) | 无 | MindSpore 2.4.0 |
| [使用mindspore对模型求导](./taking_gradients_of_model_with_mindspore_grad.md) | 无 | MindSpore 2.4.0 |
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# 使用mindspore对函数求导

> 备注:以下示例使用的是 MindSpore 2.4.0 版本

## 目录

- [使用`mindspore.grad`对函数进行求导](#section1)
- [使用`mindspore.grad`计算线性逻辑回归的梯度](#section2)
- [使用`mindspore.value_and_grad`计算梯度与损失](#section3)

<br>

<a id="section1"></a>
## 使用`mindspore.grad`对函数进行求导

`f(x) = x^3 - 2x + 1` 的导数可以计算为:

$$
\begin{aligned}
& f(x) = x^3 - 2x + 1 \\
& f'(x) = 3x - 2 \\
& f''(x) = 3
\end{aligned}
$$

使用MindSpore可以简单的表示为:

```python
from mindspore import grad, Tensor

f = lambda x: x**3 - 2*x + 1

dfdx = grad(f)
d2fdx = grad(grad(f))
```

当`x=1`时,对上述内容进行验证可以得到:

$$
\begin{aligned}
& f(1) = 0 \\
& f'(1) = 1 \\
& f''(1) = 3
\end{aligned}
$$

在MindSpore中运行:

```python
print(f(Tensor(1.0)))
print(dfdx(Tensor(1.0)))
print(d2fdx(Tensor(1.0)))
```

<a id="section2"></a>
## 使用`mindspore.grad`计算线性逻辑回归的梯度

首先,我们做如下定义:

```python
from mindspore import ops, grad, Tensor, set_seed

set_seed(0)

inputs = ops.randn((4, 3))
targets = Tensor([True, True, False, True])

W, b = ops.randn((3,)), ops.randn((1,))

def loss(inputs, targets, W, b):
preds = ops.sigmoid(ops.matmul(inputs, W) + b)
logit_loss = preds * targets + (1 - preds) * (1 - targets)
return -ops.log(logit_loss).sum()

print(f"inputs: {inputs}")
print(f"targets: {targets}")
print(f"W: {W}")
print(f"b: {b}")
```

分别计算`W`和`b`等输入的梯度:

```python
x_grad = grad(loss, grad_position=0)(inputs, targets, W, b)
print(f'x_grad: {x_grad}')

y_grad = grad(loss, 1)(inputs, targets, W, b)
print(f'y_grad: {y_grad}')

W_grad = grad(loss, 2)(inputs, targets, W, b)
print(f'W_grad: {W_grad}')

b_grad = grad(loss, 3)(W, b)
print(f'b_grad: {b_grad}')
```

当然也可以一次性计算所需要的梯度:

```python
(W_grad, b_grad) = grad(loss, (2, 3))(inputs, targets, W, b)
print(f'W_grad: {W_grad}')
print(f'b_grad: {b_grad}')
```

如果函数输出多个`loss`,计算的梯度是所有`loss`对输入的导数,

```python
def multi_loss(inputs, targets, W, b):
loss1 = loss(inputs, targets, W, b)
loss2 = (W ** 2).sum()
return loss1, loss2

(W_grad, b_grad) = grad(multi_loss, (2, 3))(inputs, targets, W, b)
print(f'W_grad: {W_grad}, b_grad: {b_grad}')
```

如果只想计算`loss1`的梯度,可以尝试用`ops.stop_gradient`进行截断:

```python
def multi_loss_2(inputs, targets, W, b):
loss1 = loss(inputs, targets, W, b)
loss2 = (W ** 2).sum()
return loss1, ops.stop_gradient(loss2)

(W_grad, b_grad) = grad(multi_loss_2, (2, 3))(inputs, targets, W, b)
print(f'W_grad: {W_grad}, b_grad: {b_grad}')
```

或者,也可以通过设置`has_aux=True`排除除了第一个以外的输出对梯度的影响:

```python
(W_grad, b_grad) = grad(multi_loss, (2, 3), has_aux=True)(inputs, targets, W, b)
print(f'W_grad: {W_grad}, b_grad: {b_grad}')
```

`mindspore.grad`接口的更多细节可以参考 [MindSpore Docs](https://www.mindspore.cn/docs/zh-CN/r2.4.0/api_python/mindspore/mindspore.grad.html)


<a id="section3"></a>
## 使用`mindspore.value_and_grad`计算梯度与损失

计算线性逻辑回归函数的梯度,并获取`loss`:

```python
from mindspore import value_and_grad

loss, (W_grad, b_grad) = value_and_grad(loss, (2, 3))(inputs, targets, W, b)

print(f"loss: {loss}")
print(f'W_grad: {W_grad}, b_grad: {b_grad}')
```

`mindspore.value_and_grad`接口的更多细节可以参考 [MindSpore Docs](https://www.mindspore.cn/docs/zh-CN/r2.4.0/api_python/mindspore/mindspore.value_and_grad.html)
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# 使用mindspore对模型求导

> 备注:以下示例使用的是 MindSpore 2.4.0 版本

## 计算模型对权重的梯度

先定义一个简单的模型

```python
from mindspore import nn, Parameter
import numpy as np

np.random.seed(0)

class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.w = Parameter(np.random.randn(4, 4))
self.b = Parameter(np.random.randn(4,))

def construct(self, x, y):
pred = self.w @ x + self.b
loss = (pred - y) ** 2
return loss.sum()

net = Net()

print(f"net: \n{net}")
print(f"parameters of net: \n{net.get_parameters()}")
```

使用`mindspore.grad`接口计算模型对权重的梯度

> 注意:这里不需要计算模型对输入的梯度,所以设置`grad_position=None`

```python
from mindspore import grad, Tensor
import numpy as np

x, y = Tensor(np.random.randn(2, 4)), Tensor([1.0])

grad_fn = grad(net, grad_position=None, weights=net.get_parameters())

grads = grad_fn(x, y)

print(f"grads: \n{grads}")
```

当然我们也可以使用`mindspore.value_and_grad`接口计算模型对权重的梯度并获取`loss`的情况

```python
from mindspore import value_and_grad

grad_fn = value_and_grad(net, grad_position=None, weights=net.get_parameters())

loss, grads = grad_fn(x, y)

print(f"loss: {loss}")
print(f"grads: \n{grads}")
```

使用 `JIT(Just-In-Time)` 编译加速

```python
from mindspore import jit

@jit
def loss_and_grads():
return grad_fn(x, y)

loss, grads = loss_and_grads(x, y)

print(f"loss: {loss}")
print(f"grads: \n{grads}")
```
6 changes: 6 additions & 0 deletions api-examples/optimizer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# 优化器


| 案例 | 说明 | MindSpore版本 |
|:---------------------------------------------|:----|:----------------|
| [从零开始写一个简单的sgd优化器](./create_a_simple_sgd.md) | 无 | MindSpore 2.4.0 |
113 changes: 113 additions & 0 deletions api-examples/optimizer/create_a_simple_sgd.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# 从零开始写一个简单的sgd优化器

> 备注:以下示例使用的是 MindSpore 2.4.0 版本

## 目录

- [从零开始创建一个sgd优化器](#section1)
- [加速](#section2)


<a id="section1"></a>
## 从零开始创建一个`sgd`优化器

- 第一步:定义`sgd`优化过程

```python
from mindspore import ops

def sgd_update(weights, grads, lr):
for w, dw in weights, grads:
ops.assign(w, w - lr * dw)
```

- 第二步:更新模型参数

> `Net` 模型定义请[参考](../../model-examples/create_a_simple_nerual_network.md)

```python
import mindspore
from mindspore import Tensor
import numpy as np

net = Net()
net.trainable_params()
grad_fn = mindspore.value_and_grad(net, None, net.trainable_params(), has_aux=False)

def train():
total_step = 10
x, y = Tensor(np.random.randn(1, 1, 16, 16)), Tensor(np.ones(1))

for current_step in range(total_step):
loss, grads = grad_fn(x, y)
sgd_update(net.trainable_params(), grads, 0.01)
print(f"loss: {loss:>7f} [{current_step + 1:>5d}/{total_step:>5d}]")
```

- 当然我们也可以将 `sgd` 优化器定义为一个类对象

```python
from mindspore import nn, ops

class SGD(nn.Cell):
def __init__(self, weights, lr):
super(SGD, self).__init__()
self.weights = weights
self.lr = lr

def construct(self, grads):
for w, dw in self.weights, grads:
ops.assign(w, w - self.lr * dw)
```

<a id="section2"></a>
## 加速

- 使用 `JIT(Just-In-Time)` 编译加速

```python
import mindspore
from mindspore import nn, ops

class SGD(nn.Cell):
def __init__(self, weights, lr):
super(SGD, self).__init__()
self.weights = weights
self.lr = lr

@mindspore.jit
def construct(self, grads):
for w, dw in self.weights, grads:
ops.assign(w, w - self.lr * dw)
```

- 使用 `mindspore.ops.HyperMap` 操作替换 `for` 循环

```python
import mindspore
from mindspore import nn, ops

sgd_update = ops.MultitypeFuncGraph("_sgd_update")

@sgd_update.register("Tensor", "Tensor", "Tensor")
def run_sgd_update(lr, grad, weight):
"""Apply sgd optimizer to the weight parameter using Tensor."""
success = True
ops.depend(success, ops.assign(weight, weight - lr * grad))
return success

class SGD(nn.Cell):
def __init__(self, weights, lr):
super(SGD, self).__init__()
self.weights = weights
self.lr = lr
self.hyper_map = ops.HyperMap()

@mindspore.jit
def construct(self, grads):
return self.hyper_map(
ops.partial(sgd_update, self.lr),
self.weights,
grads
)
```
Loading