Skip to content

Commit

Permalink
fix nan in ppocrv4 for benchmark (PaddlePaddle#14072)
Browse files Browse the repository at this point in the history
* fix nan in ppocrv4 for benchmark

* fix config
  • Loading branch information
wangna11BD authored Oct 23, 2024
1 parent 8327f79 commit 661cda1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
1 change: 1 addition & 0 deletions configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Architecture:
Head:
name: DBHead
k: 50
fix_nan: True

Loss:
name: DBLoss
Expand Down
1 change: 1 addition & 0 deletions configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Architecture:
name: PFHeadLocal
k: 50
mode: "large"
fix_nan: True


Loss:
Expand Down
8 changes: 5 additions & 3 deletions ppocr/modeling/heads/det_db_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_bias_attr(k):


class Head(nn.Layer):
def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs):
def __init__(self, in_channels, kernel_list=[3, 2, 2], fix_nan=False, **kwargs):
super(Head, self).__init__()

self.conv1 = nn.Conv2D(
Expand Down Expand Up @@ -73,14 +73,16 @@ def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs):
bias_attr=get_bias_attr(in_channels // 4),
)

self.fix_nan = fix_nan

def forward(self, x, return_f=False):
x = self.conv1(x)
x = self.conv_bn1(x)
if self.training:
if self.fix_nan and self.training:
x = paddle.where(paddle.isnan(x), paddle.zeros_like(x), x)
x = self.conv2(x)
x = self.conv_bn2(x)
if self.training:
if self.fix_nan and self.training:
x = paddle.where(paddle.isnan(x), paddle.zeros_like(x), x)
if return_f is True:
f = x
Expand Down

0 comments on commit 661cda1

Please sign in to comment.