Skip to content

Commit

Permalink
fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
qlan3 authored Aug 9, 2023
1 parent dfc8e92 commit ab902e3
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions components/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ def __init__(self, optimizer_name, cfg, seed):
self.optimizer = Optim4RL(**cfg)
self.is_rnn_output = lambda x: type(x)==tuple and type(x[0])==tuple and type(x[1])!=tuple
elif optimizer_name in ['LinearOptim', 'L2LGD2']:
self.optimizer = LinearOptim(**cfg)
if optimizer_name == 'LinearRNNOptimizer':
self.optimizer = LinearRNNOptimizer(**cfg)
elif optimizer_name == 'L2LGD2':
self.optimizer = L2LGD2(**cfg)
self.is_rnn_output = lambda x: type(x)==tuple and type(x[0])!=tuple and type(x[1])!=tuple
# Initialize param for RNN optimizer
if len(self.param_load_path) > 0:
Expand Down Expand Up @@ -338,4 +341,4 @@ def update_with_param(self, optim_param, grad, optim_state, loss):
"""
self.optimizer.theta = optim_param # This is important, do not remove
optim_state = self.optimizer.update(optim_state, grad, loss)
return optim_state
return optim_state

0 comments on commit ab902e3

Please sign in to comment.