Skip to content

Commit

Permalink
🔥 [Remove] & fix typo of momentum schedule
Browse files Browse the repository at this point in the history
  • Loading branch information
henrytsui000 committed Nov 23, 2024
1 parent 67fbfa0 commit b4dad5e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
1 change: 1 addition & 0 deletions yolo/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class DataConfig:
class OptimizerArgs:
lr: float
weight_decay: float
momentum: float


@dataclass
Expand Down
15 changes: 7 additions & 8 deletions yolo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.distributed as dist
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities import rank_zero_only
from omegaconf import ListConfig
from torch import Tensor, no_grad
from torch.optim import Optimizer
Expand Down Expand Up @@ -77,9 +76,9 @@ def create_optimizer(model: YOLO, optim_cfg: OptimizerConfig) -> Optimizer:
conv_params = [p for name, p in model.named_parameters() if "weight" in name and "bn" not in name]

model_parameters = [
{"params": bias_params, "momentum": 0.8, "weight_decay": 0},
{"params": conv_params, "momentum": 0.8},
{"params": norm_params, "momentum": 0.8, "weight_decay": 0},
{"params": bias_params, "momentum": 0.937, "weight_decay": 0},
{"params": conv_params, "momentum": 0.937},
{"params": norm_params, "momentum": 0.937, "weight_decay": 0},
]

def next_epoch(self, batch_num, epoch_idx):
Expand All @@ -89,8 +88,8 @@ def next_epoch(self, batch_num, epoch_idx):
# 0.937: Start Momentum
# 0.8 : Normal Momemtum
# 3 : The warm up epoch num
self.min_mom = lerp(0.937, 0.8, max(epoch_idx, 3), 3)
self.max_mom = lerp(0.937, 0.8, max(epoch_idx + 1, 3), 3)
self.min_mom = lerp(0.937, 0.8, min(epoch_idx, 3), 3)
self.max_mom = lerp(0.937, 0.8, min(epoch_idx + 1, 3), 3)
self.batch_num = batch_num
self.batch_idx = 0

Expand All @@ -100,7 +99,7 @@ def next_batch(self):
for lr_idx, param_group in enumerate(self.param_groups):
min_lr, max_lr = self.min_lr[lr_idx], self.max_lr[lr_idx]
param_group["lr"] = lerp(min_lr, max_lr, self.batch_idx, self.batch_num)
param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
# param_group["momentum"] = lerp(self.min_mom, self.max_mom, self.batch_idx, self.batch_num)
lr_dict[f"LR/{lr_idx}"] = param_group["lr"]
return lr_dict

Expand All @@ -125,7 +124,7 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR
lambda1 = lambda epoch: (epoch + 1) / wepoch if epoch < wepoch else 1
lambda2 = lambda epoch: 10 - 9 * ((epoch + 1) / wepoch) if epoch < wepoch else 1
warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda2, lambda1, lambda1])
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[wepoch - 1])
return schedule


Expand Down

0 comments on commit b4dad5e

Please sign in to comment.