@@ -85,12 +85,20 @@ def _append_optimize_op(self, block, param_and_grad):
85
85
if getattr (param_and_grad [0 ], 'is_sparse_grad' , None ):
86
86
index = getattr (param_and_grad [0 ], 'index' , None )
87
87
axis = getattr (param_and_grad [0 ], 'axis' , None )
88
- _ , _ = paddle ._C_ops .sparse_momentum (
89
- param_and_grad [0 ], param_and_grad [1 ], velocity_acc , index , lr ,
90
- param_and_grad [0 ], velocity_acc , 'mu' , self ._momentum ,
91
- 'use_nesterov' , self ._use_nesterov , 'regularization_method' ,
92
- self ._regularization_method , 'regularization_coeff' ,
93
- self ._regularization_coeff , 'axis' , axis )
88
+ try :
89
+ _ , _ = paddle ._C_ops .sparse_momentum (
90
+ param_and_grad [0 ], param_and_grad [1 ], velocity_acc , index , lr ,
91
+ param_and_grad [0 ], velocity_acc , 'mu' , self ._momentum ,
92
+ 'use_nesterov' , self ._use_nesterov , 'regularization_method' ,
93
+ self ._regularization_method , 'regularization_coeff' ,
94
+ self ._regularization_coeff , 'axis' , axis )
95
+ except :
96
+ _ , _ , _ = paddle ._C_ops .sparse_momentum (
97
+ param_and_grad [0 ], param_and_grad [1 ], velocity_acc , index , lr , master_weight ,
98
+ param_and_grad [0 ], velocity_acc , master_weight , 'mu' , self ._momentum ,
99
+ 'use_nesterov' , self ._use_nesterov , 'regularization_method' ,
100
+ self ._regularization_method , 'regularization_coeff' ,
101
+ self ._regularization_coeff , 'axis' , axis , 'multi_precision' , find_master )
94
102
else :
95
103
_ , _ , _ = paddle ._C_ops .momentum (
96
104
param_and_grad [0 ], param_and_grad [1 ], velocity_acc , lr ,
0 commit comments