Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kd_loss implementation issue #74

Open
ZaberKo opened this issue Oct 7, 2023 · 9 comments
Open

kd_loss implementation issue #74

ZaberKo opened this issue Oct 7, 2023 · 9 comments

Comments

@ZaberKo
Copy link

ZaberKo commented Oct 7, 2023

Hello, I found that the knowledge_distillation_kl_div_loss() in mmdet/models/losses/kd_loss.py uses a different implementation compared to the normal KL Div definition, which is equivalent to F.kl_div(reduction='mean') instead of F.kl_div(reduction='batchmean') as mentioned in F.kl_div.

kd_loss = F.kl_div(
    F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * (
        T * T)

The correct KL Div should be like

kd_loss = F.kl_div(
    F.log_softmax(pred / T, dim=1), target, reduction='none').sum(1) * (
        T * T)

Is there any reason to use the above implementation? Current kl_div is 1/17 smaller than the real kl_div, when gfl reg_max=16.

@HikariTJU
Copy link
Owner

I remeber that .mean(1) is equal to reduction='batch_mean‘ ?

@ZaberKo
Copy link
Author

ZaberKo commented Oct 8, 2023

I remeber that .mean(1) is equal to reduction='batch_mean‘ ?

Here is the source code of F.kl_div:
https://github.com/pytorch/pytorch/blob/defa0d3a2d230e5d731d5c443c1b9beda2e7fd93/torch/nn/functional.py#L2949-L2958

And the problem here is that the kd_loss is subsequently averaged by @weighted_loss wrapper.

@HikariTJU
Copy link
Owner

So batch_mean equals .mean(0)?

@ZaberKo
Copy link
Author

ZaberKo commented Oct 8, 2023

So batch_mean equals .mean(0)?

No. "batchmean" means .sum()/batch_size, i.e., .sum(1).mean()

@HikariTJU
Copy link
Owner

HikariTJU commented Oct 8, 2023

OK, I get your point, you mean mathmatically .sum(1) is the correct implementation and .mean(1)=.sum(1)/16
That's true, but how is it related to batchmean?

@ZaberKo
Copy link
Author

ZaberKo commented Oct 14, 2023

OK, I get your point, you mean mathmatically .sum(1) is the correct implementation and .mean(1)=.sum(1)/16 That's true, but how is it related to batchmean?

BTW, I also found that loss_ld used weighted sum and was not divided by avg_factor (i.e. sum of weights). Is this a typo or intended behavior for not using normalization?

@ZaberKo
Copy link
Author

ZaberKo commented Oct 14, 2023

FYI: I record the factor ratio avg_factor/(self.reg_max+1) during the training. Maybe it will help this discussion.

image

@HikariTJU
Copy link
Owner

It's a intended behavior because experiment shows not dividing is better. Don't know the theory behind this though

@ZaberKo
Copy link
Author

ZaberKo commented Oct 14, 2023

It's a intended behavior because experiment shows not dividing is better. Don't know the theory behind this though

I see, thanks for the reply.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants