From 230c0665c8f0ea5074264a2aeb121b3b8988f526 Mon Sep 17 00:00:00 2001
From: zhanghuiyao <1814619459@qq.com>
Date: Mon, 13 Jan 2025 11:37:05 +0800
Subject: [PATCH 1/7] add simple sgd tutorial
---
README.md | 2 +-
api-examples/README.md | 13 +-
api-examples/optimizer/README.md | 0
api-examples/optimizer/sgd.md | 112 ++++++++++++++++++
.../create_a_simple_nerual_network.md | 29 +++++
5 files changed, 149 insertions(+), 7 deletions(-)
create mode 100644 api-examples/optimizer/README.md
create mode 100644 api-examples/optimizer/sgd.md
create mode 100644 model-examples/create_a_simple_nerual_network.md
diff --git a/README.md b/README.md
index 160f956..b569470 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,5 @@
### guide
- [mindspore 安装指南](./installation/installation.md)
-- mindspore API 案例
+- [mindspore API 案例](./api-examples/README.md)
- [mindspore 模型案例](./model-examples/README.md)
diff --git a/api-examples/README.md b/api-examples/README.md
index 5aae743..2b15685 100644
--- a/api-examples/README.md
+++ b/api-examples/README.md
@@ -1,8 +1,9 @@
### 目录:
-| 类别 | 链接 |
-|:--|:-- |
-| ops | [operator](./operator) |
-| nn | [nerual network](./nn) |
-| tensor | [tensor](./tensor) |
-| runtime | [runtime](./runtime) |
+| 类别 | 链接 |
+|:---|:-----------------------------------|
+| ops | [operator](./operator) |
+| nn | [nerual network](./nn) |
+| optimizer | [optimizer](./optimizer/README.md) |
+| tensor | [tensor](./tensor) |
+| runtime | [runtime](./runtime) |
diff --git a/api-examples/optimizer/README.md b/api-examples/optimizer/README.md
new file mode 100644
index 0000000..e69de29
diff --git a/api-examples/optimizer/sgd.md b/api-examples/optimizer/sgd.md
new file mode 100644
index 0000000..1348c1c
--- /dev/null
+++ b/api-examples/optimizer/sgd.md
@@ -0,0 +1,112 @@
+# sgd 优化器
+
+> 基于 MindSpore 2.4.0 版本
+
+## guide
+
+- [从0开始创建一个sgd优化器](#从0开始创建一个sgd优化器)
+- [加速]()
+
+## 从0开始创建一个sgd优化器
+
+- 第一步:定义sgd优化过程
+
+ ```python
+ from mindspore import ops
+
+ def sgd_update(weights, grads, lr):
+ for w, dw in weights, grads:
+ ops.assign(w, w - lr * dw)
+ ```
+
+- 第二步:更新模型参数
+
+ > `Net` 模型定义请[参考](../../model-examples/create_a_simple_nerual_network.md)
+
+ ```python
+ import mindspore
+ from mindspore import Tensor
+ import numpy as np
+
+ net = Net()
+ net.trainable_params()
+ grad_fn = mindspore.value_and_grad(net, None, net.trainable_params(), has_aux=False)
+
+ def train():
+ total_step = 10
+ x, y = Tensor(np.random.randn(1, 1, 16, 16)), Tensor(np.ones(1))
+
+ for current_step in range(total_step):
+ loss, grads = grad_fn(x, y)
+ sgd_update(net.trainable_params(), grads, 0.01)
+ print(f"loss: {loss:>7f} [{current_step + 1:>5d}/{total_step:>5d}]")
+ ```
+
+- 当然我们也可以将 `sgd` 优化器定义为一个类对象
+
+ ```python
+ from mindspore import nn, ops
+
+ class SGD(nn.Cell):
+ def __init__(self, weights, lr):
+ super(SGD, self).__init__()
+ self.weights = weights
+ self.lr = lr
+
+ def construct(self, grads):
+ for w, dw in self.weights, grads:
+ ops.assign(w, w - self.lr * dw)
+ ```
+
+## 加速
+
+- 使用 JIT(Just-In-Time) 编译加速
+
+ ```diff
+ import mindspore
+ from mindspore import nn, ops
+
+ class SGD(nn.Cell):
+ def __init__(self, weights, lr):
+ super(SGD, self).__init__()
+ self.weights = weights
+ self.lr = lr
+
+ + @mindspore.jit
+ def construct(self, grads):
+ for w, dw in self.weights, grads:
+ ops.assign(w, w - self.lr * dw)
+ ```
+
+- 使用 `mindspore.ops.HyperMap` 操作替换 `for` 循环
+
+ ```diff
+ import mindspore
+ from mindspore import nn, ops
+
+ +sgd_update = ops.MultitypeFuncGraph("_sgd_update")
+ +
+ +@sgd_update.register("Tensor", "Tensor", "Tensor")
+ +def run_sgd_update(lr, grad, weight):
+ + """Apply sgd optimizer to the weight parameter using Tensor."""
+ + success = True
+ + ops.depend(success, ops.assign(weight, weight - lr * grad))
+ + return success
+
+ class SGD(nn.Cell):
+ def __init__(self, weights, lr):
+ super(SGD, self).__init__()
+ self.weights = weights
+ self.lr = lr
+ + self.hyper_map = ops.HyperMap()
+
+ @mindspore.jit
+ def construct(self, grads):
+ - for w, dw in self.weights, grads:
+ - ops.assign(w, w - self.lr * dw)
+ + return self.hyper_map(
+ + ops.partial(sgd_update, self.lr),
+ + self.weights,
+ + grads
+ + )
+ ```
diff --git a/model-examples/create_a_simple_nerual_network.md b/model-examples/create_a_simple_nerual_network.md
new file mode 100644
index 0000000..8a33bf8
--- /dev/null
+++ b/model-examples/create_a_simple_nerual_network.md
@@ -0,0 +1,29 @@
+# 创建一个简单的神经网络
+
+```python
+from mindspore import nn, Tensor
+import numpy as np
+
+class Net(nn.Cell):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.mlp = nn.SequentialCell([
+ nn.Dense(16*16, 128),
+ nn.ReLU(),
+ nn.Dense(128, 128),
+ nn.ReLU(),
+ nn.Dense(128, 10),
+ ])
+ self.loss_fn = nn.CrossEntropyLoss()
+
+ def construct(self, x, y):
+ x = self.mlp(x)
+ loss = self.loss_fn(x, y)
+ return loss
+
+net = Net()
+x, y = Tensor(np.random.randn(1, 1, 16, 16)), Tensor(np.ones(1))
+
+print(net)
+print(net(x, y))
+```
From fe7ebc5c64c67726203f30d8cbaf8a4726848278 Mon Sep 17 00:00:00 2001
From: zhanghuiyao <1814619459@qq.com>
Date: Mon, 13 Jan 2025 11:41:30 +0800
Subject: [PATCH 2/7] 1
---
api-examples/optimizer/sgd.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/api-examples/optimizer/sgd.md b/api-examples/optimizer/sgd.md
index 1348c1c..37b819a 100644
--- a/api-examples/optimizer/sgd.md
+++ b/api-examples/optimizer/sgd.md
@@ -62,7 +62,7 @@
- 使用 JIT(Just-In-Time) 编译加速
- ```diff
+ ```pycon
import mindspore
from mindspore import nn, ops
From f59cd99cd99a53e4a70e6d7908aca04af47d99bc Mon Sep 17 00:00:00 2001
From: zhanghuiyao <1814619459@qq.com>
Date: Mon, 13 Jan 2025 11:42:14 +0800
Subject: [PATCH 3/7] 1
---
api-examples/optimizer/sgd.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/api-examples/optimizer/sgd.md b/api-examples/optimizer/sgd.md
index 37b819a..055264a 100644
--- a/api-examples/optimizer/sgd.md
+++ b/api-examples/optimizer/sgd.md
@@ -62,7 +62,7 @@
- 使用 JIT(Just-In-Time) 编译加速
- ```pycon
+ ```python
import mindspore
from mindspore import nn, ops
From 2719dba822dc970600387bee6b30ed1eb48b5558 Mon Sep 17 00:00:00 2001
From: zhanghuiyao <1814619459@qq.com>
Date: Mon, 13 Jan 2025 14:44:46 +0800
Subject: [PATCH 4/7] update
---
api-examples/optimizer/README.md | 6 +++
.../{sgd.md => create_a_simple_sgd.md} | 53 ++++++++++---------
2 files changed, 33 insertions(+), 26 deletions(-)
rename api-examples/optimizer/{sgd.md => create_a_simple_sgd.md} (69%)
diff --git a/api-examples/optimizer/README.md b/api-examples/optimizer/README.md
index e69de29..94e07de 100644
--- a/api-examples/optimizer/README.md
+++ b/api-examples/optimizer/README.md
@@ -0,0 +1,6 @@
+# 优化器
+
+
+| 案例 | 说明 | MindSpore版本 |
+|:---------------------------------------------|:----|:----------------|
+| [从零开始写一个简单的sgd优化器](./create_a_simple_sgd.md) | 无 | MindSpore 2.4.0 |
diff --git a/api-examples/optimizer/sgd.md b/api-examples/optimizer/create_a_simple_sgd.md
similarity index 69%
rename from api-examples/optimizer/sgd.md
rename to api-examples/optimizer/create_a_simple_sgd.md
index 055264a..ee25d14 100644
--- a/api-examples/optimizer/sgd.md
+++ b/api-examples/optimizer/create_a_simple_sgd.md
@@ -1,15 +1,17 @@
-# sgd 优化器
+# 从零开始写一个简单的sgd优化器
-> 基于 MindSpore 2.4.0 版本
+> 备注:以下示例使用的是 MindSpore 2.4.0 版本
-## guide
+## 目录
-- [从0开始创建一个sgd优化器](#从0开始创建一个sgd优化器)
-- [加速]()
+- [从零开始创建一个sgd优化器](#section1)
+- [加速](#section2)
-## 从0开始创建一个sgd优化器
-- 第一步:定义sgd优化过程
+
+## 从零开始创建一个`sgd`优化器
+
+- 第一步:定义`sgd`优化过程
```python
from mindspore import ops
@@ -58,9 +60,10 @@
ops.assign(w, w - self.lr * dw)
```
+
## 加速
-- 使用 JIT(Just-In-Time) 编译加速
+- 使用 `JIT(Just-In-Time)` 编译加速
```python
import mindspore
@@ -72,7 +75,7 @@
self.weights = weights
self.lr = lr
- + @mindspore.jit
+ @mindspore.jit
def construct(self, grads):
for w, dw in self.weights, grads:
ops.assign(w, w - self.lr * dw)
@@ -80,33 +83,31 @@
- 使用 `mindspore.ops.HyperMap` 操作替换 `for` 循环
- ```diff
+ ```python
import mindspore
from mindspore import nn, ops
- +sgd_update = ops.MultitypeFuncGraph("_sgd_update")
- +
- +@sgd_update.register("Tensor", "Tensor", "Tensor")
- +def run_sgd_update(lr, grad, weight):
- + """Apply sgd optimizer to the weight parameter using Tensor."""
- + success = True
- + ops.depend(success, ops.assign(weight, weight - lr * grad))
- + return success
+ sgd_update = ops.MultitypeFuncGraph("_sgd_update")
+
+ @sgd_update.register("Tensor", "Tensor", "Tensor")
+ def run_sgd_update(lr, grad, weight):
+ """Apply sgd optimizer to the weight parameter using Tensor."""
+ success = True
+ ops.depend(success, ops.assign(weight, weight - lr * grad))
+ return success
class SGD(nn.Cell):
def __init__(self, weights, lr):
super(SGD, self).__init__()
self.weights = weights
self.lr = lr
- + self.hyper_map = ops.HyperMap()
+ self.hyper_map = ops.HyperMap()
@mindspore.jit
def construct(self, grads):
- - for w, dw in self.weights, grads:
- - ops.assign(w, w - self.lr * dw)
- + return self.hyper_map(
- + ops.partial(sgd_update, self.lr),
- + self.weights,
- + grads
- + )
+ return self.hyper_map(
+ ops.partial(sgd_update, self.lr),
+ self.weights,
+ grads
+ )
```
From b72bbda3867dc8938d96ee7a3a7f9d4e582945c1 Mon Sep 17 00:00:00 2001
From: zhanghuiyao <1814619459@qq.com>
Date: Tue, 14 Jan 2025 16:55:24 +0800
Subject: [PATCH 5/7] add auto-grad with function
---
api-examples/README.md | 15 +-
api-examples/auto-grad/README.md | 5 +
.../taking_gradients_with_mindspore_grad.md | 153 ++++++++++++++++++
3 files changed, 166 insertions(+), 7 deletions(-)
create mode 100644 api-examples/auto-grad/README.md
create mode 100644 api-examples/auto-grad/taking_gradients_with_mindspore_grad.md
diff --git a/api-examples/README.md b/api-examples/README.md
index 2b15685..1c5e1cc 100644
--- a/api-examples/README.md
+++ b/api-examples/README.md
@@ -1,9 +1,10 @@
### 目录:
-| 类别 | 链接 |
-|:---|:-----------------------------------|
-| ops | [operator](./operator) |
-| nn | [nerual network](./nn) |
-| optimizer | [optimizer](./optimizer/README.md) |
-| tensor | [tensor](./tensor) |
-| runtime | [runtime](./runtime) |
+| 类别 | 链接 |
+|:----------|:---------------------------------------------------|
+| ops | [operator](./operator) |
+| nn | [nerual network](./nn) |
+| optimizer | [optimizer](./optimizer/README.md) |
+| auto-grad | [automatic differentiation](./auto-grad/README.md) |
+| tensor | [tensor](./tensor) |
+| runtime | [runtime](./runtime) |
diff --git a/api-examples/auto-grad/README.md b/api-examples/auto-grad/README.md
new file mode 100644
index 0000000..f458b0b
--- /dev/null
+++ b/api-examples/auto-grad/README.md
@@ -0,0 +1,5 @@
+# 自动微分接口
+
+| 案例 | 说明 | MindSpore版本 |
+|:------------------------------------------------------------------|:----|:----------------|
+| [使用mindspore对函数求导](./taking_gradients_with_mindspore_grad.md) | 无 | MindSpore 2.4.0 |
diff --git a/api-examples/auto-grad/taking_gradients_with_mindspore_grad.md b/api-examples/auto-grad/taking_gradients_with_mindspore_grad.md
new file mode 100644
index 0000000..3c80552
--- /dev/null
+++ b/api-examples/auto-grad/taking_gradients_with_mindspore_grad.md
@@ -0,0 +1,153 @@
+# 使用mindspore对函数求导
+
+> 备注:以下示例使用的是 MindSpore 2.4.0 版本
+
+## 目录
+
+- [使用`mindspore.grad`对函数进行求导](#section1)
+- [使用`mindspore.grad`计算线性逻辑回归的梯度](#section2)
+- [使用`mindspore.value_and_grad`计算梯度与数值](#section3)
+
+
+
+
+## 使用`mindspore.grad`对函数进行求导
+
+`f(x) = x^3 - 2x + 1` 的导数可以计算为:
+
+$$
+\begin{aligned}
+& f(x) = x^3 - 2x + 1 \\
+& f'(x) = 3x - 2 \\
+& f''(x) = 3
+\end{aligned}
+$$
+
+使用MindSpore可以简单的表示为:
+
+```python
+from mindspore import grad, Tensor
+
+f = lambda x: x**3 - 2*x + 1
+
+dfdx = grad(f)
+d2fdx = grad(grad(f))
+```
+
+当`x=1`时,对上述内容进行评估可以得到:
+
+$$
+\begin{aligned}
+& f(1) = 0 \\
+& f'(1) = 1 \\
+& f''(1) = 3
+\end{aligned}
+$$
+
+使用MindSpore:
+
+```python
+print(f(Tensor(1.0)))
+print(dfdx(Tensor(1.0)))
+print(d2fdx(Tensor(1.0)))
+```
+
+
+## 使用`mindspore.grad`计算线性逻辑回归的梯度
+
+首先,我们做如下定义:
+
+```python
+from mindspore import ops, grad, Tensor, set_seed
+
+set_seed(0)
+
+inputs = ops.randn((4, 3))
+targets = Tensor([True, True, False, True])
+
+W, b = ops.randn((3,)), ops.randn((1,))
+
+def loss(inputs, targets, W, b):
+ preds = ops.sigmoid(ops.matmul(inputs, W) + b)
+ logit_loss = preds * targets + (1 - preds) * (1 - targets)
+ return -ops.log(logit_loss).sum()
+
+print(f"inputs: {inputs}")
+print(f"targets: {targets}")
+print(f"W: {W}")
+print(f"b: {b}")
+```
+
+分别计算`W`和`b`等输入的梯度:
+
+```python
+x_grad = grad(loss, grad_position=0)(inputs, targets, W, b)
+print(f'x_grad: {x_grad}')
+
+y_grad = grad(loss, 1)(inputs, targets, W, b)
+print(f'y_grad: {y_grad}')
+
+W_grad = grad(loss, 2)(inputs, targets, W, b)
+print(f'W_grad: {W_grad}')
+
+b_grad = grad(loss, 3)(W, b)
+print(f'b_grad: {b_grad}')
+```
+
+当然也可以一次性计算所需要的梯度:
+
+```python
+(W_grad, b_grad) = grad(loss, (2, 3))(inputs, targets, W, b)
+print(f'W_grad: {W_grad}')
+print(f'b_grad: {b_grad}')
+```
+
+如果函数输出多个`loss`,计算的梯度是所有`loss`对输入的导数,
+
+```python
+def multi_loss(inputs, targets, W, b):
+ loss1 = loss(inputs, targets, W, b)
+ loss2 = (W ** 2).sum()
+ return loss1, loss2
+
+(W_grad, b_grad) = grad(multi_loss, (2, 3))(inputs, targets, W, b)
+print(f'W_grad: {W_grad}, b_grad: {b_grad}')
+```
+
+如果只想计算`loss1`的梯度,可以尝试用`ops.stop_gradient`进行截断:
+
+```python
+def multi_loss_2(inputs, targets, W, b):
+ loss1 = loss(inputs, targets, W, b)
+ loss2 = (W ** 2).sum()
+ return loss1, ops.stop_gradient(loss2)
+
+(W_grad, b_grad) = grad(multi_loss_2, (2, 3))(inputs, targets, W, b)
+print(f'W_grad: {W_grad}, b_grad: {b_grad}')
+```
+
+或者,也可以通过设置`has_aux=True`排除除了第一个以外的输出对梯度的影响:
+
+```python
+(W_grad, b_grad) = grad(multi_loss, (2, 3), has_aux=True)(inputs, targets, W, b)
+print(f'W_grad: {W_grad}, b_grad: {b_grad}')
+```
+
+`mindspore.grad`接口的更多细节可以参考 [MindSpore Docs](https://www.mindspore.cn/docs/zh-CN/r2.4.0/api_python/mindspore/mindspore.grad.html)
+
+
+
+## 使用`mindspore.value_and_grad`计算梯度与数值
+
+计算线性逻辑回归函数的梯度,并获取`loss`:
+
+```python
+from mindspore import value_and_grad
+
+loss, (W_grad, b_grad) = value_and_grad(loss, (2, 3))(inputs, targets, W, b)
+
+print(f"loss: {loss}")
+print(f'W_grad: {W_grad}, b_grad: {b_grad}')
+```
+
+`mindspore.value_and_grad`接口的更多细节可以参考 [MindSpore Docs](https://www.mindspore.cn/docs/zh-CN/r2.4.0/api_python/mindspore/mindspore.value_and_grad.html)
From 443fffae69face6aebe9d01ca2cd7b88edff639d Mon Sep 17 00:00:00 2001
From: zhanghuiyao <1814619459@qq.com>
Date: Tue, 14 Jan 2025 17:53:50 +0800
Subject: [PATCH 6/7] update auto-grad
---
api-examples/auto-grad/README.md | 7 +-
...dients_of_function_with_mindspore_grad.md} | 8 +-
..._gradients_of_model_with_mindspore_grad.md | 75 +++++++++++++++++++
3 files changed, 83 insertions(+), 7 deletions(-)
rename api-examples/auto-grad/{taking_gradients_with_mindspore_grad.md => taking_gradients_of_function_with_mindspore_grad.md} (94%)
create mode 100644 api-examples/auto-grad/taking_gradients_of_model_with_mindspore_grad.md
diff --git a/api-examples/auto-grad/README.md b/api-examples/auto-grad/README.md
index f458b0b..d06d666 100644
--- a/api-examples/auto-grad/README.md
+++ b/api-examples/auto-grad/README.md
@@ -1,5 +1,6 @@
# 自动微分接口
-| 案例 | 说明 | MindSpore版本 |
-|:------------------------------------------------------------------|:----|:----------------|
-| [使用mindspore对函数求导](./taking_gradients_with_mindspore_grad.md) | 无 | MindSpore 2.4.0 |
+| 案例 | 说明 | MindSpore版本 |
+|:--------------------------------------------------------------------------|:----|:----------------|
+| [使用mindspore对函数求导](./taking_gradients_of_function_with_mindspore_grad.md) | 无 | MindSpore 2.4.0 |
+| [使用mindspore对模型求导](./taking_gradients_of_model_with_mindspore_grad.md) | 无 | MindSpore 2.4.0 |
diff --git a/api-examples/auto-grad/taking_gradients_with_mindspore_grad.md b/api-examples/auto-grad/taking_gradients_of_function_with_mindspore_grad.md
similarity index 94%
rename from api-examples/auto-grad/taking_gradients_with_mindspore_grad.md
rename to api-examples/auto-grad/taking_gradients_of_function_with_mindspore_grad.md
index 3c80552..fa88e59 100644
--- a/api-examples/auto-grad/taking_gradients_with_mindspore_grad.md
+++ b/api-examples/auto-grad/taking_gradients_of_function_with_mindspore_grad.md
@@ -6,7 +6,7 @@
- [使用`mindspore.grad`对函数进行求导](#section1)
- [使用`mindspore.grad`计算线性逻辑回归的梯度](#section2)
-- [使用`mindspore.value_and_grad`计算梯度与数值](#section3)
+- [使用`mindspore.value_and_grad`计算梯度与损失](#section3)
@@ -34,7 +34,7 @@ dfdx = grad(f)
d2fdx = grad(grad(f))
```
-当`x=1`时,对上述内容进行评估可以得到:
+当`x=1`时,对上述内容进行验证可以得到:
$$
\begin{aligned}
@@ -44,7 +44,7 @@ $$
\end{aligned}
$$
-使用MindSpore:
+在MindSpore中运行:
```python
print(f(Tensor(1.0)))
@@ -137,7 +137,7 @@ print(f'W_grad: {W_grad}, b_grad: {b_grad}')
-## 使用`mindspore.value_and_grad`计算梯度与数值
+## 使用`mindspore.value_and_grad`计算梯度与损失
计算线性逻辑回归函数的梯度,并获取`loss`:
diff --git a/api-examples/auto-grad/taking_gradients_of_model_with_mindspore_grad.md b/api-examples/auto-grad/taking_gradients_of_model_with_mindspore_grad.md
new file mode 100644
index 0000000..536fb17
--- /dev/null
+++ b/api-examples/auto-grad/taking_gradients_of_model_with_mindspore_grad.md
@@ -0,0 +1,75 @@
+# 使用mindspore对模型求导
+
+> 备注:以下示例使用的是 MindSpore 2.4.0 版本
+
+## 计算模型输出对权重的梯度
+
+先定义一个简单的模型
+
+```python
+from mindspore import nn, Parameter
+import numpy as np
+
+np.random.seed(0)
+
+class Net(nn.Cell):
+ def __init__(self):
+ super(Net, self).__init__()
+ self.w = Parameter(np.random.randn(4, 4))
+ self.b = Parameter(np.random.randn(4,))
+
+ def construct(self, x, y):
+ pred = self.w @ x + self.b
+ loss = (pred - y) ** 2
+ return loss.sum()
+
+net = Net()
+
+print(f"net: \n{net}")
+print(f"parameters of net: \n{net.get_parameters()}")
+```
+
+使用`mindspore.grad`接口计算模型对权重的梯度
+
+> 注意:这里不需要计算模型对输入的梯度,所以设置`grad_position=None`
+
+```python
+from mindspore import grad, Tensor
+import numpy as np
+
+x, y = Tensor(np.random.randn(2, 4)), Tensor([1.0])
+
+grad_fn = grad(net, grad_position=None, weights=net.get_parameters())
+
+grads = grad_fn(x, y)
+
+print(f"grads: \n{grads}")
+```
+
+当然我们也可以使用`mindspore.value_and_grad`接口计算模型对权重的梯度并获取`loss`的情况
+
+```python
+from mindspore import value_and_grad
+
+grad_fn = value_and_grad(net, grad_position=None, weights=net.get_parameters())
+
+loss, grads = grad_fn(x, y)
+
+print(f"loss: {loss}")
+print(f"grads: \n{grads}")
+```
+
+使用 `JIT(Just-In-Time)` 编译加速
+
+```python
+from mindspore import jit
+
+@jit
+def loss_and_grads():
+ return grad_fn(x, y)
+
+loss, grads = loss_and_grads(x, y)
+
+print(f"loss: {loss}")
+print(f"grads: \n{grads}")
+```
From 9bca803117e25e73d1207ecad81ae1442935df16 Mon Sep 17 00:00:00 2001
From: zhanghuiyao <1814619459@qq.com>
Date: Tue, 14 Jan 2025 17:55:12 +0800
Subject: [PATCH 7/7] update
---
.../auto-grad/taking_gradients_of_model_with_mindspore_grad.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/api-examples/auto-grad/taking_gradients_of_model_with_mindspore_grad.md b/api-examples/auto-grad/taking_gradients_of_model_with_mindspore_grad.md
index 536fb17..cdeedca 100644
--- a/api-examples/auto-grad/taking_gradients_of_model_with_mindspore_grad.md
+++ b/api-examples/auto-grad/taking_gradients_of_model_with_mindspore_grad.md
@@ -2,7 +2,7 @@
> 备注:以下示例使用的是 MindSpore 2.4.0 版本
-## 计算模型输出对权重的梯度
+## 计算模型对权重的梯度
先定义一个简单的模型