Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Readability refactor + normalize classification loss #252

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

ggaziv
Copy link
Contributor

@ggaziv ggaziv commented May 4, 2020

Current classification loss is not normalized, enabling it to dominate the total loss.
image
image

@zylo117
Copy link
Owner

zylo117 commented May 5, 2020

Good job on the refactor.

I don't think it's not necessary to do normalization.
Isn't that this line does the normalization?
cls_loss.sum() / num_positive_anchors.to(dtype)

Loss from all anchors divided by the number of the positive anchors. I think it means the cls_loss per positive anchor, it's a mean loss already, right?

@ggaziv
Copy link
Contributor Author

ggaziv commented May 5, 2020

Thanks.

  1. It is clear that the classification loss is not properly normalized - see my plot indicating it has order of magnitude 2e5 before the fix.
  2. It can stem either in the fact that num_positive_anchors is not the proper normalization factor or from classification_losses.append(cls_loss.sum()) (here), where it is not normalized at all.

Note that in the custom dataset I'm working on, I consider images with many annotations each - it is probably what exacerbates the issue.

@zylo117
Copy link
Owner

zylo117 commented May 5, 2020

Thanks.

  1. It is clear that the classification loss is not properly normalized - see my plot indicating it has order of magnitude 2e5 before the fix.
  2. It can stem either in the fact that num_positive_anchors is not the proper normalization factor or from classification_losses.append(cls_loss.sum()) (here), where it is not normalized at all.

Note that in the custom dataset I'm working on, I consider images with many annotations each - it is probably what exacerbates the issue.

No, it's just dataset issue or improper parameters. My cls loss is only 0.25.
You see try running the tutorial. It's definitely not that high.

@ggaziv
Copy link
Contributor Author

ggaziv commented May 5, 2020

Do you agree that with proper normalization the loss should be bounded regardless of the dataset?

Thanks.

  1. It is clear that the classification loss is not properly normalized - see my plot indicating it has order of magnitude 2e5 before the fix.
  2. It can stem either in the fact that num_positive_anchors is not the proper normalization factor or from classification_losses.append(cls_loss.sum()) (here), where it is not normalized at all.

Note that in the custom dataset I'm working on, I consider images with many annotations each - it is probably what exacerbates the issue.

No, it's just dataset issue or improper parameters. My cls loss is only 0.25.
You see try running the tutorial. It's definitely not that high.

@zylo117
Copy link
Owner

zylo117 commented May 6, 2020

I agree, but it's already done. Here.
cls_loss.sum() / num_positive_anchors.to(dtype)

@ggaziv
Copy link
Contributor Author

ggaziv commented May 11, 2020

@zylo117 see new commit. There was exploding classification loss by zero division there (saw other commenting on that as well).
Also I found descrepancies from original implementation re anchors and matching, e.g., see here and here.

Also note that in the original implementation the targets below the unmatched threshold (0.5 there unlike 0.4 here) are considered negative, while here they are ignored.
I currently left this behavior unchanged while enabling their behavior in the code as well.
Thanks.

train.py Outdated Show resolved Hide resolved
efficientdet/loss.py Outdated Show resolved Hide resolved
@dkoguciuk
Copy link

@ggaziv , are you sure your code is correct? In @zylo117 's implementation target can have 3 different values:

  • 1 - this is the positive anchor
  • 0 - this is negative anchor
  • -1 - this anchor will be ignored

So for with self.negatives_lower_than_unmatched=True:

            if self.negatives_lower_than_unmatched:  
                # negative matches are the ones below the unmatched_threshold, whereas ignored matches are in between the matched and unmatched
                targets[torch.lt(IoU_max, self.matched_threshold) & torch.ge(IoU_max, self.unmatched_threshold), :] = 0

You would ignore anchors with IoU < 0.4, which is an undesired behavior IMO. With self.negatives_lower_than_unmatched=False:

            else:  
                # Ignore targets with overlap lower than unmatched_threshold
                targets[torch.lt(IoU_max, self.unmatched_threshold), :] = 0

we get what we need and the comment here (# Ignore targets with overlap lower than unmatched_threshold) is wrong.

BTW. Where did you find in the original implementation that the targets below the 0.5 threshold are considered as negatives?

@ggaziv
Copy link
Contributor Author

ggaziv commented May 12, 2020

@dkoguciuk you are right re the implementation inversion - your legend of the target values helped correcting this (see updated fix).

BTW. Where did you find in the original implementation that the targets below the 0.5 threshold are considered as negatives?
https://github.com/google/automl/blob/c470de8915e320c02c8ed39d0b8fc3ded99cfc64/efficientdet/anchors.py#L480

@dkoguciuk
Copy link

@ggaziv, thanks for the link!

This is the docstring from ArgMaxMatcher class:

negatives_lower_than_unmatched: Boolean which defaults to True. If True
then negative matches are the ones below the unmatched_threshold,
whereas ignored matches are in between the matched and unmatched
threshold. If False, then negative matches are in between the matched
and unmatched threshold, and everything lower than unmatched is ignored

And it's invocation:

matcher = argmax_matcher.ArgMaxMatcher(
        match_threshold,
        unmatched_threshold=match_threshold,
        negatives_lower_than_unmatched=True,
        force_match_for_each_row=True)

So, as far as I understand - this is the exact @zylo117's implementation, or did I miss something?

Nevertheless, the PR is still useful, and yes, the code LGTM 👍

@zylo117
Copy link
Owner

zylo117 commented May 12, 2020

I've just tested on this branch on shape dataset. Low loss, low mAP too. But the refactor part seems ok.

@ggaziv
Copy link
Contributor Author

ggaziv commented May 12, 2020

@zylo117 I suggest you try again with the latest which I just forced pushed (after @dkoguciuk feedback). It should currently not alter the behavior at all other than avoiding zero division (causing classification loss to explode).

Net-net this PR extends code utility and adds a minor fix w.r.t official code.

@zylo117
Copy link
Owner

zylo117 commented May 13, 2020

@zylo117 I suggest you try again with the latest which I just forced pushed (after @dkoguciuk feedback). It should currently not alter the behavior at all other than avoiding zero division (causing classification loss to explode).

Net-net this PR extends code utility and adds a minor fix w.r.t official code.

I still can't get a normal result.

Step: 1287. Epoch: 45/50. Iteration: 28/28. Cls loss: 0.17737. Reg loss: 0.00548. Total loss: 0.18285: 100% 28/28 [00:38<00:00,  1.39s/it]
Val. Epoch: 45/50. Classification loss: 0.16009. Regression loss: 0.00668. Total loss: 0.16677
Step: 1315. Epoch: 46/50. Iteration: 28/28. Cls loss: 0.19635. Reg loss: 0.00514. Total loss: 0.20148: 100% 28/28 [00:38<00:00,  1.38s/it]
Val. Epoch: 46/50. Classification loss: 0.15749. Regression loss: 0.00632. Total loss: 0.16381
Step: 1343. Epoch: 47/50. Iteration: 28/28. Cls loss: 0.17216. Reg loss: 0.00930. Total loss: 0.18146: 100% 28/28 [00:38<00:00,  1.38s/it]
Val. Epoch: 47/50. Classification loss: 0.15515. Regression loss: 0.00642. Total loss: 0.16157
Step: 1371. Epoch: 48/50. Iteration: 28/28. Cls loss: 0.19578. Reg loss: 0.00617. Total loss: 0.20195: 100% 28/28 [00:38<00:00,  1.38s/it]
Val. Epoch: 48/50. Classification loss: 0.15291. Regression loss: 0.00648. Total loss: 0.15939
Step: 1399. Epoch: 49/50. Iteration: 28/28. Cls loss: 0.17510. Reg loss: 0.00524. Total loss: 0.18034: 100% 28/28 [00:38<00:00,  1.38s/it]
Val. Epoch: 49/50. Classification loss: 0.15087. Regression loss: 0.00635. Total loss: 0.15722

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.001
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.002
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.001
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.001
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.027
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.037
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.042
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.027

And the cls loss is at least 3 times higher than this master branch's.
What did I miss? Can you train well on shape dataset?

@ggaziv
Copy link
Contributor Author

ggaziv commented May 13, 2020

What did I miss? Can you train well on shape dataset?

I didn't try but from looking at the code it should be exactly the same as master other than +1 to num_positive_anchors. I suggest just remove the +1 to verify that it is indeed functionally identical.

Just note that if you are running directly on my branch it is lagged w.r.t current master so only consider the PR changes.

@ludomitch
Copy link

Hello,

Thanks for all the great work you've done @zylo117!

I am having a similar problem. Has this been resolved now in the latest master branch?

I am training on my custom dataset with three classes which I have already trained on with the signatrix implementation and the training was very stable with the loss decreasing consistently from the beginning. However, it doesn't support using different backbone versions so I switched to this repo.

I am getting a huge classification loss oscillating between 3000-40000 and doesn't seem to stabilise or decrease.

Here is the command I am running:
python3 -m train -c 0 -p my_project --batch_size 8 --lr 6e-5 --num_epochs 10 --load_weights /path/to/weights/efficientdet-d0.pth --head_only True

I have tried with both d0 and d4 pretrained weights.

Any help would be much appreciated!

@zylo117
Copy link
Owner

zylo117 commented May 15, 2020

Hello,

Thanks for all the great work you've done @zylo117!

I am having a similar problem. Has this been resolved now in the latest master branch?

I am training on my custom dataset with three classes which I have already trained on with the signatrix implementation and the training was very stable with the loss decreasing consistently from the beginning. However, it doesn't support using different backbone versions so I switched to this repo.

I am getting a huge classification loss oscillating between 3000-40000 and doesn't seem to stabilise or decrease.

Here is the command I am running:
python3 -m train -c 0 -p my_project --batch_size 8 --lr 6e-5 --num_epochs 10 --load_weights /path/to/weights/efficientdet-d0.pth --head_only True

I have tried with both d0 and d4 pretrained weights.

Any help would be much appreciated!

Loss won't decrease with that low lr, try 1e-4 for the first few epochs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants