Skip to content

Commit 731eb3f

Browse files
authored
Compatible upgrade of sparse_momentum for master param (#107)
1 parent ef0a32b commit 731eb3f

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

dynamic/utils/hybrid_optimizer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,20 @@ def _append_optimize_op(self, block, param_and_grad):
8585
if getattr(param_and_grad[0], 'is_sparse_grad', None):
8686
index = getattr(param_and_grad[0], 'index', None)
8787
axis = getattr(param_and_grad[0], 'axis', None)
88-
_, _ = paddle._C_ops.sparse_momentum(
89-
param_and_grad[0], param_and_grad[1], velocity_acc, index, lr,
90-
param_and_grad[0], velocity_acc, 'mu', self._momentum,
91-
'use_nesterov', self._use_nesterov, 'regularization_method',
92-
self._regularization_method, 'regularization_coeff',
93-
self._regularization_coeff, 'axis', axis)
88+
try:
89+
_, _ = paddle._C_ops.sparse_momentum(
90+
param_and_grad[0], param_and_grad[1], velocity_acc, index, lr,
91+
param_and_grad[0], velocity_acc, 'mu', self._momentum,
92+
'use_nesterov', self._use_nesterov, 'regularization_method',
93+
self._regularization_method, 'regularization_coeff',
94+
self._regularization_coeff, 'axis', axis)
95+
except:
96+
_, _, _ = paddle._C_ops.sparse_momentum(
97+
param_and_grad[0], param_and_grad[1], velocity_acc, index, lr, master_weight,
98+
param_and_grad[0], velocity_acc, master_weight, 'mu', self._momentum,
99+
'use_nesterov', self._use_nesterov, 'regularization_method',
100+
self._regularization_method, 'regularization_coeff',
101+
self._regularization_coeff, 'axis', axis, 'multi_precision', find_master)
94102
else:
95103
_, _, _ = paddle._C_ops.momentum(
96104
param_and_grad[0], param_and_grad[1], velocity_acc, lr,

0 commit comments

Comments
 (0)