Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Trained OCR model: Mismatch Error in Character Set Size Leading #1226

Open
geheim01 opened this issue Mar 19, 2024 · 1 comment
Open

Comments

@geheim01
Copy link

I've trained an OCR model on a specialized dataset by following the methodology outlined in the README of the deep-text-recognition-benchmark repository. My setup includes the model's architecture defined in my_model.py, alongside the my_model.pth and my_model.yaml files.

Currently, I'm encountering an issue related to the character set used for training, which consists of 44 characters. Specifying the identical character set in the .yaml file triggers the following RuntimeError when initializing the reader with easyocr.Reader(['en'], recogn_network='my_model'), pointing to a discrepancy in the torch tensor dimensions by one:

RuntimeError: Error(s) in loading state_dict for Model:
    size mismatch for Prediction.attention_cell.rnn.weight_ih: copying a param with shape torch.Size([1024, 301]) from checkpoint, the shape in current model is torch.Size([1024, 300]).
    size mismatch for Prediction.generator.weight: copying a param with shape torch.Size([45, 256]) from checkpoint, the shape in current model is torch.Size([44, 256]).
    size mismatch for Prediction.generator.bias: copying a param with shape torch.Size([45]) from checkpoint, the shape in current model is torch.Size([44]).

I found a workaround solution by incrementing the character set size by adding a additional character (including a leading whitespace), which allows the model to generate predictions. However, this adjustment negatively impacts the prediction accuracy compared to using my model directly with the deep-text-recognition benchmark.

The modification causes the model to produce significantly longer strings than expected most of the time. For instance:

The image text AA BB 123 is predicted as AA BB 123 123 123 123 using reader.recognize(img, allowlist=allow_list, detail=1)[0]
In contrast, direct predictions with my model output the correct AA BB 123.

This issue leads me to believe that the adjustment might be interfering with the recognition of the stop character.

Could anyone provide insights or suggestions on how to address this problem?

  • Is there a specific change needed for the character set?
  • Could the issue be related to how the stop character is handled?

Any guidance or advice would be greatly appreciated.

@geheim01 geheim01 changed the title Custom Trained OCR model: Mismatch Error in Character Set Size Leading to Inaccurate Predictions with EasyOCR Integration Custom Trained OCR model: Mismatch Error in Character Set Size Leading Mar 19, 2024
@geheim01
Copy link
Author

geheim01 commented Mar 19, 2024

Could the error be in the CTCLabelConverter because I use Attn instead of CTC for prediction? and accordingly use the AttnLabelConverter from the deep-text-recognition-benchmark repository in my training?

I looked more into this and it seems that the get_recognizer() function that is called with easyocr.Reader() does not support another LabelConverter next to CTC.

`def get_recognizer(recog_network, network_params, character,
separator_list, dict_list, model_path,
device = 'cpu', quantize = True):

converter = CTCLabelConverter(character, separator_list, dict_list)

num_class = len(converter.character)

if recog_network == 'generation1':
    model_pkg = importlib.import_module("easyocr.model.model")
elif recog_network == 'generation2':
    model_pkg = importlib.import_module("easyocr.model.vgg_model")
else:
    model_pkg = importlib.import_module(recog_network)
model = model_pkg.Model(num_class=num_class, **network_params)

if device == 'cpu':
    state_dict = torch.load(model_path, map_location=device)
    new_state_dict = OrderedDict()
    for key, value in state_dict.items():
        new_key = key[7:]
        new_state_dict[new_key] = value
    model.load_state_dict(new_state_dict)
    if quantize:
        try:
            torch.quantization.quantize_dynamic(model, dtype=torch.qint8, inplace=True)
        except:
            pass

Can't I just use my TPS-ResNet-BiLSTM-Attn Model trained with the deep-text-recognition benchmark model here out of the box with the easyocr package using easy.ocr(recog_network=...)?
If so, what models can I train and use within easyocr?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant