diff --git a/.gitignore b/.gitignore deleted file mode 100644 index dfaefa2..0000000 --- a/.gitignore +++ /dev/null @@ -1,17 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ - -# Jupyter Notebook -.ipynb_checkpoints - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Environments -.env -.venv - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json \ No newline at end of file diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 261eeb9..0000000 --- a/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/README.md b/README.md deleted file mode 100644 index cbde2ad..0000000 --- a/README.md +++ /dev/null @@ -1,121 +0,0 @@ -## Paper: Towards Diverse and Consistent Typography Generation -Wataru Shimoda1, Daichi Haraguchi2, Seiichi Uchida2, Kota Yamaguchi1 -1CyberAgent.Inc, 2 Kyushu University -Accepted to WACV2024. -[[Arxiv](https://arxiv.org/abs/2309.02099)] -[[project-page]()] - -## Introduction -This repository contains the codes for ["Towards Diverse and Consistent Typography Generation"](https://arxiv.org/abs/2309.02099). - - - -## Requirements -We check the reproducibility under the environment. -- Ubuntu (>=20.04) -- Python3 (>=3.8) - - -## Install -Clone this repository and navigate to the folder - -``` sh -git clone https://github.com/CyberAgentAILab/tdc-typography-generation.git -cd tdc-typography-generation -``` - -We manage the dependencies of python libraries by [pyproject.toml](https://github.com/CyberAgentAILab/tdc-typography-generation/blob/main/pyproject.toml). -Please install the dependencies via pip or poetry using [pyproject.toml](https://github.com/CyberAgentAILab/tdc-typography-generation/blob/main/pyproject.toml). - - -If the version of setuptools is `setuptools >=61.0.0`, the following command installs the dependencies via pip: -``` sh -pip install . -``` -or -We recommend installing the dependencies using Poetry (see [official docs](https://python-poetry.org/docs/)): -``` sh -poetry install -``` -Note that we omit the head of commands `poetry run` in the after guidance for simplification. - -## Dataset -Our model is trained and tested on [Crello dataset](https://huggingface.co/datasets/cyberagent/crello), and this dataset is open on [Hugging Face](https://huggingface.co/). -We can download this dataset through the [Hugging Face Dataset API](https://huggingface.co/docs/datasets/index), but this dataset does not contain high-resolution background images. - -We provide background images via Google Drive ([link](https://storage.googleapis.com/ailab-public/tdc_typography_generation/generate_bg_png.tar.gz), 3.6GB). -Please download the background images and locate them to `data/generate_bg_png/`. -``` sh -tar zxvf generate_bg_png.tar.gz -mv generate_bg_png data/ -rm generate_bg_png.tar.gz -``` - -We also provide font files for rendering designs and computing appropriate text sizes via Google Drive ([link](https://storage.googleapis.com/ailab-public/tdc_typography_generation/font.tar.gz), 43MB). -Please download the font files and locate them to `data/font/`. -``` sh -tar zxvf font.tar.gz -mv font data/ -rm font.tar.gz -``` - -## Usage -We prepare scripts for experiments as the following. - -### Preprocessing -We recommend adding features to the dataset in advance, it makes training and testing faster. -We provide a script for preprocessing: -``` sh -python -m typography_generation map_features --datadir data -``` -This script extends the dataset via [map function](https://huggingface.co/docs/datasets/v2.15.0/en/package_reference/main_classes#datasets.Dataset.map). -The extended dataset is saved in `data/map_featreus`, and `--use_extended_dataset` option manages the use of the extended dataset. - -### Training -The following command trains a model, it takes a half day with the preprocessed dataset and a NVIDIA T4 machine. -We handle the detail of training via configuration files in `data/config/*.yaml`. -The basic configurations are in `src/typography_generation/config/*.py`. - -``` sh -python -m typography_generation train_eval \ - --configname bart \ - --jobdir ${OUTPUT_DIR} \ - --datadir data \ - --use_extended_dataset \ - --gpu \ -``` -The outputs are in ${OUTPUT_DIR}. - -### Sampling -The following command samples typographic attributes. -This command requires `--weight` option, which is a path for loading weights of a trained model. -A weight file obtained by the avobe training command is in `${OUTPUT_DIR}/weight.pth`. -Please assign a path of a weight file to `${WEIGHT_FILE}`. -``` sh -python -m typography_generation structure_preserved_sample \ - --configname bart \ - --jobdir ${OUTPUT_DIR} \ - --datadir data \ - --weight=${WEIGHT_FILE} \ - --use_extended_dataset \ - --gpu \ -``` - - -## Visualization -We provides notebooks for showing results. -- `notebooks/score.ipnyb` shows scores of the saved results in `${OUTPUT_DIR}`. -- `notebooks/vis.ipnyb` shows generated graphic designs in `${OUTPUT_DIR}`. - -## Reference -```bibtex -@misc{shimoda_2024_tdctg, - author = {Shimoda, Wataru and Haraguchi, Daichi and Uchida, Seiichi and Yamaguchi, Kota}, - title = {Towards Diverse and Consistent Typography Generation}, - publisher = {arXiv:2309.02099}, - year = {2024}, -} -``` - -## Contact -This repository is maintained by Wataru shimoda(wataru_shimoda[at]cyberagent.co.jp). \ No newline at end of file diff --git a/data/cluster/canvas_aspect_ratio_16.pkl b/data/cluster/canvas_aspect_ratio_16.pkl deleted file mode 100644 index 189ec37..0000000 Binary files a/data/cluster/canvas_aspect_ratio_16.pkl and /dev/null differ diff --git a/data/cluster/text_angle_16.pkl b/data/cluster/text_angle_16.pkl deleted file mode 100644 index b1882c8..0000000 Binary files a/data/cluster/text_angle_16.pkl and /dev/null differ diff --git a/data/cluster/text_center_x_64.pkl b/data/cluster/text_center_x_64.pkl deleted file mode 100644 index ed9e334..0000000 Binary files a/data/cluster/text_center_x_64.pkl and /dev/null differ diff --git a/data/cluster/text_center_y_64.pkl b/data/cluster/text_center_y_64.pkl deleted file mode 100644 index 0eec0ea..0000000 Binary files a/data/cluster/text_center_y_64.pkl and /dev/null differ diff --git a/data/cluster/text_font_color_64.pkl b/data/cluster/text_font_color_64.pkl deleted file mode 100644 index 5ed7216..0000000 Binary files a/data/cluster/text_font_color_64.pkl and /dev/null differ diff --git a/data/cluster/text_font_size_16.pkl b/data/cluster/text_font_size_16.pkl deleted file mode 100644 index 3cea5bf..0000000 Binary files a/data/cluster/text_font_size_16.pkl and /dev/null differ diff --git a/data/cluster/text_height_16.pkl b/data/cluster/text_height_16.pkl deleted file mode 100644 index d00a399..0000000 Binary files a/data/cluster/text_height_16.pkl and /dev/null differ diff --git a/data/cluster/text_left_64.pkl b/data/cluster/text_left_64.pkl deleted file mode 100644 index 5e4dc1c..0000000 Binary files a/data/cluster/text_left_64.pkl and /dev/null differ diff --git a/data/cluster/text_letter_spacing_16.pkl b/data/cluster/text_letter_spacing_16.pkl deleted file mode 100644 index d0f2e6e..0000000 Binary files a/data/cluster/text_letter_spacing_16.pkl and /dev/null differ diff --git a/data/cluster/text_line_height_scale_16.pkl b/data/cluster/text_line_height_scale_16.pkl deleted file mode 100644 index 983bde4..0000000 Binary files a/data/cluster/text_line_height_scale_16.pkl and /dev/null differ diff --git a/data/cluster/text_top_64.pkl b/data/cluster/text_top_64.pkl deleted file mode 100644 index e097126..0000000 Binary files a/data/cluster/text_top_64.pkl and /dev/null differ diff --git a/data/cluster/text_width_16.pkl b/data/cluster/text_width_16.pkl deleted file mode 100644 index bedb577..0000000 Binary files a/data/cluster/text_width_16.pkl and /dev/null differ diff --git a/data/config/bart/canvas_embedding_attribute_config.yaml b/data/config/bart/canvas_embedding_attribute_config.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/data/config/bart/canvas_embedding_flag_config.yaml b/data/config/bart/canvas_embedding_flag_config.yaml deleted file mode 100644 index 867fe77..0000000 --- a/data/config/bart/canvas_embedding_flag_config.yaml +++ /dev/null @@ -1,5 +0,0 @@ -canvas_embedding_flag: - canvas_bg_img: False - canvas_bg_img_emb: True - canvas_aspect_ratio: True - canvas_text_num: True diff --git a/data/config/bart/data_config.yaml b/data/config/bart/data_config.yaml deleted file mode 100644 index 097add5..0000000 --- a/data/config/bart/data_config.yaml +++ /dev/null @@ -1,8 +0,0 @@ -data_config: - font_num: 261 - large_spatio_bin: 64 - small_spatio_bin: 16 - color_bin: 64 - max_text_line_count: 50 - font_emb_type: "label" - order_type: "rasterscan_asc" diff --git a/data/config/bart/metainfo.yaml b/data/config/bart/metainfo.yaml deleted file mode 100644 index 91da1d4..0000000 --- a/data/config/bart/metainfo.yaml +++ /dev/null @@ -1,3 +0,0 @@ -meta_info: - model_name: "bart" - dataset: "crello" diff --git a/data/config/bart/model_config.yaml b/data/config/bart/model_config.yaml deleted file mode 100644 index 6f5a014..0000000 --- a/data/config/bart/model_config.yaml +++ /dev/null @@ -1,9 +0,0 @@ -model_config: - d_model: 256 - num_encoder_layers: 8 - num_decoder_layers: 8 - n_head: 8 - dropout: 0.1 - bert_dim: 768 - clip_dim: 512 - mlp_dim: 1792 diff --git a/data/config/bart/test_config.yaml b/data/config/bart/test_config.yaml deleted file mode 100644 index 1831047..0000000 --- a/data/config/bart/test_config.yaml +++ /dev/null @@ -1,4 +0,0 @@ -test_config: - sampling_param: 0.9 - sampling_param_geometry: 0.1 - sampling_param_semantic: 0.9 diff --git a/data/config/bart/text_element_embedding_attribute_config.yaml b/data/config/bart/text_element_embedding_attribute_config.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/data/config/bart/text_element_embedding_flag_config.yaml b/data/config/bart/text_element_embedding_flag_config.yaml deleted file mode 100644 index 119d966..0000000 --- a/data/config/bart/text_element_embedding_flag_config.yaml +++ /dev/null @@ -1,19 +0,0 @@ -text_element_embedding_flag: - text_font: False - text_font_size: False - text_font_color: False - text_angle: False - text_letter_spacing: False - text_line_height_scale: False - text_capitalize: False - text_line_count: True - text_char_count: True - text_height: False - text_width: False - text_top: False - text_left: False - text_center_y: True - text_center_x: True - text_align_type: False - text_emb: True - text_local_img_emb: True diff --git a/data/config/bart/text_element_prediction_attribute_config.yaml b/data/config/bart/text_element_prediction_attribute_config.yaml deleted file mode 100644 index e69de29..0000000 diff --git a/data/config/bart/text_element_prediction_flag_config.yaml b/data/config/bart/text_element_prediction_flag_config.yaml deleted file mode 100644 index dda973f..0000000 --- a/data/config/bart/text_element_prediction_flag_config.yaml +++ /dev/null @@ -1,9 +0,0 @@ -text_element_prediction_flag: - text_font: True - text_font_size: True - text_font_color: True - text_angle: True - text_letter_spacing: True - text_line_height_scale: True - text_capitalize: True - text_align_type: True diff --git a/data/config/bart/train_config.yaml b/data/config/bart/train_config.yaml deleted file mode 100644 index 12780fd..0000000 --- a/data/config/bart/train_config.yaml +++ /dev/null @@ -1,7 +0,0 @@ -train_config: - epochs: 31 - save_epoch: 5 - batch_size: 32 - num_worker: 2 - train_only: False - show_interval: 100 diff --git a/data/font2ttf.pkl b/data/font2ttf.pkl deleted file mode 100644 index 9903ce6..0000000 Binary files a/data/font2ttf.pkl and /dev/null differ diff --git a/data/fonttype2fontid_fix.pkl b/data/fonttype2fontid_fix.pkl deleted file mode 100644 index f2ea1a4..0000000 Binary files a/data/fonttype2fontid_fix.pkl and /dev/null differ diff --git a/data/svgid2scaleinfo.pkl b/data/svgid2scaleinfo.pkl deleted file mode 100644 index 68e0d4e..0000000 Binary files a/data/svgid2scaleinfo.pkl and /dev/null differ diff --git a/notebooks/score.ipynb b/notebooks/score.ipynb deleted file mode 100644 index 8363651..0000000 --- a/notebooks/score.ipynb +++ /dev/null @@ -1,114 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "18.186068109361116\n", - "59.10895250864158\n", - "89.43196472690855\n", - "99.98559492941517\n", - "29.720623040124977\n", - "0.381246451385764\n", - "2.807381701892659\n", - "0.07559504299962953\n" - ] - } - ], - "source": [ - "import os\n", - "from matplotlib import pylab as plt\n", - "import numpy as np\n", - "import statistics\n", - "import math\n", - "\n", - "def get_score(prefix, score_index, scalar):\n", - " tmppath = f\"{prefix}/logs/score.txt\"\n", - " with open(tmppath) as f:\n", - " scores = f.readlines()\n", - " score = scores[score_index].replace('\\n', '').split(\" \")[-2]\n", - " score = float(score)*scalar\n", - " return score\n", - "\n", - "def get_structure_score(prefix, score_index,scalar,type=None):\n", - " gcspath = f\"gs://ailab-ephemeral/s09275/job_gcp_eval/{prefix}/logs/score.txt\"\n", - " _prefix = prefix.split(\"/\")[-1]\n", - " tmppath = f\"tmp_data/dl/{_prefix}.txt\"\n", - " com = f\"gsutil cp {gcspath} {tmppath}\"\n", - " if os.path.isfile(tmppath) is False:\n", - " os.system(com)\n", - " with open(tmppath) as f:\n", - " scores = f.readlines()\n", - " score0 = scores[score_index].replace('\\n', '').split(\" \")[-2]\n", - " score0 = float(score0)*scalar\n", - " count0 = int(scores[score_index].replace('\\n', '').split(\" \")[-3].split(\":\")[-1])\n", - " score1 = scores[score_index+1].replace('\\n', '').split(\" \")[-2]\n", - " score1 = float(score1)*scalar\n", - " count1 = int(scores[score_index+1].replace('\\n', '').split(\" \")[-3].split(\":\")[-1])\n", - " score2 = scores[score_index+2].replace('\\n', '').split(\" \")[-2]\n", - " score2 = float(score2)*scalar\n", - " count2 = int(scores[score_index+2].replace('\\n', '').split(\" \")[-3].split(\":\")[-1])\n", - " if type is None:\n", - " count = float(count0+count1+count2)\n", - " score = score0*count0/count + score1*count1/count + score2*count2/count\n", - " #print(count0,count1,count2)\n", - " elif type==0:\n", - " score = score0\n", - " elif type==1:\n", - " score = score1\n", - " elif type==2:\n", - " score = score2\n", - " return score\n", - "\n", - "output_dir = \"tmp\" # please set your output directory here\n", - "tarscore2index1={\n", - " \"font_score\": (1,100,\"a\"),\n", - " \"font_color_score\": (25,1,\"a\"),\n", - " \"font_capitalize_score\": (9,100,\"a\"),\n", - " \"font_align_score\": (5,100,\"a\"),\n", - " \"font_size_score\": (13,1,\"a\"),\n", - " \"font_angle_score\": (19,1,\"b\"),\n", - " \"font_letter_score\": (16,1,\"b\"),\n", - " \"font_lline_height_score\": (22,1,\"c\"),\n", - "}\n", - "for att, (index,k, type) in tarscore2index1.items():\n", - "\n", - " font_score = get_score(output_dir, index,k)\n", - " print(font_score)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.1" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 30a21d2..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,50 +0,0 @@ -[tool.pysen] -version = "0.9" - -[tool.pysen.lint] -enable_black = true -enable_flake8 = true -enable_isort = true -enable_mypy = true -mypy_preset = "strict" -line_length = 88 -[tool.pysen.lint.source] - excludes = [".venv/",".git/", ".pytest_cache/", ".python-version/","data","tmp_data/"] -[[tool.pysen.lint.mypy_targets]] - paths = ["."] - -[tool.poetry] -name = "typography-generation" -version = "0.1.0" -description = "" -authors = ["shimoda-uec "] -readme = "README.md" -packages = [{include = "typography_generation", from = "src"}] - - -[tool.poetry.dependencies] -python = "^3.9" -skia-python = "^87.5" -einops = "^0.6.1" -hydra-core = "^1.3.2" -logzero = "^1.7.0" -datasets = "^2.12.0" -torch = "^1.13" -scikit-learn = "^1.0" -pytest = "^7.3.1" -pillow = "9.0.1" -matplotlib = "3.5" -transformers = "4.30.2" -openpyxl = "^3.1.2" -tensorboard = "^2.14.1" -gcsfs = "^2023.9.2" -seam-carving = "^1.1.0" - - -[tool.poetry.group.dev.dependencies] -jupyter = "^1.0.0" -notebook = "^6.5.4" - -[build-system] -requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" diff --git a/src/typography_generation/__init__.py b/src/typography_generation/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/src/typography_generation/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/typography_generation/__main__.py b/src/typography_generation/__main__.py deleted file mode 100644 index 6303516..0000000 --- a/src/typography_generation/__main__.py +++ /dev/null @@ -1,415 +0,0 @@ -import argparse -import logging -import os -from typing import Any, Dict, Tuple - -import logzero -import torch -from logzero import logger -from typography_generation.config.config_args_util import ( - args2add_data_inputs, - get_global_config, - get_global_config_input, - get_model_config_input, - get_prefix_lists, - get_sampling_config, - get_train_config_input, -) -from typography_generation.config.default import ( - get_datapreprocess_config, - get_font_config, -) -from typography_generation.io.build_dataset import ( - build_test_dataset, - build_train_dataset, -) -from typography_generation.model.model import create_model - -from typography_generation.preprocess.map_features import map_features - -from typography_generation.tools.evaluator import Evaluator - -from typography_generation.tools.sampler import Sampler -from typography_generation.tools.structure_preserved_sampler import ( - StructurePreservedSampler, -) -from typography_generation.tools.train import Trainer - - -def get_save_dir(job_dir: str) -> str: - save_dir = os.path.join(job_dir, "logs") - os.makedirs(save_dir, exist_ok=True) - return save_dir - - -def make_logfile(job_dir: str, debug: bool = False) -> None: - if debug is True: - logzero.loglevel(logging.DEBUG) - else: - logzero.loglevel(logging.INFO) - - os.makedirs(job_dir, exist_ok=True) - file_name = f"{job_dir}/log.log" - logzero.logfile(file_name) - - -def train(args: Any) -> None: - logger.info("training") - data_dir = args.datadir - global_config_input = get_global_config_input(data_dir, args) - config = get_global_config(**global_config_input) - gpu = args.gpu - model_name, model_kwargs = get_model_config_input(config) - - logger.info("model creation") - model = create_model( - model_name, - **model_kwargs, - ) - - logger.info(f"log file location {args.jobdir}/log.log") - make_logfile(args.jobdir, args.debug) - save_dir = get_save_dir(args.jobdir) - logger.info(f"save_dir {save_dir}") - prefix_list_object = get_prefix_lists(config) - font_config = get_font_config(config) - prediction_config_element = config.text_element_prediction_attribute_config - - train_kwargs = get_train_config_input(config, args.debug) - logger.info(f"build trainer") - dataset, dataset_val = build_train_dataset( - data_dir, - prefix_list_object, - font_config, - use_extended_dataset=args.use_extended_dataset, - debug=args.debug, - ) - - model_trainer = Trainer( - model, - gpu, - save_dir, - dataset, - dataset_val, - prefix_list_object, - prediction_config_element, - **train_kwargs, - ) - logger.info("start training") - model_trainer.train_model() - - -def train_eval(args: Any) -> None: - logger.info("training") - data_dir = args.datadir - global_config_input = get_global_config_input(data_dir, args) - config = get_global_config(**global_config_input) - gpu = args.gpu - model_name, model_kwargs = get_model_config_input(config) - - logger.info("model creation") - model = create_model( - model_name, - **model_kwargs, - ) - - logger.info(f"log file location {args.jobdir}/log.log") - make_logfile(args.jobdir, args.debug) - save_dir = get_save_dir(args.jobdir) - logger.info(f"save_dir {save_dir}") - prefix_list_object = get_prefix_lists(config) - font_config = get_font_config(config) - prediction_config_element = config.text_element_prediction_attribute_config - - train_kwargs = get_train_config_input(config, args.debug) - logger.info(f"build trainer") - - dataset, dataset_val = build_train_dataset( - data_dir, - prefix_list_object, - font_config, - use_extended_dataset=args.use_extended_dataset, - debug=args.debug, - ) - - model_trainer = Trainer( - model, - gpu, - save_dir, - dataset, - dataset_val, - prefix_list_object, - prediction_config_element, - **train_kwargs, - ) - logger.info("start training") - model_trainer.train_model() - - dataset = build_test_dataset( - data_dir, - prefix_list_object, - font_config, - use_extended_dataset=args.use_extended_dataset, - debug=args.debug, - ) - - evaluator = Evaluator( - model, - gpu, - save_dir, - dataset, - prefix_list_object, - debug=args.debug, - ) - logger.info("start evaluation") - evaluator.eval_model() - - -# def _train_font_embedding(args: Any) -> None: -# train_font_embedding(args.datadir, args.jobdir, args.gpu) - - -def loadweight(weight_file: Any, gpu: bool, model: Any) -> Any: - if weight_file == "": - pass - else: - if gpu is False: - state_dict = torch.load(weight_file, map_location=torch.device("cpu")) - else: - state_dict = torch.load(weight_file) - model.load_state_dict(state_dict) - return model - - -def sample(args: Any) -> None: - data_dir = args.datadir - global_config_input = get_global_config_input(data_dir, args) - config = get_global_config(**global_config_input) - gpu = args.gpu - model_name, model_kwargs = get_model_config_input(config) - - logger.info("model creation") - model = create_model( - model_name, - **model_kwargs, - ) - weight = args.weight - model = loadweight(weight, gpu, model) - - logger.info(f"log file location {args.jobdir}/log.log") - make_logfile(args.jobdir, args.debug) - save_dir = get_save_dir(args.jobdir) - logger.info(f"save_dir {save_dir}") - prefix_list_object = get_prefix_lists(config) - font_config = get_font_config(config) - sampling_config = get_sampling_config(config) - - dataset = build_test_dataset( - data_dir, - prefix_list_object, - font_config, - use_extended_dataset=args.use_extended_dataset, - debug=args.debug, - ) - - sampler = Sampler( - model, - gpu, - save_dir, - dataset, - prefix_list_object, - sampling_config, - debug=args.debug, - ) - logger.info("start sampling") - sampler.sample() - - -def structure_preserved_sample(args: Any) -> None: - data_dir = args.datadir - global_config_input = get_global_config_input(data_dir, args) - config = get_global_config(**global_config_input) - gpu = args.gpu - model_name, model_kwargs = get_model_config_input(config) - - logger.info("model creation") - model = create_model( - model_name, - **model_kwargs, - ) - weight = args.weight - model = loadweight(weight, gpu, model) - - logger.info(f"log file location {args.jobdir}/log.log") - make_logfile(args.jobdir, args.debug) - save_dir = get_save_dir(args.jobdir) - logger.info(f"save_dir {save_dir}") - prefix_list_object = get_prefix_lists(config) - font_config = get_font_config(config) - sampling_config = get_sampling_config(config) - - dataset = build_test_dataset( - data_dir, - prefix_list_object, - font_config, - use_extended_dataset=args.use_extended_dataset, - debug=args.debug, - ) - - sampler = StructurePreservedSampler( - model, - gpu, - save_dir, - dataset, - prefix_list_object, - sampling_config, - debug=args.debug, - ) - logger.info("start sampling") - sampler.sample() - - -def evaluation_pattern(args: Any, prefix: str, evaluation_class: Any) -> None: - logger.info(f"{prefix}") - data_dir = args.datadir - global_config_input = get_global_config_input(data_dir, args) - config = get_global_config(**global_config_input) - gpu = args.gpu - model_name, model_kwargs = get_model_config_input(config) - - logger.info("model creation") - model = create_model( - model_name, - **model_kwargs, - ) - weight = args.weight - model = loadweight(weight, gpu, model) - - logger.info(f"log file location {args.jobdir}/log.log") - make_logfile(args.jobdir, args.debug) - save_dir = get_save_dir(args.jobdir) - logger.info(f"save_dir {save_dir}") - prefix_list_object = get_prefix_lists(config) - font_config = get_font_config(config) - - dataset = build_test_dataset( - data_dir, - prefix_list_object, - font_config, - use_extended_dataset=args.use_extended_dataset, - debug=args.debug, - ) - evaluator = evaluation_class( - model, - gpu, - save_dir, - dataset, - prefix_list_object, - dataset_division="test", - debug=args.debug, - ) - logger.info("start evaluation") - evaluator.eval_model() - - -def evaluation(args: Any) -> None: - evaluation_pattern(args, "evaluation", Evaluator) - - -def _map_features(args: Any) -> None: - logger.info(f"map_features") - map_features(args.datadir) - - -COMMANDS = { - "train": train, - "train_evaluation": train_eval, - "sample": sample, - "structure_preserved_sample": structure_preserved_sample, - "evaluation": evaluation, - "map_features": _map_features, -} - - -if __name__ == "__main__": - logger.info("start") - parser = argparse.ArgumentParser() - parser.add_argument("command", help="job option") - parser.add_argument( - "--configname", - type=str, - default="bart", - help="config option", - ) - parser.add_argument( - "--testconfigname", - type=str, - default="test_config", - help="test config option", - ) - parser.add_argument( - "--modelconfigname", - type=str, - default="model_config", - help="test config option", - ) - parser.add_argument( - "--canvasembeddingflagconfigname", - type=str, - default="canvas_embedding_flag_config", - help="canvas embedding flag config option", - ) - parser.add_argument( - "--elementembeddingflagconfigname", - type=str, - default="text_element_embedding_flag_config", - help="element embedding flag config option", - ) - parser.add_argument( - "--elementpredictionflagconfigname", - type=str, - default="text_element_prediction_flag_config", - help="element prediction flag config option", - ) - parser.add_argument( - "--datadir", - type=str, - default="data", - help="data location", - ) - parser.add_argument( - "--jobdir", - type=str, - default=".", - help="results location", - ) - parser.add_argument( - "--job-dir", - type=str, - default=".", - help="dummy", - ) - parser.add_argument( - "--weight", - type=str, - default="", - help="weight file location", - ) - parser.add_argument( - "--gpu", - action="store_true", - help="gpu option", - ) - parser.add_argument( - "--use_extended_dataset", - action="store_true", - help="dataset option", - ) - parser.add_argument( - "--debug", - action="store_true", - help="debug option", - ) - args = parser.parse_args() - module = COMMANDS[args.command] - module(args) diff --git a/src/typography_generation/config/__init__.py b/src/typography_generation/config/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/src/typography_generation/config/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/typography_generation/config/attribute_config.py b/src/typography_generation/config/attribute_config.py deleted file mode 100644 index 8e15c20..0000000 --- a/src/typography_generation/config/attribute_config.py +++ /dev/null @@ -1,533 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Union - - -@dataclass -class EmbeddingConfig: - flag: bool = True - inp_space: int = 256 - input_prefix: str = "prefix" - emb_layer: Union[str, None] = "nn.Embedding" - emb_layer_kwargs: Any = None - specific_build: Union[str, None] = None - specific_func: Union[str, None] = None - - -@dataclass -class PredictionConfig: - flag: bool = True - out_dim: int = 256 - layer: Union[str, None] = "nn.Linear" - loss_type: str = "cre" - loss_weight: float = 1.0 - ignore_label: int = -1 - decode_format: str = "cl" - att_type: str = "semantic" - - -@dataclass -class TextElementContextPredictionAttributeConfig: - text_font: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_font}", - "${data_config.font_num}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "semantic", - ) - text_font_emb: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_font_emb}", - "${data_config.font_emb_dim}", - "nn.Linear", - "mfc_gan", - 1.0, - -10000, - "emb", - "semantic", - ) - text_font_size: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_font_size}", - "${data_config.small_spatio_bin}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "geometry", - ) - text_font_size_raw: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_font_size_raw}", - "1", - "nn.Linear", - "l1", - 1.0, - -1, - "scalar", - "geometry", - ) - text_font_color: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_font_color}", - "${data_config.color_bin}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "semantic", - ) - text_angle: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_angle}", - "${data_config.small_spatio_bin}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "geometry", - ) - text_letter_spacing: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_letter_spacing}", - "${data_config.small_spatio_bin}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "geometry", - ) - text_line_height_scale: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_line_height_scale}", - "${data_config.small_spatio_bin}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "geometry", - ) - text_line_height_size: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_line_height_size}", - "${data_config.small_spatio_bin}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "geometry", - ) - text_capitalize: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_capitalize}", - 2, - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "semantic", - ) - text_align_type: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_align_type}", - 3, - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "semantic", - ) - text_center_y: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_center_y}", - "${data_config.large_spatio_bin}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "geometry", - ) - text_center_x: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_center_x}", - "${data_config.large_spatio_bin}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "geometry", - ) - text_height: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_height}", - "${data_config.small_spatio_bin}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "geometry", - ) - text_width: PredictionConfig = PredictionConfig( - "${text_element_prediction_flag.text_width}", - "${data_config.large_spatio_bin}", - "nn.Linear", - "cre", - 1.0, - -1, - "cl", - "geometry", - ) - - -@dataclass -class TextElementContextEmbeddingAttributeConfig: - text_font: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_font}", - "${data_config.font_num}", - "text_font", - "nn.Embedding", - { - "num_embeddings": "${data_config.font_num}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_font_size: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_font_size}", - "${data_config.small_spatio_bin}", - "text_font_size", - "nn.Embedding", - { - "num_embeddings": "${data_config.small_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_font_size_raw: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_font_size_raw}", - "${data_config.small_spatio_bin}", - "text_font_size", - "nn.Embedding", - { - "num_embeddings": "${data_config.small_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_font_color: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_font_color}", - "${data_config.color_bin}", - "text_font_color", - "nn.Embedding", - { - "num_embeddings": "${data_config.color_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_line_count: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_line_count}", - "${data_config.max_text_line_count}", - "text_line_count", - "nn.Embedding", - { - "num_embeddings": "${data_config.max_text_line_count}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_char_count: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_char_count}", - "${data_config.max_text_char_count}", - "text_char_count", - "nn.Embedding", - { - "num_embeddings": "${data_config.max_text_char_count}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_height: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_height}", - "${data_config.large_spatio_bin}", - "text_height", - "nn.Embedding", - { - "num_embeddings": "${data_config.large_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_width: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_width}", - "${data_config.large_spatio_bin}", - "text_width", - "nn.Embedding", - { - "num_embeddings": "${data_config.large_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_top: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_top}", - "${data_config.large_spatio_bin}", - "text_top", - "nn.Embedding", - { - "num_embeddings": "${data_config.large_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_left: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_left}", - "${data_config.large_spatio_bin}", - "text_left", - "nn.Embedding", - { - "num_embeddings": "${data_config.large_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_center_y: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_center_y}", - "${data_config.large_spatio_bin}", - "text_center_y", - "nn.Embedding", - { - "num_embeddings": "${data_config.large_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_center_x: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_center_x}", - "${data_config.large_spatio_bin}", - "text_center_x", - "nn.Embedding", - { - "num_embeddings": "${data_config.large_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_align_type: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_align_type}", - "3", - "text_align_type", - "nn.Embedding", - { - "num_embeddings": "3", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_angle: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_angle}", - "${data_config.small_spatio_bin}", - "text_angle", - "nn.Embedding", - { - "num_embeddings": "${data_config.small_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_letter_spacing: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_letter_spacing}", - "${data_config.small_spatio_bin}", - "text_letter_spacing", - "nn.Embedding", - { - "num_embeddings": "${data_config.small_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_line_height_scale: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_line_height_scale}", - "${data_config.small_spatio_bin}", - "text_line_height_scale", - "nn.Embedding", - { - "num_embeddings": "${data_config.small_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_line_height_size: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_line_height_size}", - "${data_config.small_spatio_bin}", - "text_line_height_size", - "nn.Embedding", - { - "num_embeddings": "${data_config.small_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - - text_capitalize: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_capitalize}", - "2", - "text_capitalize", - "nn.Embedding", - { - "num_embeddings": "2", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - text_emb: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_emb}", - "${model_config.clip_dim}", - "text_emb", - "nn.Linear", - { - "in_features": "${model_config.clip_dim}", - "out_features": "${model_config.d_model}", - }, - None, - None, - ) - text_local_img_emb: EmbeddingConfig = EmbeddingConfig( - "${text_element_embedding_flag.text_local_img_emb}", - "${model_config.clip_dim}", - "text_local_img_emb", - "nn.Linear", - { - "in_features": "${model_config.clip_dim}", - "out_features": "${model_config.d_model}", - }, - None, - None, - ) - - -@dataclass -class CanvasContextEmbeddingAttributeConfig: - canvas_bg_img: EmbeddingConfig = EmbeddingConfig( - "${canvas_embedding_flag.canvas_bg_img}", - 2048, - "canvas_bg_img", - None, - None, - "build_resnet_feat_extractor", - "get_feat", - ) - canvas_aspect_ratio: EmbeddingConfig = EmbeddingConfig( - "${canvas_embedding_flag.canvas_aspect_ratio}", - "${data_config.small_spatio_bin}", - "canvas_aspect_ratio", - "nn.Embedding", - { - "num_embeddings": "${data_config.small_spatio_bin}", - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - canvas_text_num: EmbeddingConfig = EmbeddingConfig( - "${canvas_embedding_flag.canvas_text_num}", - 50, - "canvas_text_num", - "nn.Embedding", - { - "num_embeddings": 50, - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - canvas_group: EmbeddingConfig = EmbeddingConfig( - "${canvas_embedding_flag.canvas_group}", - 6, - "canvas_group", - "nn.Embedding", - { - "num_embeddings": 6, - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - canvas_format: EmbeddingConfig = EmbeddingConfig( - "${canvas_embedding_flag.canvas_format}", - 67, - "canvas_format", - "nn.Embedding", - { - "num_embeddings": 67, - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - canvas_category: EmbeddingConfig = EmbeddingConfig( - "${canvas_embedding_flag.canvas_category}", - 23, - "canvas_category", - "nn.Embedding", - { - "num_embeddings": 23, - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - canvas_height: EmbeddingConfig = EmbeddingConfig( - "${canvas_embedding_flag.canvas_height}", - 46, - "canvas_height", - "nn.Embedding", - { - "num_embeddings": 46, - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - canvas_width: EmbeddingConfig = EmbeddingConfig( - "${canvas_embedding_flag.canvas_width}", - 41, - "canvas_width", - "nn.Embedding", - { - "num_embeddings": 41, - "embedding_dim": "${model_config.d_model}", - }, - None, - None, - ) - canvas_bg_img_emb: EmbeddingConfig = EmbeddingConfig( - "${canvas_embedding_flag.canvas_bg_img_emb}", - "${model_config.clip_dim}", - "canvas_bg_img_emb", - "nn.Linear", - { - "in_features": "${model_config.clip_dim}", - "out_features": "${model_config.d_model}", - }, - None, - "canvas_bg_img_emb_layer", - ) diff --git a/src/typography_generation/config/base_config_object.py b/src/typography_generation/config/base_config_object.py deleted file mode 100644 index ece0512..0000000 --- a/src/typography_generation/config/base_config_object.py +++ /dev/null @@ -1,147 +0,0 @@ -from dataclasses import dataclass -from typography_generation.config.attribute_config import ( - CanvasContextEmbeddingAttributeConfig, - TextElementContextEmbeddingAttributeConfig, - TextElementContextPredictionAttributeConfig, -) - - -@dataclass -class MetaInfo: - model_name: str = "bart" - dataset: str = "crello" - data_dir: str = "crello" - - -@dataclass -class DataConfig: - font_num: int = 288 - large_spatio_bin: int = 64 - small_spatio_bin: int = 16 - color_bin: int = 64 - max_text_char_count: int = 50 - max_text_line_count: int = 50 - font_emb_type: str = "label" - font_emb_weight: float = 1.0 - font_emb_dim: int = 40 - font_emb_name: str = "mfc" - order_type: str = "raster_scan_order" - seq_length: int = 50 - - -@dataclass -class ModelConfig: - d_model: int = 256 - num_encoder_layers: int = 4 - num_decoder_layers: int = 4 - n_head: int = 8 - dropout: float = 0.1 - bert_dim: int = 768 - clip_dim: int = 512 - mlp_dim: int = 3584 - mfc_dim: int = 3584 - bypass: bool = True - seq_length: int = 50 - std_ratio: float = 1.0 - - -@dataclass -class TrainConfig: - epochs: int = 31 - save_epoch: int = 5 - batch_size: int = 32 - num_worker: int = 2 - train_only: bool = False - show_interval: int = 100 - learning_rate: float = 0.0002 - weight_decay: float = 0.01 - optimizer: str = "adam" - - -@dataclass -class TestConfig: - autoregressive_prediction: bool = False - sampling_mode: str = "topp" - sampling_param: float = 0 - sampling_param_geometry: float = 0 - sampling_param_semantic: float = 0 - sampling_num: int = 10 - - -@dataclass -class TextElementEmbeddingFlag: - text_font: bool = True - text_font_size: bool = True - text_font_size_raw: bool = False - text_font_color: bool = True - text_line_count: bool = True - text_char_count: bool = True - text_height: bool = False - text_width: bool = False - text_top: bool = False - text_left: bool = False - text_center_y: bool = True - text_center_x: bool = True - text_align_type: bool = False - text_angle: bool = False - text_letter_spacing: bool = False - text_line_height_scale: bool = False - text_line_height_size: bool = False - text_capitalize: bool = False - text_emb: bool = True - text_local_img_emb: bool = True - - -@dataclass -class CanvasEmbeddingFlag: - canvas_bg_img: bool = False - canvas_bg_img_emb: bool = True - canvas_aspect_ratio: bool = True - canvas_text_num: bool = True - canvas_group: bool = False - canvas_format: bool = False - canvas_category: bool = False - canvas_width: bool = False - canvas_height: bool = False - - -@dataclass -class TextElementPredictionTargetFlag: - text_font: bool = True - text_font_emb: bool = False - text_font_size: bool = True - text_font_size_raw: bool = False - text_font_color: bool = True - text_angle: bool = True - text_letter_spacing: bool = True - text_line_height_scale: bool = True - text_line_height_size: bool = False - text_capitalize: bool = True - text_align_type: bool = True - text_center_y: bool = False - text_center_x: bool = False - text_height: bool = False - text_width: bool = False - - -@dataclass -class GlobalConfig: - meta_info: MetaInfo = MetaInfo() - model_config: ModelConfig = ModelConfig() - data_config: DataConfig = DataConfig() - train_config: TrainConfig = TrainConfig() - test_config: TestConfig = TestConfig() - text_element_embedding_flag: TextElementEmbeddingFlag = TextElementEmbeddingFlag() - canvas_embedding_flag: CanvasEmbeddingFlag = CanvasEmbeddingFlag() - text_element_prediction_flag: TextElementPredictionTargetFlag = ( - TextElementPredictionTargetFlag() - ) - canvas_embedding_attribute_config: CanvasContextEmbeddingAttributeConfig = ( - CanvasContextEmbeddingAttributeConfig() - ) - text_element_embedding_attribute_config: TextElementContextEmbeddingAttributeConfig = ( - TextElementContextEmbeddingAttributeConfig() - ) - text_element_prediction_attribute_config: TextElementContextPredictionAttributeConfig = ( - TextElementContextPredictionAttributeConfig() - ) diff --git a/src/typography_generation/config/config_args_util.py b/src/typography_generation/config/config_args_util.py deleted file mode 100644 index fbe8adb..0000000 --- a/src/typography_generation/config/config_args_util.py +++ /dev/null @@ -1,199 +0,0 @@ -from typing import Any, Dict, Tuple - -import omegaconf -from logzero import logger - -from typography_generation.config.base_config_object import GlobalConfig -from typography_generation.config.default import ( - build_config, - get_bindata, - get_datapreprocess_config, - get_font_config, - get_model_input_prefix_list, - get_target_prefix_list, -) -from typography_generation.io.data_object import ( - BinsData, - DataPreprocessConfig, - FontConfig, - PrefixListObject, - SamplingConfig, -) - - -def args2add_data_inputs(args: Any) -> Tuple: - add_data_inputs = {} - add_data_inputs["data_dir"] = args.datadir - add_data_inputs["global_config_input"] = get_global_config_input(args.datadir, args) - add_data_inputs["gpu"] = args.gpu - add_data_inputs["weight"] = args.weight - add_data_inputs["jobdir"] = args.jobdir - add_data_inputs["debug"] = args.debug - return add_data_inputs - - -def get_conf(yaml_file: str) -> omegaconf: - logger.info(f"load {yaml_file}") - conf = omegaconf.OmegaConf.load(yaml_file) - return conf - - -def get_global_config( - data_dir: str, - model_name: str, - test_config_name: str = "test_config", - model_config_name: str = "model_config", - elementembeddingflag_config_name: str = "text_element_embedding_flag_config", - elementpredictionflag_config_name: str = "text_element_prediction_flag_config", - canvasembeddingflag_config_name: str = "canvas_embedding_flag_config", - elementembeddingatt_config_name: str = "text_element_embedding_attribute_config", - elementpredictionatt_config_name: str = "text_element_prediction_attribute_config", - canvasembeddingatt_config_name: str = "canvas_embedding_attribute_config", -) -> GlobalConfig: - metainfo_conf = get_conf(f"{data_dir}/config/{model_name}/metainfo.yaml") - data_conf = get_conf(f"{data_dir}/config/{model_name}/data_config.yaml") - model_conf = get_conf(f"{data_dir}/config/{model_name}/{model_config_name}.yaml") - train_conf = get_conf(f"{data_dir}/config/{model_name}/train_config.yaml") - test_conf = get_conf(f"{data_dir}/config/{model_name}/{test_config_name}.yaml") - text_element_emb_flag_conf = get_conf( - f"{data_dir}/config/{model_name}/{elementembeddingflag_config_name}.yaml" - ) - canvas_emb_flag_conf = get_conf( - f"{data_dir}/config/{model_name}/{canvasembeddingflag_config_name}.yaml" - ) - text_element_pred_flag_conf = get_conf( - f"{data_dir}/config/{model_name}/{elementpredictionflag_config_name}.yaml" - ) - text_element_emb_attribute_conf = get_conf( - f"{data_dir}/config/{model_name}/{elementembeddingatt_config_name}.yaml" - ) - canvas_emb_attribute_conf = get_conf( - f"{data_dir}/config/{model_name}/{canvasembeddingatt_config_name}.yaml" - ) - text_element_pred_attribute_conf = get_conf( - f"{data_dir}/config/{model_name}/{elementpredictionatt_config_name}.yaml" - ) - config = build_config( - metainfo_conf, - data_conf, - model_conf, - train_conf, - test_conf, - text_element_emb_flag_conf, - canvas_emb_flag_conf, - text_element_pred_flag_conf, - text_element_emb_attribute_conf, - canvas_emb_attribute_conf, - text_element_pred_attribute_conf, - ) - return config - - -def get_data_config( - config: GlobalConfig, -) -> Tuple[BinsData, FontConfig, DataPreprocessConfig]: - bin_data = get_bindata(config) - font_config = get_font_config(config) - data_preprocess_config = get_datapreprocess_config(config) - return bin_data, font_config, data_preprocess_config - - -def get_sampling_config( - config: GlobalConfig, -) -> SamplingConfig: - sampling_param = config.test_config.sampling_param - sampling_param_geometry = config.test_config.sampling_param_geometry - sampling_param_semantic = config.test_config.sampling_param_semantic - sampling_num = config.test_config.sampling_num - return SamplingConfig( - sampling_param, sampling_param_geometry, sampling_param_semantic, sampling_num - ) - - -def get_global_config_input(data_dir: str, args: Any) -> Dict: - global_config_input = {} - global_config_input["data_dir"] = data_dir - global_config_input["model_name"] = args.configname - global_config_input["test_config_name"] = args.testconfigname - global_config_input["model_config_name"] = args.modelconfigname - global_config_input[ - "elementembeddingflag_config_name" - ] = args.elementembeddingflagconfigname - global_config_input[ - "canvasembeddingflag_config_name" - ] = args.canvasembeddingflagconfigname - global_config_input[ - "elementpredictionflag_config_name" - ] = args.elementpredictionflagconfigname - - return global_config_input - - -def get_model_config_input(config: GlobalConfig) -> Tuple[str, Dict]: - model_kwargs = {} - - model_kwargs["prefix_list_element"] = get_target_prefix_list( - config.text_element_embedding_attribute_config - ) - model_kwargs["prefix_list_canvas"] = get_target_prefix_list( - config.canvas_embedding_attribute_config - ) - model_kwargs["prefix_list_target"] = get_target_prefix_list( - config.text_element_prediction_attribute_config - ) - model_kwargs[ - "embedding_config_element" - ] = config.text_element_embedding_attribute_config - model_kwargs["embedding_config_canvas"] = config.canvas_embedding_attribute_config - model_kwargs[ - "prediction_config_element" - ] = config.text_element_prediction_attribute_config - model_kwargs["d_model"] = config.model_config.d_model - model_kwargs["n_head"] = config.model_config.n_head - model_kwargs["dropout"] = config.model_config.dropout - model_kwargs["num_encoder_layers"] = config.model_config.num_encoder_layers - model_kwargs["num_decoder_layers"] = config.model_config.num_decoder_layers - model_kwargs["seq_length"] = config.model_config.seq_length - model_kwargs["std_ratio"] = config.model_config.std_ratio - model_kwargs["bypass"] = config.model_config.bypass - - model_name = config.meta_info.model_name - return model_name, model_kwargs - - -def get_train_config_input(config: GlobalConfig, debug: bool) -> Dict: - train_kwargs = {} - if debug is True: - train_kwargs["epochs"] = 1 - train_kwargs["batch_size"] = 2 - else: - train_kwargs["epochs"] = config.train_config.epochs - train_kwargs["batch_size"] = config.train_config.batch_size - train_kwargs["save_epoch"] = config.train_config.save_epoch - train_kwargs["num_worker"] = config.train_config.num_worker - train_kwargs["learning_rate"] = config.train_config.learning_rate - train_kwargs["show_interval"] = config.train_config.show_interval - train_kwargs["optimizer_option"] = config.train_config.optimizer - train_kwargs["weight_decay"] = config.train_config.weight_decay - train_kwargs["debug"] = debug - return train_kwargs - - -def get_prefix_lists(config: GlobalConfig) -> PrefixListObject: - prefix_list_textelement = get_target_prefix_list( - config.text_element_embedding_attribute_config - ) - prefix_list_canvas = get_target_prefix_list( - config.canvas_embedding_attribute_config - ) - prefix_list_model_input = get_model_input_prefix_list(config) - prefix_list_target = get_target_prefix_list( - config.text_element_prediction_attribute_config - ) - prefix_list_object = PrefixListObject( - prefix_list_textelement, - prefix_list_canvas, - prefix_list_model_input, - prefix_list_target, - ) - return prefix_list_object diff --git a/src/typography_generation/config/default.py b/src/typography_generation/config/default.py deleted file mode 100644 index c10b5f1..0000000 --- a/src/typography_generation/config/default.py +++ /dev/null @@ -1,128 +0,0 @@ -from typing import Any, List, Union -from logzero import logger -from omegaconf import OmegaConf -from typography_generation.config.attribute_config import ( - CanvasContextEmbeddingAttributeConfig, - TextElementContextEmbeddingAttributeConfig, - TextElementContextPredictionAttributeConfig, -) -from typography_generation.config.base_config_object import ( - CanvasEmbeddingFlag, - DataConfig, - GlobalConfig, - MetaInfo, - ModelConfig, - TestConfig, - TextElementEmbeddingFlag, - TextElementPredictionTargetFlag, - TrainConfig, -) -from typography_generation.io.data_object import ( - BinsData, - DataPreprocessConfig, - FontConfig, -) - - -def get_bindata(config: GlobalConfig) -> BinsData: - bin_data = BinsData( - config.data_config.color_bin, - config.data_config.small_spatio_bin, - config.data_config.large_spatio_bin, - ) - return bin_data - - -def get_font_config(config: GlobalConfig) -> FontConfig: - font_config = FontConfig( - config.data_config.font_num, - config.data_config.font_emb_type, - config.data_config.font_emb_weight, - config.data_config.font_emb_dim, - config.data_config.font_emb_name, - ) - return font_config - - -def get_datapreprocess_config(config: GlobalConfig) -> DataPreprocessConfig: - datapreprocess_config = DataPreprocessConfig( - config.data_config.order_type, - config.data_config.seq_length, - ) - return datapreprocess_config - - -def get_target_prefix_list(tar_cls: Any) -> List: - all_text_prefix_list = dir(tar_cls) - target_prefix_list = [] - for prefix in all_text_prefix_list: - elm = getattr(tar_cls, prefix) - if elm.flag is True: - target_prefix_list.append(prefix) - return target_prefix_list - - -def get_model_input_prefix_list(conf: GlobalConfig) -> List: - input_prefix_list = [] - input_prefix_list += get_target_prefix_list( - conf.text_element_embedding_attribute_config - ) - input_prefix_list += get_target_prefix_list(conf.canvas_embedding_attribute_config) - input_prefix_list += get_target_prefix_list( - conf.text_element_prediction_attribute_config - ) - if "canvas_text_num" in input_prefix_list: - pass - else: - input_prefix_list.append("canvas_text_num") - return list(set(input_prefix_list)) - - -def plusone_num_embeddings(config: GlobalConfig) -> None: - prefix_lists = dir(config.text_element_embedding_attribute_config) - for prefix in prefix_lists: - elm = getattr(config.text_element_embedding_attribute_config, f"{prefix}") - if elm.emb_layer == "nn.Embedding": - elm.emb_layer_kwargs["num_embeddings"] = ( - int(elm.emb_layer_kwargs["num_embeddings"]) + 1 - ) - - -def show_class_attributes(_class: Any): - class_dict = _class.__dict__["_content"] - logger.info(f"{class_dict}") - - -def build_config( - metainfo_conf: MetaInfo, - data_conf: DataConfig, - model_conf: ModelConfig, - train_conf: TrainConfig, - test_conf: TestConfig, - text_element_emb_flag_conf: TextElementEmbeddingFlag, - canvas_emb_flag_conf: CanvasEmbeddingFlag, - text_element_pred_flag_conf: TextElementPredictionTargetFlag, - text_element_emb_attribute_conf: TextElementContextEmbeddingAttributeConfig, - canvas_emb_attribute_conf: CanvasContextEmbeddingAttributeConfig, - text_element_pred_attribute_conf: TextElementContextPredictionAttributeConfig, -) -> Union[Any, GlobalConfig]: - conf = OmegaConf.structured(GlobalConfig) - show_class_attributes(text_element_emb_flag_conf) - show_class_attributes(canvas_emb_flag_conf) - show_class_attributes(text_element_pred_flag_conf) - conf = OmegaConf.merge( - conf, - metainfo_conf, - data_conf, - model_conf, - train_conf, - test_conf, - text_element_emb_flag_conf, - canvas_emb_flag_conf, - text_element_pred_flag_conf, - text_element_emb_attribute_conf, - canvas_emb_attribute_conf, - text_element_pred_attribute_conf, - ) - plusone_num_embeddings(conf) - return conf diff --git a/src/typography_generation/io/__init__.py b/src/typography_generation/io/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/src/typography_generation/io/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/typography_generation/io/build_dataset.py b/src/typography_generation/io/build_dataset.py deleted file mode 100644 index 14bacd3..0000000 --- a/src/typography_generation/io/build_dataset.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Tuple -from typography_generation.io.data_loader import CrelloLoader -from typography_generation.io.data_object import ( - FontConfig, - PrefixListObject, -) -from typography_generation.tools.tokenizer import Tokenizer -import datasets -from logzero import logger -import os - - -def build_train_dataset( - data_dir: str, - prefix_list_object: PrefixListObject, - font_config: FontConfig, - use_extended_dataset: bool = True, - dataset_name: str = "crello", - debug: bool = False, -) -> Tuple[CrelloLoader, CrelloLoader]: - tokenizer = Tokenizer(data_dir) - - if dataset_name == "crello": - logger.info("load hugging dataset start") - if use_extended_dataset is True: - _dataset = datasets.load_from_disk( - os.path.join(data_dir, "crello_map_features") - ) - else: - _dataset = datasets.load_dataset("cyberagent/crello", revision="3.1") - logger.info("load hugging dataset done") - dataset = CrelloLoader( - data_dir, - tokenizer, - _dataset["train"], - prefix_list_object, - font_config, - use_extended_dataset=use_extended_dataset, - debug=debug, - ) - dataset_val = CrelloLoader( - data_dir, - tokenizer, - _dataset["validation"], - prefix_list_object, - font_config, - use_extended_dataset=use_extended_dataset, - debug=debug, - ) - else: - raise NotImplementedError() - return dataset, dataset_val - - -def build_test_dataset( - data_dir: str, - prefix_list_object: PrefixListObject, - font_config: FontConfig, - use_extended_dataset: bool = True, - dataset_name: str = "crello", - debug: bool = False, -) -> CrelloLoader: - tokenizer = Tokenizer(data_dir) - - if dataset_name == "crello": - logger.info("load hugging dataset start") - if use_extended_dataset is True: - _dataset = datasets.load_from_disk( - os.path.join(data_dir, "crello_map_features") - ) - else: - _dataset = datasets.load_dataset("cyberagent/crello", revision="3.1") - logger.info("load hugging dataset done") - dataset = CrelloLoader( - data_dir, - tokenizer, - _dataset["test"], - prefix_list_object, - font_config, - use_extended_dataset=use_extended_dataset, - debug=debug, - ) - else: - raise NotImplementedError() - return dataset diff --git a/src/typography_generation/io/crello_util.py b/src/typography_generation/io/crello_util.py deleted file mode 100644 index 85f1e7d..0000000 --- a/src/typography_generation/io/crello_util.py +++ /dev/null @@ -1,490 +0,0 @@ -import math -import os -import pickle -from typing import Any, Dict, List, Tuple -from einops import repeat -import skia -import numpy as np -import PIL -from PIL import Image -import torch -from typography_generation.io.data_object import FontConfig -from typography_generation.tools.tokenizer import Tokenizer - -from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer - -from typography_generation.visualization.renderer_util import ( - get_skia_font, - get_text_actual_width, - get_texts, -) - -fontmgr = skia.FontMgr() - - -class CrelloProcessor: - def __init__( - self, - data_dir: str, - tokenizer: Tokenizer, - dataset: Any, - font_config: FontConfig, - use_extended_dataset: bool = True, - seq_length: int = 50, - ) -> None: - self.data_dir = data_dir - self.tokenizer = tokenizer - self.font_config = font_config - self.dataset = dataset - self.seq_length = seq_length - if font_config is not None: - fn = os.path.join( - self.data_dir, "font_emb", f"{font_config.font_emb_name}.pkl" - ) - self.fontid2fontemb = pickle.load(open(fn, "rb")) - self.use_extended_dataset = use_extended_dataset - if not use_extended_dataset: - fn = os.path.join(data_dir, "font2ttf.pkl") - _font2ttf = pickle.load(open(fn, "rb")) - font2ttf = {} - for key in _font2ttf.keys(): - tmp = _font2ttf[key].split("/data/dataset/crello/")[1] - fn = os.path.join(data_dir, tmp) - font2ttf[key] = fn - self.font2ttf = font2ttf - - fn = os.path.join(data_dir, "svgid2scaleinfo.pkl") - self.svgid2scaleinfo = pickle.load(open(fn, "rb")) - self.fontlabel2fontname = self.dataset.features["font"].feature.int2str - - self.processor = CLIPProcessor.from_pretrained( - "openai/clip-vit-base-patch32" - ) - self.text_tokenizer = CLIPTokenizer.from_pretrained( - "openai/clip-vit-base-patch32" - ) - self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") - # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.device = torch.device("cpu") - self.model.to(self.device) - - def get_canvas_text_num( - self, element_data: dict, **kwargs: Any - ) -> Tuple[int, List]: - text_num = 0 - for k in range(len(element_data["text"])): - if element_data["text"][k] == "": - pass - else: - text_num += 1 - text_num = min(text_num, self.seq_length) - return text_num - - def get_canvas_text_ids( - self, element_data: dict, **kwargs: Any - ) -> Tuple[int, List]: - text_ids = [] - for k in range(len(element_data["text"])): - if element_data["text"][k] == "": - pass - else: - text_ids.append(k) - return text_ids - - def get_canvas_bg_size(self, element_data: dict) -> Tuple[int, int]: - return element_data["canvas_bg_size"] - - def get_scale_box(self, element_data: dict) -> List: - if self.use_extended_dataset: - return tuple(element_data["scale_box"]) - else: - svgid = element_data["id"] - return self.svgid2scaleinfo[svgid] - - def get_text_font(self, element_data: dict, text_index: int, **kwargs: Any) -> int: - font = element_data["font"][text_index] - 1 - return int(font) - - def denorm_text_font(self, val: int, **kwargs: Any) -> int: - val += 1 - return val - - def get_text_font_emb( - self, - element_data: dict, - text_index: int, - **kwargs: Any, - ) -> np.array: - font = element_data["font"][text_index] - font_emb = self.fontid2fontemb[font - 1] - return font_emb - - def raw2token_text_font_emb( - self, - val: np.array, - **kwargs: Any, - ) -> int: - vec = repeat(val, "c -> n c", n=len(self.fontid2fontemb)) - diff = (vec - self.fontid2fontemb) ** 2 - diff = diff.sum(1) - font = int(np.argsort(diff)[0]) - return font - - def denorm_text_font_emb(self, val: int, **kwargs: Any) -> int: - val += 1 - return val - - def get_skia_font( - self, element_data: dict, text_index: int, scaleinfo: Tuple, **kwargs: Any - ) -> int: - font_label = element_data["font"][text_index] - font_name = self.fontlabel2fontname(int(font_label)) - font_name = font_name.replace(" ", "_") - scale_h, _ = scaleinfo - font_skia, _ = get_skia_font( - self.font2ttf, - fontmgr, - element_data, - text_index, - font_name, - scale_h, - ) - return font_skia - - def get_text_font_size( - self, - element_data: dict, - text_index: int, - img_size: Tuple, - scaleinfo: Tuple, - **kwargs: Any, - ) -> float: - fs = element_data["font_size"][text_index] - scale_h, _ = scaleinfo - h, _ = img_size - val = fs * scale_h / h - return float(val) - - def denorm_text_font_size( - self, val: float, img_height: int, scale_h: float, **kwargs: Any - ) -> float: - val = val * img_height / scale_h - return val - - def get_text_font_size_raw( - self, - element_data: dict, - text_index: int, - img_size: Tuple, - scaleinfo: Tuple, - **kwargs: Any, - ) -> float: - val = self.get_text_font_size(element_data, text_index, img_size, scaleinfo) - return val - - def raw2token_text_font_size_raw( - self, - val: float, - **kwargs: Any, - ) -> int: - val = self.tokenizer.tokenize("text_font_size", float(val)) - return val - - def denorm_text_font_size_raw( - self, val: float, img_height: int, scale_h: float, **kwargs: Any - ) -> float: - val = val * img_height / scale_h - return val - - def get_text_font_color( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> Tuple[int, int, int]: - B, G, R = element_data["color"][text_index] - return (R, G, B) - - def get_text_height( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> float: - height = element_data["height"][text_index] - return float(height) - - def denorm_text_height(self, val: float, img_height: int, **kwargs: Any) -> float: - val = val * img_height - return val - - def get_text_width( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> float: - width = element_data["width"][text_index] - return float(width) - - def denorm_text_width(self, val: float, img_width: int, **kwargs: Any) -> float: - val = val * img_width - return val - - def get_text_top(self, element_data: dict, text_index: int, **kwargs: Any) -> float: - top = element_data["top"][text_index] - return float(top) - - def denorm_text_top(self, val: float, img_height: int, **kwargs: Any) -> float: - val = val * img_height - return val - - def get_text_left( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> float: - left = element_data["left"][text_index] - return float(left) - - def denorm_text_left(self, val: float, img_width: int, **kwargs: Any) -> float: - val = val * img_width - return val - - def get_text_actual_width( - self, - element_data: dict, - text_index: int, - texts: List[str], - font_skia: skia.Font, - scaleinfo: Tuple[float, float], - **kwargs: Any, - ) -> float: - _, scale_w = scaleinfo - text_width = get_text_actual_width( - element_data, text_index, texts, font_skia, scale_w - ) - return text_width - - def get_text_center_y( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> float: - if self.use_extended_dataset: - return element_data["text_center_y"][text_index] - else: - text_height = element_data["height"][text_index] - top = element_data["top"][text_index] - center_y = (top + top + text_height) / 2.0 - return float(center_y) - - def get_text_center_x( - self, - element_data: dict, - text_index: int, - scaleinfo: Tuple, - **kwargs: Any, - ) -> float: - if self.use_extended_dataset: - return element_data["text_center_x"][text_index] - else: - left = element_data["left"][text_index] - w = element_data["width"][text_index] - textAlign = element_data["text_align"][text_index] - texts = get_texts(element_data, text_index) - font_skia = self.get_skia_font(element_data, text_index, scaleinfo) - text_actual_width = self.get_text_actual_width( - element_data, text_index, texts, font_skia, scaleinfo - ) - right = left + w - if textAlign == 1: - center_x = (left + right) / 2.0 - elif textAlign == 3: - center_x = right - text_actual_width / 2.0 - elif textAlign == 2: - center_x = left + text_actual_width / 2.0 - return float(center_x) - - def get_text_align_type( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> int: - align_type = element_data["text_align"][text_index] - 1 - return int(align_type) - - def denorm_text_align_type(self, val: int, **kwargs: Any) -> int: - val += 1 - return val - - def get_text_capitalize( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> int: - capitalize = element_data["capitalize"][text_index] - return int(capitalize) - - def get_text_angle( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> float: - angle = element_data["angle"][text_index] - angle = (float(angle) * 180 / math.pi) / 360.0 - angle = math.modf(angle)[0] - return float(angle) - - def denorm_text_angle(self, val: float, **kwargs: Any) -> float: - val = val * 360 / 180.0 * math.pi - return val - - def get_text_letter_spacing( - self, - element_data: dict, - text_index: int, - scaleinfo: Tuple, - img_size: Tuple, - **kwargs: Any, - ) -> float: - _, scale_w = scaleinfo - _, W = img_size - letter_space = element_data["letter_spacing"][text_index] * scale_w / W - return float(letter_space) - - def denorm_text_letter_spacing( - self, val: float, img_width: int, scale_w: float, **kwargs: Any - ) -> float: - val = val * img_width / scale_w - return val - - def get_text_line_height_scale( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> float: - line_height_scale = element_data["line_height"][text_index] - return float(line_height_scale) - - def get_text_char_count( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> int: - text = self.get_text(element_data, text_index) - texts = text.split(os.linesep) - max_char_count = 0 - for t in texts: - max_char_count = max(max_char_count, len(t)) - return min(max_char_count, 50 - 1) - - def get_text_line_count( - self, element_data: dict, text_index: int, **kwargs: Any - ) -> int: - texts = element_data["text"][text_index].split(os.linesep) - line_count = 0 - for t in texts: - if t == "": - pass - else: - line_count += 1 - return min(line_count, 49) - - def get_text(self, element_data: dict, text_index: int) -> str: - text = element_data["text"][text_index] - return str(text) - - def get_canvas_aspect_ratio( - self, element_data: dict, bg_img: Any, **kwargs: Any - ) -> float: - h, w = bg_img.size[1], bg_img.size[0] - ratio = float(h) / float(w) - return ratio - - def get_canvas_group(self, element_data: dict, **kwargs: Any) -> int: - return element_data["group"] - - def get_canvas_format(self, element_data: dict, **kwargs: Any) -> int: - return element_data["format"] - - def get_canvas_category(self, element_data: dict, **kwargs: Any) -> int: - return element_data["category"] - - def get_canvas_width(self, element_data: dict, **kwargs: Any) -> int: - return element_data["canvas_width"] - - def get_canvas_height(self, element_data: dict, **kwargs: Any) -> int: - return element_data["canvas_height"] - - def get_canvas_bg_img_emb( - self, element_data: dict, bg_img: PIL.Image, **kwargs: Any - ) -> np.array: - if self.use_extended_dataset: - return np.array(element_data["canvas_bg_img_emb"]) - else: - inputs = self.processor(images=[bg_img], return_tensors="pt") - inputs["pixel_values"] = inputs["pixel_values"].to(self.device) - image_feature = self.model.get_image_features(**inputs) - return image_feature.data.numpy() - - def get_text_emb( - self, - element_data: dict, - text_index: int, - **kwargs: Any, - ) -> np.array: - if self.use_extended_dataset: - return np.array(element_data["text_emb"][text_index]) - else: - text = element_data["text"][text_index] - inputs = self.text_tokenizer([text], padding=True, return_tensors="pt") - if inputs["input_ids"].shape[1] > 77: - inp = inputs["input_ids"][:, :77] - else: - inp = inputs["input_ids"] - text_features = self.model.get_text_features(inp).data.numpy()[0] - return text_features - - def get_text_local_img( - self, - img: Any, - text_center_y: float, - text_center_x: float, - H: int, - W: int, - ) -> np.array: - text_center_y = text_center_y * H - text_center_x = text_center_x * W - - text_center_y = min(max(text_center_y, 0), H) - text_center_x = min(max(text_center_x, 0), W) - img = img.resize((640, 640)) - img = np.array(img) - local_img_size = 64 * 5 - local_img_size_half = local_img_size // 2 - img_pad = np.zeros((640 + local_img_size, 640 + local_img_size, 3)) - img_pad[ - local_img_size_half : 640 + local_img_size_half, - local_img_size_half : 640 + local_img_size_half, - ] = img - h_rate = 640 / float(H) - w_rate = 640 / float(W) - text_center_y = int(np.round(text_center_y * h_rate + local_img_size_half)) - text_center_x = int(np.round(text_center_x * w_rate + local_img_size_half)) - local_img = img_pad[ - text_center_y - local_img_size_half : text_center_y + local_img_size_half, - text_center_x - local_img_size_half : text_center_x + local_img_size_half, - ] - return local_img - - def get_text_local_img_emb( - self, - element_data: dict, - text_index: int, - scaleinfo: Tuple, - img_size: Tuple, - bg_img: Any, - **kwargs: Any, - ) -> np.array: - if self.use_extended_dataset: - return np.array(element_data["text_local_img_emb"][text_index]) - else: - H, W = img_size - text_center_x = self.get_text_center_x(element_data, text_index, scaleinfo) - text_center_y = self.get_text_center_y(element_data, text_index) - local_img = self.get_text_local_img( - bg_img.copy(), text_center_y, text_center_x, H, W - ) - local_img = Image.fromarray(local_img.astype(np.uint8)).resize((224, 224)) - inputs = self.processor(images=[local_img], return_tensors="pt") - inputs["pixel_values"] = inputs["pixel_values"].to(self.device) - image_feature = self.model.get_image_features(**inputs) - return image_feature.data.numpy() - - def load_samples(self, index: int) -> Tuple[Dict, Any, str, int]: - element_data = self.dataset[index] - svg_id = element_data["id"] - fn = os.path.join(self.data_dir, "generate_bg_png", f"{svg_id}.png") - bg = Image.open(fn).convert("RGB") # background image - return element_data, bg, svg_id, index - - def __len__(self) -> int: - return len(self.dataset) diff --git a/src/typography_generation/io/data_loader.py b/src/typography_generation/io/data_loader.py deleted file mode 100644 index 89b1771..0000000 --- a/src/typography_generation/io/data_loader.py +++ /dev/null @@ -1,123 +0,0 @@ -import time -from typing import Any, Dict, List, Tuple -import numpy as np - -import torch -from logzero import logger -from typography_generation.io.crello_util import CrelloProcessor -from typography_generation.io.data_object import ( - DesignContext, - FontConfig, - PrefixListObject, -) -from typography_generation.io.data_utils import get_canvas_context, get_element_context -from typography_generation.tools.tokenizer import Tokenizer - - -class CrelloLoader(torch.utils.data.Dataset): - def __init__( - self, - data_dir: str, - tokenizer: Tokenizer, - dataset: Any, - prefix_list_object: PrefixListObject, - font_config: FontConfig, - use_extended_dataset: bool = True, - seq_length: int = 50, - debug: bool = False, - ) -> None: - super().__init__() - self.data_dir = data_dir - self.prefix_list_object = prefix_list_object - self.debug = debug - logger.debug("create crello dataset processor") - self.dataset = CrelloProcessor( - data_dir, - tokenizer, - dataset, - font_config, - use_extended_dataset=use_extended_dataset, - ) - logger.debug("create crello dataset processor done") - self.seq_length = seq_length - - def get_ordered_text_ids(self, element_data, order_list) -> List: - text_ids = [] - for i in order_list: - if element_data["text"][i] == "": - pass - else: - text_ids.append(i) - return text_ids - - def get_order_list(self, elm: Dict[str, Any]) -> List[int]: - if self.dataset.use_extended_dataset: - return elm["order_list"] - else: - """ - Sort elments based on the raster scan order. - """ - center_y = [] - center_x = [] - scaleinfo = self.dataset.get_scale_box(elm) - for text_id in range(len(elm["text"])): - center_y.append(self.dataset.get_text_center_y(elm, text_id)) - center_x.append(self.dataset.get_text_center_x(elm, text_id, scaleinfo)) - center_y = np.array(center_y) - center_x = np.array(center_x) - sortedid = np.argsort(center_y * 1000 + center_x) - return list(sortedid) - - def load_data(self, index: int) -> Tuple: - logger.debug("load samples") - element_data, bg_img, svg_id, index = self.dataset.load_samples(index) - - # extract text element indexes - text_num = self.dataset.get_canvas_text_num(element_data) - - logger.debug("order elements") - order_list = self.get_order_list(element_data) - text_ids = self.get_ordered_text_ids(element_data, order_list) - elment_prefix_list = ( - self.prefix_list_object.textelement + self.prefix_list_object.target - ) - logger.debug("get_element_context") - text_context = get_element_context( - element_data, - bg_img, - self.dataset, - elment_prefix_list, - text_num, - text_ids, - ) - - logger.debug("get_canvas_context") - canvas_context = get_canvas_context( - element_data, - self.dataset, - self.prefix_list_object.canvas, - bg_img, - text_num, - ) - logger.debug("build design context object") - design_context = DesignContext(text_context, canvas_context) - return design_context, svg_id, element_data - - def __getitem__(self, index: int) -> Tuple[DesignContext, List, str]: - logger.debug("load data") - start = time.time() - design_context, svg_id, element_data = self.load_data(index) - logger.debug(f"load data {time.time() -start}") - logger.debug("get model input list") - model_input_list = design_context.get_model_inputs_from_prefix_list( - self.prefix_list_object.model_input - ) - logger.debug(f"get model input list {time.time() -start}") - logger.debug("get model input list done") - return design_context, model_input_list, svg_id, index - - def __len__(self) -> int: - if self.debug is True: - return 2 - else: - return len(self.dataset) diff --git a/src/typography_generation/io/data_object.py b/src/typography_generation/io/data_object.py deleted file mode 100644 index f0fb0eb..0000000 --- a/src/typography_generation/io/data_object.py +++ /dev/null @@ -1,266 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union - -import numpy as np -import torch -from logzero import logger -from torch import Tensor - -from typography_generation.config.attribute_config import ( - TextElementContextEmbeddingAttributeConfig, -) - - -@dataclass -class PrefixListObject: - textelement: List - canvas: List - model_input: List - target: List - - -@dataclass -class BinsData: - color_bin: int - small_spatio_bin: int - large_spatio_bin: int - - -@dataclass -class FontConfig: - font_num: int - font_emb_type: str = "label" - font_emb_weight: float = 1.0 - font_emb_dim: int = 40 - font_emb_name: str = "mfc" - - -@dataclass -class DataPreprocessConfig: - order_type: str - seq_length: int - - -@dataclass -class SamplingConfig: - sampling_param: str - sampling_param_geometry: float - sampling_param_semantic: float - sampling_num: int - - -class ModelInput: - def __init__(self, design_context_list: List, model_input: List, gpu: bool) -> None: - self.design_context_list = design_context_list - self.model_input = model_input - self.prefix_list = self.design_context_list[0].model_input_prefix_list - self.gpu = gpu - self.reset() - - def reset(self) -> None: - self.setgt() - if self.gpu is True: - self.cuda() - - def setgt(self) -> None: - if len(self.prefix_list) != len(self.model_input): - raise ValueError("The length between list and input is different.") - self.batch_num = self.model_input[0].shape[0] - for prefix, elm in zip(self.prefix_list, self.model_input): - setattr(self, f"{prefix}", elm.clone()) - if prefix == "canvas_text_num": - self.canvas_text_num = elm.clone() - - def cuda(self) -> None: - for prefix in self.prefix_list: - tar = getattr(self, f"{prefix}") - if type(tar) == Tensor: - setattr(self, f"{prefix}", tar.cuda()) - - def target_register(self, prefix: str, elm: Any) -> None: - setattr(self, f"{prefix}", elm) - - def zeroinitialize_style_attributes(self, prefix_list: List) -> None: - for prefix in prefix_list: - if prefix == "canvas_text_num": - continue - tar = getattr(self, f"{prefix}") - tar_rep = torch.zeros_like(tar) - setattr(self, f"{prefix}", tar_rep) - - def zeroinitialize_specific_attribute(self, prefix: str) -> None: - tar = getattr(self, f"{prefix}") - if prefix != "canvas_text_num": - setattr(self, f"{prefix}", torch.zeros_like(tar)) - - def setgt_specific_attribute(self, prefix_tar: str) -> None: - for prefix, elm in zip(self.prefix_list, self.model_input): - if prefix == prefix_tar: - tar = elm.clone() - if self.gpu is True and type(tar) == Tensor: - tar = tar.cuda() - setattr(self, f"{prefix}", tar) - - def update_th_style_attributes( - self, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - prefix_list: List, - model_output: Tensor, - text_index: int, - batch_index: int = 0, - ) -> None: - for prefix in prefix_list: - target_embedding_config = getattr(embedding_config_element, prefix) - tar = getattr(self, f"{prefix}") - out = model_output[f"{prefix}"] - tar[batch_index, text_index] = out - setattr(self, f"{target_embedding_config.input_prefix}", tar) - - def zeroinitialize_th_style_attributes( - self, - prefix_list: List, - text_index: int, - batch_index: int = 0, - ) -> None: - for prefix in prefix_list: - tar = getattr(self, f"{prefix}") - tar[batch_index, text_index] = 0 - setattr(self, f"{prefix}", tar) - - def additional_input_from_design_context_list( - self, design_context_list: List - ) -> None: - self.texts = [] - for design_context in design_context_list: - self.texts.append(design_context.text_context.texts) - - -@dataclass -class ElementContext: - prefix_list: List - rawdata2token: Dict - img_size: Tuple - scaleinfo: Tuple - seq_length: int - - def __post_init__(self) -> None: - for prefix in self.prefix_list: - setattr(self, prefix, []) - if prefix == "text_local_img": - self.text_local_img_model_input = np.zeros( - (self.seq_length, 3, 224, 224), dtype=np.float32 - ) - elif prefix == "text_local_img_emb": - self.text_local_img_emb_model_input = np.zeros( - (self.seq_length, 512), dtype=np.float32 - ) - elif prefix == "text_emb": - self.text_emb_model_input = np.zeros( - (self.seq_length, 512), dtype=np.float32 - ) - elif prefix == "text_font_emb": - self.text_font_emb_model_input = ( - np.zeros((self.seq_length, 40), dtype=np.float32) - 10000 - ) - else: - setattr( - self, - f"{prefix}_model_input", - np.zeros((self.seq_length), dtype=np.float32) - 1, - ) - - if prefix in self.rawdata2token.keys(): - setattr( - self, - f"{self.rawdata2token[prefix]}_model_input", - np.zeros((self.seq_length), dtype=np.float32) - 1, - ) - - -@dataclass -class CanvasContext: - canvas_bg_img: np.array - canvas_text_num: int - img_size: Tuple - scale_box: Tuple - prefix_list: List - - -@dataclass -class DesignContext: - element_context: ElementContext - canvas_context: CanvasContext - - def __post_init__(self) -> None: - self.prepare_keys() - - def prepare_keys(self) -> None: - self.canvas_context_keys = dir(self.canvas_context) - self.element_context_keys = dir(self.element_context) - - def get_text_num(self) -> int: - return self.canvas_context.canvas_text_num - - def convert_target_to_torch_format( - self, - tar: Any, - ) -> Any: - if type(tar) == np.ndarray: - tar = torch.from_numpy(tar) - elif type(tar) == float or type(tar) == int: - tar = torch.Tensor([tar]) - elif type(tar) == str or type(tar) == list: - pass - else: - logger.info(tar) - raise NotImplementedError() - return tar - - def search_class(self, prefix: str) -> Union[ElementContext, CanvasContext]: - if prefix in self.canvas_context_keys: - return self.canvas_context - elif prefix in self.element_context_keys: - return self.element_context - else: - logger.info( - f"{prefix}, {self.canvas_context_keys}, {self.element_context_keys}" - ) - raise NotImplementedError() - - def get_model_inputs_from_prefix_list(self, prefix_list: List) -> List: - self.model_input_prefix_list = prefix_list - model_inputs = [] - for prefix in prefix_list: - logger.debug(f"convert_target_to_torch_format {prefix}") - tar_cls = self.search_class(f"{prefix}") - tar = getattr(tar_cls, f"{prefix}_model_input") - tar = self.convert_target_to_torch_format(tar) - model_inputs.append(tar) - return model_inputs - - def get_data(self, prefix: str) -> Any: - tar_cls = self.search_class(f"{prefix}") - tar = getattr(tar_cls, f"{prefix}") - return tar - - def get_canvas_size(self) -> Tuple: - canvas_size = ( - self.canvas_context.canvas_img_size_h, - self.canvas_context.canvas_img_size_w, - ) - return canvas_size - - def get_text_context(self) -> ElementContext: - return self.element_context - - def get_bg(self) -> np.array: - return self.canvas_context.canvas_bg_img - - def get_scaleinfo(self) -> Tuple: - return (self.canvas_context.canvas_h_scale, self.canvas_context.canvas_w_scale) - - def convert_torch_format(self, prefix: str) -> None: - tar_cls = self.search_class(prefix) - tar = getattr(tar_cls, prefix) - tar = self.convert_target_to_torch_format(tar) - setattr(tar_cls, prefix, tar) diff --git a/src/typography_generation/io/data_utils.py b/src/typography_generation/io/data_utils.py deleted file mode 100644 index 3d76d71..0000000 --- a/src/typography_generation/io/data_utils.py +++ /dev/null @@ -1,89 +0,0 @@ -import time -from typing import List, Tuple -import PIL -import numpy as np -from typography_generation.io.crello_util import CrelloProcessor -from typography_generation.io.data_object import CanvasContext, ElementContext -from logzero import logger - - -def get_canvas_context( - element_data: dict, - dataset: CrelloProcessor, - canvas_prefix_list: List, - bg_img: np.array, - text_num: int, -) -> CanvasContext: - img_size = (bg_img.size[1], bg_img.size[0]) # (h,w) - scale_box = dataset.get_scale_box(element_data) - canvas_context = CanvasContext( - bg_img, text_num, img_size, scale_box, canvas_prefix_list - ) - for prefix in canvas_prefix_list: - data_info = {"element_data": element_data, "bg_img": bg_img} - data = getattr(dataset, f"get_{prefix}")(**data_info) - setattr(canvas_context, f"{prefix}", data) - setattr(canvas_context, f"{prefix}_model_input", data) - - return canvas_context - - -def get_element_context( - element_data: dict, - bg_img: PIL.Image, - dataset: CrelloProcessor, - element_prefix_list: List, - text_num: int, - text_ids: List, -) -> ElementContext: - scaleinfo = dataset.get_scale_box(element_data) - img_size = (bg_img.size[1], bg_img.size[0]) # (h,w) - element_context = ElementContext( - element_prefix_list, - dataset.tokenizer.rawdata2token, - img_size, - scaleinfo, - dataset.seq_length, - ) - for i in range(text_num): - text_index = text_ids[i] - for prefix in element_prefix_list: - start = time.time() - - data_info = { - "element_data": element_data, - "text_index": text_index, - "img_size": img_size, - "scaleinfo": scaleinfo, - "bg_img": bg_img, - } - if prefix == "text_emb": - data = getattr(dataset, f"get_{prefix}")(**data_info) - element_context.text_emb_model_input[i] = data - elif prefix == "text_local_img_emb": - data = getattr(dataset, f"get_{prefix}")(**data_info) - element_context.text_local_img_emb_model_input[i] = data - elif prefix == "text_font_emb": - data = dataset.get_text_font_emb(element_data, text_index) - getattr(element_context, prefix).append(data) - element_context.text_font_emb_model_input[i] = data - else: - data = getattr(dataset, f"get_{prefix}")(**data_info) - getattr(element_context, prefix).append(data) - if prefix in dataset.tokenizer.prefix_list: - model_input = dataset.tokenizer.tokenize(prefix, data) - else: - model_input = data - getattr(element_context, f"{prefix}_model_input")[i] = model_input - - if prefix in dataset.tokenizer.rawdata_list: - token = getattr(dataset, f"raw2token_{prefix}")(data) - getattr( - element_context, - f"{dataset.tokenizer.rawdata2token[prefix]}_model_input", - )[i] = token - - end = time.time() - logger.debug(f"{prefix} {end - start}") - - return element_context diff --git a/src/typography_generation/model/__init__.py b/src/typography_generation/model/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/src/typography_generation/model/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/typography_generation/model/bart.py b/src/typography_generation/model/bart.py deleted file mode 100644 index 77a8fe7..0000000 --- a/src/typography_generation/model/bart.py +++ /dev/null @@ -1,522 +0,0 @@ -import random -import time -from typing import Any, Dict, List, Tuple, Union - -import numpy as np -import torch -from logzero import logger -from torch import Tensor, nn -from torch.functional import F -from typography_generation.config.attribute_config import ( - CanvasContextEmbeddingAttributeConfig, - TextElementContextEmbeddingAttributeConfig, - TextElementContextPredictionAttributeConfig, -) -from typography_generation.io.crello_util import CrelloProcessor -from typography_generation.io.data_object import ModelInput -from typography_generation.model.decoder import Decoder, MultiTask -from typography_generation.model.embedding import Embedding -from typography_generation.model.encoder import Encoder - -FILTER_VALUE = -float("Inf") - - -def top_p( - prior: Tensor, - text_index: int, - sampling_param: float, -) -> int: - prior = F.softmax(prior, 1) - sorted_prob, sorted_label = torch.sort(input=prior, dim=1, descending=True) - prior = sorted_prob[text_index] - sum_p = 0 - for k in range(len(prior)): - sum_p += prior[k].item() - if sum_p > sampling_param: # prior - break - - range_class = k - if range_class == 0: - index = 0 - else: - index = random.randint(0, range_class) - out_label = sorted_label[text_index][index].item() - return out_label - - -def top_p_weight( - logits: Tensor, - text_index: int, - sampling_param: float, -) -> int: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=1) - cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=1), dim=1) - S = logits.size(1) - indices = torch.arange(S).view(1, S).to(logits.device) - # make sure to keep the first logit (most likely one) - sorted_logits[(cumulative_probs > sampling_param) & (indices > 0)] = FILTER_VALUE - logits = sorted_logits.gather(dim=1, index=sorted_indices.argsort(dim=1)) - probs = F.softmax(logits, dim=1) - output = torch.multinomial(probs, num_samples=1) # (B, 1) - logger.debug(f"{output.shape}=") - return output[text_index][0].item() - - -sampler_dict = {"top_p": top_p, "top_p_weight": top_p_weight} - - -def sample_label( - logits: Tensor, text_index: int, sampling_param: float, mode: str = "top_p" -) -> int: - return sampler_dict[mode](logits, text_index, sampling_param) - - -def get_structure(pred: np.array, indexes: List) -> Dict: - index2samestructureindexes: Dict[str, List] - index2samestructureindexes = {} - for i in indexes: - index2samestructureindexes[i] = [] - label_i = pred[i] - for j in indexes: - label_j = pred[j] - if label_i == label_j: - index2samestructureindexes[i].append(j) - return index2samestructureindexes - - -def get_structure_dict(prefix_list: List, preds: Dict, indexes: List) -> Dict: - index2samestructureindexes = {} - for prefix in prefix_list: - index2samestructureindexes[prefix] = get_structure( - preds[prefix], - indexes, - ) - return index2samestructureindexes - - -def get_init_label_link(indexes: List) -> Dict: - label_link: Dict[str, Union[int, None]] - label_link = {} - for i in indexes: - label_link[i] = None - return label_link - - -def initialize_link(prefix_list: List, indexes: List) -> Tuple[Dict, Dict]: - label_link: Dict[str, Dict] - label_link = {} - used_labels: Dict[str, List] - used_labels = {} - for prefix in prefix_list: - label_link[prefix] = get_init_label_link(indexes) - used_labels[prefix] = [] - return label_link, used_labels - - -def label_linkage( - prefix_list: List, - text_num: int, - preds: Dict, -) -> Tuple[Dict, Dict, Dict]: - indexes = list(range(text_num)) - index2samestructureindexes = get_structure_dict(prefix_list, preds, indexes) - label_link, used_labels = initialize_link(prefix_list, indexes) - return index2samestructureindexes, label_link, used_labels - - -def update_label_link(label_link: Dict, samestructureindexes: List, label: int) -> Dict: - for i in samestructureindexes: - label_link[i] = label - return label_link - - -class BART(nn.Module): - def __init__( - self, - prefix_list_element: List, - prefix_list_canvas: List, - prefix_list_target: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - prediction_config_element: TextElementContextPredictionAttributeConfig, - d_model: int = 256, - n_head: int = 8, - dropout: float = 0.1, - num_encoder_layers: int = 4, - num_decoder_layers: int = 4, - seq_length: int = 50, - bypass: bool = True, - **kwargs: Any, - ) -> None: - super().__init__() - logger.info(f"BART settings") - logger.info(f"d_model: {d_model}") - logger.info(f"n_head: {n_head}") - logger.info(f"num_encoder_layers: {num_encoder_layers}") - logger.info(f"num_decoder_layers: {num_decoder_layers}") - logger.info(f"seq_length: {seq_length}") - self.embedding_config_element = embedding_config_element - - self.emb = Embedding( - prefix_list_element, - prefix_list_canvas, - embedding_config_element, - embedding_config_canvas, - d_model=d_model, - dropout=dropout, - seq_length=seq_length, - ) - self.enc = Encoder( - d_model=d_model, - n_head=n_head, - dropout=dropout, - num_encoder_layers=num_encoder_layers, - ) - self.dec = Decoder( - prefix_list_target, - embedding_config_element, - d_model=d_model, - n_head=n_head, - dropout=dropout, - num_decoder_layers=num_decoder_layers, - seq_length=seq_length, - ) - self.head = MultiTask( - prefix_list_target, - prediction_config_element, - d_model, - bypass=bypass, - ) - self.initialize_weights() - - for prefix in prefix_list_target: - target_prediction_config = getattr(prediction_config_element, prefix) - setattr(self, f"{prefix}_att_type", target_prediction_config.att_type) - - def forward(self, model_inputs: ModelInput) -> Tensor: - start = time.time() - ( - src, - text_mask_src, - feat_cat, - ) = self.emb(model_inputs) - logger.debug(f"{time.time()-start} sec emb") - start = time.time() - z = self.enc(src, text_mask_src) - logger.debug(f"{time.time()-start} sec enc") - start = time.time() - zd = self.dec(feat_cat, z, model_inputs) - logger.debug(f"{time.time()-start} sec dec") - start = time.time() - outs = self.head(zd, feat_cat) - logger.debug(f"{time.time()-start} sec head") - return outs - - def tokenize_model_out( - self, - dataset: CrelloProcessor, - prefix: str, - model_out: Tensor, - batch_index, - text_index, - ) -> int: - if prefix in dataset.tokenizer.rawdata_list: - data = model_out[batch_index][text_index].data.cpu().numpy() - out_label = getattr(dataset, f"raw2token_{prefix}")(data) - else: - sorted_label = torch.sort( - input=model_out[batch_index], dim=1, descending=True - )[1] - target_label = 0 # top1 - out_label = sorted_label[text_index][target_label].item() - return out_label - - def get_labels( - self, - model_outs: Dict, - dataset: CrelloProcessor, - target_prefix_list: List, - text_index: int, - batch_index: int = 0, - ) -> Dict: - out_labels = {} - for prefix in target_prefix_list: - out = model_outs[f"{prefix}"] - out_label = self.tokenize_model_out( - dataset, prefix, out, batch_index, text_index - ) - - out_labels[f"{prefix}"] = out_label - - return out_labels - - def get_outs( - self, - model_outs: Dict, - dataset: CrelloProcessor, - target_prefix_list: List, - text_index: int, - batch_index: int = 0, - ) -> Dict: - outs = {} - for prefix in target_prefix_list: - out = model_outs[f"{prefix}"] - if prefix in dataset.tokenizer.rawdata_list: - data = out[batch_index][text_index].data.cpu().numpy() - else: - sorted_label = torch.sort( - input=out[batch_index], dim=1, descending=True - )[1] - target_label = 0 # top1 - data = sorted_label[text_index][target_label].item() - - outs[f"{prefix}"] = data - - return outs - - def store( - self, - outs_all: Dict, - outs: Dict, - target_prefix_list: List, - text_index: int, - ) -> Dict: - for prefix in target_prefix_list: - outs_all[f"{prefix}"][text_index] = outs[f"{prefix}"] - return outs_all - - def prediction( - self, - model_inputs: ModelInput, - dataset: CrelloProcessor, - target_prefix_list: List, - start_index: int = 0, - ) -> Tensor: - target_text_num = int(model_inputs.canvas_text_num[0].item()) - start_index = min(start_index, target_text_num) - for t in range(start_index, target_text_num): - model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t) - src, text_mask_src, feat_cat = self.emb(model_inputs) - z = self.enc(src, text_mask_src) - - outs_all = {} - for prefix in target_prefix_list: - outs_all[prefix] = {} - for t in range(0, start_index): - tar = getattr(model_inputs, f"{prefix}")[0, t].item() - outs_all[prefix][t] = tar - for t in range(start_index, target_text_num): - zd = self.dec(feat_cat, z, model_inputs) - model_outs = self.head(zd, feat_cat) - out_labels = self.get_labels(model_outs, dataset, target_prefix_list, t) - outs = self.get_outs(model_outs, dataset, target_prefix_list, t) - model_inputs.update_th_style_attributes( - self.embedding_config_element, target_prefix_list, out_labels, t - ) - outs_all = self.store(outs_all, outs, target_prefix_list, t) - return outs_all - - def get_transformer_weight( - self, - model_inputs: ModelInput, - dataset: CrelloProcessor, - target_prefix_list: List, - start_index: int = 0, - ) -> Tensor: - target_text_num = int(model_inputs.canvas_text_num[0].item()) - start_index = min(start_index, target_text_num) - for t in range(start_index, target_text_num): - model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t) - src, text_mask_src, feat_cat = self.emb(model_inputs) - z = self.enc(src, text_mask_src) - out_labels_all = {} - for prefix in target_prefix_list: - out_labels_all[prefix] = np.zeros((target_text_num, 1)) - for t in range(0, start_index): - tar = getattr(model_inputs, f"{prefix}")[0, t].item() - out_labels_all[prefix][t, 0] = tar - weights = [] - for t in range(start_index, target_text_num): - zd, _weights = self.dec.get_transformer_weight(feat_cat, z, model_inputs) - weights.append(_weights[0, t]) - model_outs = self.head(zd, feat_cat) - out_labels = self.get_labels(model_outs, dataset, target_prefix_list, t) - model_inputs.update_th_style_attributes( - self.embedding_config_element, target_prefix_list, out_labels, t - ) - out_labels_all = self.store( - out_labels_all, out_labels, target_prefix_list, t - ) - if len(weights) > 0: - weights = torch.stack(weights, dim=0) - return weights - else: - dummy_weights = torch.zeros((1, 1)).to(src.device) - return dummy_weights - - def sample_labels( - self, - model_outs: Dict, - target_prefix_list: List, - text_index: int, - sampling_param_geometry: float = 0.5, - sampling_param_semantic: float = 0.9, - batch_index: int = 0, - ) -> Dict: - out_labels = {} - for prefix in target_prefix_list: - out = model_outs[f"{prefix}"][batch_index] - if getattr(self, f"{prefix}_att_type") == "semantic": - sampling_param = sampling_param_semantic - elif getattr(self, f"{prefix}_att_type") == "geometry": - sampling_param = sampling_param_geometry - out_label = sample_label(out, text_index, sampling_param) - out_labels[f"{prefix}"] = out_label - - return out_labels - - def sample( - self, - model_inputs: ModelInput, - target_prefix_list: List, - sampling_param_geometry: float = 0.7, - sampling_param_semantic: float = 0.7, - start_index: int = 0, - **kwargs: Any, - ) -> Tensor: - target_text_num = int(model_inputs.canvas_text_num[0].item()) - start_index = min(start_index, target_text_num) - for t in range(start_index, target_text_num): - model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t) - src, text_mask_src, feat_cat = self.emb(model_inputs) - z = self.enc(src, text_mask_src) - - outs_all = {} - for prefix in target_prefix_list: - outs_all[prefix] = {} - for t in range(start_index, target_text_num): - zd = self.dec(feat_cat, z, model_inputs) - model_outs = self.head(zd, feat_cat) - out_labels = self.sample_labels( - model_outs, - target_prefix_list, - t, - sampling_param_geometry, - sampling_param_semantic, - ) - model_inputs.update_th_style_attributes( - self.embedding_config_element, target_prefix_list, out_labels, t - ) - outs_all = self.store(outs_all, out_labels, target_prefix_list, t) - return outs_all - - def sample_labels_with_structure( - self, - model_outs: Dict, - target_prefix_list: List, - text_index: int, - label_link: Dict, - used_labels: Dict, - index2samestructureindexes: Dict, - sampling_param_geometry: float = 0.5, - sampling_param_semantic: float = 0.9, - batch_index: int = 0, - ) -> Dict: - outs_all = {} - for prefix in target_prefix_list: - _label_link = label_link[prefix][text_index] - _used_labels = used_labels[prefix] - if _label_link is None: - out = model_outs[f"{prefix}"][batch_index] - cnt = 0 - sampling_type = getattr(self, f"{prefix}_att_type") - if sampling_type == "semantic": - sampling_param = sampling_param_semantic - elif sampling_type == "geometry": - sampling_param = sampling_param_geometry - - out_label = sample_label(out, text_index, sampling_param) - max_val = max( - torch.sum(F.softmax(out, 1)[text_index]).item(), sampling_param - ) - while out_label in _used_labels: - out_label = sample_label(out, text_index, sampling_param) - cnt += 1 - if cnt > 10: - sampling_param += abs((max_val - sampling_param) * 0.1) - if cnt > 1000: - sampling_param *= 2 - - samestructureindexes = index2samestructureindexes[prefix][text_index] - label_link[prefix] = update_label_link( - label_link[prefix], samestructureindexes, out_label - ) - used_labels[prefix].append(out_label) - else: - out_label = _label_link - outs_all[f"{prefix}"] = out_label - - return outs_all - - def structure_preserved_sample( - self, - model_inputs: ModelInput, - dataset: CrelloProcessor, - target_prefix_list: List, - sampling_param_geometry: float = 0.7, - sampling_param_semantic: float = 0.7, - start_index: int = 0, - ) -> Tensor: - target_text_num = int(model_inputs.canvas_text_num[0].item()) - preds = self.prediction(model_inputs, dataset, target_prefix_list) - index2samestructureindexes, label_link, used_labels = label_linkage( - target_prefix_list, target_text_num, preds - ) - - for t in range(start_index, target_text_num): - model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t) - src, text_mask_src, feat_cat = self.emb(model_inputs) - z = self.enc(src, text_mask_src) - - outs_all = {} - for prefix in target_prefix_list: - outs_all[prefix] = {} - for t in range(start_index, target_text_num): - zd = self.dec(feat_cat, z, model_inputs) - model_outs = self.head(zd, feat_cat) - out_labels = self.sample_labels_with_structure( - model_outs, - target_prefix_list, - t, - label_link, - used_labels, - index2samestructureindexes, - sampling_param_geometry, - sampling_param_semantic, - ) - model_inputs.update_th_style_attributes( - self.embedding_config_element, target_prefix_list, out_labels, t - ) - outs_all = self.store(outs_all, out_labels, target_prefix_list, t) - return outs_all - - def initialize_weights(self) -> None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.xavier_normal_(m.weight) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - m.weight.data.normal_(mean=0.0, std=0.02) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.LayerNorm): - if m.weight is not None: - m.weight.data.fill_(1) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Embedding): - m.weight.data.normal_(mean=0.0, std=0.02) diff --git a/src/typography_generation/model/baseline.py b/src/typography_generation/model/baseline.py deleted file mode 100644 index 2fd67a2..0000000 --- a/src/typography_generation/model/baseline.py +++ /dev/null @@ -1,343 +0,0 @@ -import pickle -import random -from typing import Any, Dict, List, Tuple, Union -import numpy as np - -import torch -from logzero import logger -from torch import Tensor, nn -from typography_generation.config.attribute_config import ( - CanvasContextEmbeddingAttributeConfig, - TextElementContextEmbeddingAttributeConfig, - TextElementContextPredictionAttributeConfig, -) -from typography_generation.io.data_object import ModelInput -from typography_generation.model.bottleneck import ImlevelLF -from typography_generation.model.decoder import Decoder, MultiTask -from typography_generation.model.embedding import Embedding -from typography_generation.model.encoder import Encoder - - -class Baseline(nn.Module): - def __init__( - self, - prefix_list_element: List, - prefix_list_canvas: List, - prefix_list_target: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - prediction_config_element: TextElementContextPredictionAttributeConfig, - d_model: int = 256, - n_head: int = 8, - dropout: float = 0.1, - num_encoder_layers: int = 4, - num_decoder_layers: int = 4, - seq_length: int = 50, - std_ratio: float = 1.0, - **kwargs: Any, - ) -> None: - super().__init__() - logger.info(f"CanvasVAE settings") - logger.info(f"d_model: {d_model}") - logger.info(f"n_head: {n_head}") - logger.info(f"num_encoder_layers: {num_encoder_layers}") - logger.info(f"num_decoder_layers: {num_decoder_layers}") - logger.info(f"seq_length: {seq_length}") - - self.emb = Embedding( - prefix_list_element, - prefix_list_canvas, - embedding_config_element, - embedding_config_canvas, - d_model, - dropout, - seq_length, - ) - self.enc = Encoder( - d_model=d_model, - n_head=n_head, - dropout=dropout, - num_encoder_layers=num_encoder_layers, - ) - self.lf = ImlevelLF(vae=True, std_ratio=std_ratio) - - self.dec = Decoder( - prefix_list_target, - embedding_config_element, - d_model=d_model, - n_head=n_head, - dropout=dropout, - num_decoder_layers=num_decoder_layers, - seq_length=seq_length, - autoregressive_scheme=False, - ) - self.head = MultiTask( - prefix_list_target, prediction_config_element, d_model, bypass=False - ) - self.initialize_weights() - - def forward(self, model_inputs: ModelInput) -> Tensor: - ( - src, - text_mask_src, - feat_cat, - ) = self.emb(model_inputs) - z = self.enc(src, text_mask_src) - z, vae_data = self.lf(z, text_mask_src) - zd = self.dec(feat_cat, z, model_inputs) - outs = self.head(zd, feat_cat) - outs["vae_data"] = vae_data - return outs - - def store( - self, - out_labels_all: Dict, - out_labels: Dict, - target_prefix_list: List, - text_index: int, - ) -> Dict: - for prefix in target_prefix_list: - out_labels_all[f"{prefix}"][text_index, 0] = out_labels[f"{prefix}"] - return out_labels_all - - def prediction( - self, - model_inputs: ModelInput, - target_prefix_list: List, - start_index: int = 0, - ) -> Tensor: - target_text_num = int(model_inputs.canvas_text_num[0].item()) - start_index = min(start_index, target_text_num) - for t in range(start_index, target_text_num): - model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t) - ( - src, - text_mask_src, - feat_cat, - ) = self.emb(model_inputs) - z = self.enc(src, text_mask_src) - z = self.lf.prediction(z, text_mask_src) - zd = self.dec(feat_cat, z, model_inputs) - model_outs = self.head(zd, feat_cat) - - out_labels_all = {} - for prefix in target_prefix_list: - out_labels_all[prefix] = np.zeros((target_text_num, 1)) - for t in range(0, start_index): - tar = getattr(model_inputs, f"{prefix}")[0, t].item() - out_labels_all[prefix][t, 0] = tar - - for t in range(start_index, target_text_num): - out_labels = self.get_labels(model_outs, target_prefix_list, t) - out_labels_all = self.store( - out_labels_all, out_labels, target_prefix_list, t - ) - return out_labels_all - - def sample( - self, - model_inputs: ModelInput, - target_prefix_list: List, - start_index: int = 0, - **kwargs: Any, - ) -> Tensor: - target_text_num = int(model_inputs.canvas_text_num[0].item()) - start_index = min(start_index, target_text_num) - for t in range(start_index, target_text_num): - model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t) - ( - src, - text_mask_src, - feat_cat, - ) = self.emb(model_inputs) - z = self.enc(src, text_mask_src) - z = self.lf.sample(z, text_mask_src) - zd = self.dec(feat_cat, z, model_inputs) - model_outs = self.head(zd, feat_cat) - out_labels_all = {} - for prefix in target_prefix_list: - out_labels_all[prefix] = np.zeros((target_text_num, 1)) - for t in range(0, start_index): - tar = getattr(model_inputs, f"{prefix}")[0, t].item() - out_labels_all[prefix][t, 0] = tar - - for t in range(start_index, target_text_num): - out_labels = self.get_labels(model_outs, target_prefix_list, t) - out_labels_all = self.store( - out_labels_all, out_labels, target_prefix_list, t - ) - return out_labels_all - - def initialize_weights(self) -> None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.xavier_normal_(m.weight) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - m.weight.data.normal_(mean=0.0, std=0.02) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.LayerNorm): - if m.weight is not None: - m.weight.data.fill_(1) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Embedding): - m.weight.data.normal_(mean=0.0, std=0.02) - - -class AllZero(Baseline): - def __init__( - self, - prefix_list_element: List, - prefix_list_canvas: List, - prefix_list_target: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - prediction_config_element: TextElementContextPredictionAttributeConfig, - d_model: int = 256, - n_head: int = 8, - dropout: float = 0.1, - num_encoder_layers: int = 4, - num_decoder_layers: int = 4, - seq_length: int = 50, - std_ratio: float = 1.0, - **kwargs: Any, - ) -> None: - super().__init__( - prefix_list_element, - prefix_list_canvas, - prefix_list_target, - embedding_config_element, - embedding_config_canvas, - prediction_config_element, - d_model, - n_head, - dropout, - num_encoder_layers, - num_decoder_layers, - seq_length, - std_ratio, - ) - - def get_labels( - self, - model_outs: Union[Dict, None], - target_prefix_list: List, - text_index: int, - batch_index: int = 0, - ) -> Dict: - out_labels = {} - for prefix in target_prefix_list: - out_labels[f"{prefix}"] = 0 - - return out_labels - - -class AllRandom(Baseline): - def __init__( - self, - prefix_list_element: List, - prefix_list_canvas: List, - prefix_list_target: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - prediction_config_element: TextElementContextPredictionAttributeConfig, - d_model: int = 256, - n_head: int = 8, - dropout: float = 0.1, - num_encoder_layers: int = 4, - num_decoder_layers: int = 4, - seq_length: int = 50, - std_ratio: float = 1.0, - **kwargs: Any, - ) -> None: - super().__init__( - prefix_list_element, - prefix_list_canvas, - prefix_list_target, - embedding_config_element, - embedding_config_canvas, - prediction_config_element, - d_model, - n_head, - dropout, - num_encoder_layers, - num_decoder_layers, - seq_length, - std_ratio, - ) - - def get_labels( - self, - model_outs: Dict, - target_prefix_list: List, - text_index: int, - batch_index: int = 0, - ) -> Dict: - out_labels = {} - for prefix in target_prefix_list: - out = model_outs[f"{prefix}"] - label_range = len(out[batch_index][text_index]) - out_labels[f"{prefix}"] = random.randint(0, label_range - 1) - - return out_labels - - -class Mode(Baseline): - def __init__( - self, - prefix_list_element: List, - prefix_list_canvas: List, - prefix_list_target: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - prediction_config_element: TextElementContextPredictionAttributeConfig, - d_model: int = 256, - n_head: int = 8, - dropout: float = 0.1, - num_encoder_layers: int = 4, - num_decoder_layers: int = 4, - seq_length: int = 50, - std_ratio: float = 1.0, - **kwargs: Any, - ) -> None: - super().__init__( - prefix_list_element, - prefix_list_canvas, - prefix_list_target, - embedding_config_element, - embedding_config_canvas, - prediction_config_element, - d_model, - n_head, - dropout, - num_encoder_layers, - num_decoder_layers, - seq_length, - std_ratio, - ) - self.prefix2mode = pickle.load( - open( - f"prefix2mode.pkl", - "rb", - ) - ) - - def get_labels( - self, - model_outs: Dict, - target_prefix_list: List, - text_index: int, - batch_index: int = 0, - ) -> Dict: - out_labels = {} - for prefix in target_prefix_list: - out = model_outs[f"{prefix}"] - out_labels[f"{prefix}"] = self.prefix2mode[prefix] - - return out_labels diff --git a/src/typography_generation/model/bottleneck.py b/src/typography_generation/model/bottleneck.py deleted file mode 100644 index 0c7569e..0000000 --- a/src/typography_generation/model/bottleneck.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Tuple - -import torch -from torch import Tensor, nn - - -class VAE(nn.Module): - def __init__( - self, - d_model: int = 256, - reparametric: bool = True, - std_ratio: float = 1.0, - ): - super(VAE, self).__init__() - self.d_model = d_model - self.enc_mu_fcn = nn.Linear(d_model, d_model) - self.enc_sigma_fcn = nn.Linear(d_model, d_model) - self.reparametoric = reparametric - self.stdrate = std_ratio - - def _init_embeddings(self) -> None: - nn.init.normal_(self.enc_mu_fcn.weight, std=0.001) - nn.init.constant_(self.enc_mu_fcn.bias, 0) - nn.init.normal_(self.enc_sigma_fcn.weight, std=0.001) - nn.init.constant_(self.enc_sigma_fcn.bias, 0) - - def forward(self, z: Tensor) -> Tuple: - mu, logsigma = self.enc_mu_fcn(z), self.enc_sigma_fcn(z) - sigma = torch.exp(logsigma / 2.0) - z = mu + sigma * torch.randn_like(sigma) * self.stdrate - return z, (mu, logsigma) - - def prediction(self, z: Tensor) -> Tensor: - mu = self.enc_mu_fcn(z) - z = mu - return z - - def sample(self, z: Tensor) -> Tensor: - mu, logsigma = self.enc_mu_fcn(z), self.enc_sigma_fcn(z) - sigma = torch.exp(logsigma / 2.0) - z = mu + sigma * torch.randn_like(sigma) * self.stdrate - return z - - -class ImlevelLF(nn.Module): - def __init__( - self, - vae: bool = False, - std_ratio: float = 1.0, - ): - super().__init__() - self.vae_flag = vae - if vae is True: - self.vae = VAE(std_ratio=std_ratio) - - def forward(self, z: Tensor, text_mask: Tensor) -> Tensor: - text_mask = ( - text_mask.permute(1, 0) - .view(z.shape[0], z.shape[1], 1) - .repeat(1, 1, z.shape[2]) - ) - z_tmp = torch.sum(z * text_mask, dim=0) / (torch.sum(text_mask, dim=0) + 1e-20) - - vae_item = [] - if self.vae_flag is True: - z_tmp, vae_item_iml = self.vae(z_tmp) - vae_item.append(vae_item_iml) - return z_tmp.unsqueeze(0), vae_item - - def prediction(self, z: Tensor, text_mask: Tensor) -> Tensor: - text_mask = ( - text_mask.permute(1, 0) - .view(z.shape[0], z.shape[1], 1) - .repeat(1, 1, z.shape[2]) - ) - z_tmp = torch.sum(z * text_mask, dim=0) / (torch.sum(text_mask, dim=0) + 1e-20) - - if self.vae_flag is True: - z_tmp = self.vae.prediction(z_tmp) - return z_tmp.unsqueeze(0) - - def sample(self, z: Tensor, text_mask: Tensor) -> Tensor: - text_mask = ( - text_mask.permute(1, 0) - .view(z.shape[0], z.shape[1], 1) - .repeat(1, 1, z.shape[2]) - ) - z_tmp = torch.sum(z * text_mask, dim=0) / (torch.sum(text_mask, dim=0) + 1e-20) - - z_tmp = self.vae.sample(z_tmp) - return z_tmp.unsqueeze(0) - - -class Bottleneck(nn.Module): - def __init__(self) -> None: - super(Bottleneck, self).__init__() - pass - - def forward(self, z: Tensor) -> Tensor: - return z diff --git a/src/typography_generation/model/canvas_vae.py b/src/typography_generation/model/canvas_vae.py deleted file mode 100644 index b05bb8f..0000000 --- a/src/typography_generation/model/canvas_vae.py +++ /dev/null @@ -1,207 +0,0 @@ -from typing import Any, Dict, List, Tuple -import numpy as np - -import torch -from logzero import logger -from torch import Tensor, nn -from typography_generation.config.attribute_config import ( - CanvasContextEmbeddingAttributeConfig, - TextElementContextEmbeddingAttributeConfig, - TextElementContextPredictionAttributeConfig, -) -from typography_generation.io.crello_util import CrelloProcessor -from typography_generation.io.data_object import ModelInput -from typography_generation.model.bottleneck import ImlevelLF -from typography_generation.model.decoder import Decoder, MultiTask -from typography_generation.model.embedding import Embedding -from typography_generation.model.encoder import Encoder - - -class CanvasVAE(nn.Module): - def __init__( - self, - prefix_list_element: List, - prefix_list_canvas: List, - prefix_list_target: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - prediction_config_element: TextElementContextPredictionAttributeConfig, - d_model: int = 256, - n_head: int = 8, - dropout: float = 0.1, - num_encoder_layers: int = 4, - num_decoder_layers: int = 4, - seq_length: int = 50, - std_ratio: float = 1.0, - **kwargs: Any, - ) -> None: - super().__init__() - logger.info(f"CanvasVAE settings") - logger.info(f"d_model: {d_model}") - logger.info(f"n_head: {n_head}") - logger.info(f"num_encoder_layers: {num_encoder_layers}") - logger.info(f"num_decoder_layers: {num_decoder_layers}") - logger.info(f"seq_length: {seq_length}") - - self.emb = Embedding( - prefix_list_element, - prefix_list_canvas, - embedding_config_element, - embedding_config_canvas, - d_model, - dropout, - seq_length, - ) - self.enc = Encoder( - d_model=d_model, - n_head=n_head, - dropout=dropout, - num_encoder_layers=num_encoder_layers, - ) - self.lf = ImlevelLF(vae=True, std_ratio=std_ratio) - - self.dec = Decoder( - prefix_list_target, - embedding_config_element, - d_model=d_model, - n_head=n_head, - dropout=dropout, - num_decoder_layers=num_decoder_layers, - seq_length=seq_length, - autoregressive_scheme=False, - ) - self.head = MultiTask( - prefix_list_target, prediction_config_element, d_model, bypass=False - ) - self.initialize_weights() - - def forward(self, model_inputs: ModelInput) -> Tensor: - ( - src, - text_mask_src, - feat_cat, - ) = self.emb(model_inputs) - z = self.enc(src, text_mask_src) - z, vae_data = self.lf(z, text_mask_src) - zd = self.dec(feat_cat, z, model_inputs) - outs = self.head(zd, feat_cat) - outs["vae_data"] = vae_data - return outs - - def get_labels( - self, - model_outs: Dict, - target_prefix_list: List, - text_index: int, - batch_index: int = 0, - ) -> Dict: - out_labels = {} - for prefix in target_prefix_list: - out = model_outs[f"{prefix}"] - sorted_label = torch.sort(input=out[batch_index], dim=1, descending=True)[1] - target_label = 0 # top1 - out_label = sorted_label[text_index][target_label] - out_labels[f"{prefix}"] = out_label - - return out_labels - - def store( - self, - out_labels_all: Dict, - out_labels: Dict, - target_prefix_list: List, - text_index: int, - ) -> Dict: - for prefix in target_prefix_list: - out_labels_all[f"{prefix}"][text_index] = out_labels[f"{prefix}"].item() - return out_labels_all - - def prediction( - self, - model_inputs: ModelInput, - dataset: CrelloProcessor, - target_prefix_list: List, - start_index: int = 0, - ) -> Tensor: - target_text_num = int(model_inputs.canvas_text_num[0].item()) - start_index = min(start_index, target_text_num) - for t in range(start_index, target_text_num): - model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t) - ( - src, - text_mask_src, - feat_cat, - ) = self.emb(model_inputs) - z = self.enc(src, text_mask_src) - z = self.lf.prediction(z, text_mask_src) - zd = self.dec(feat_cat, z, model_inputs) - model_outs = self.head(zd, feat_cat) - - out_labels_all = {} - for prefix in target_prefix_list: - out_labels_all[prefix] = np.zeros((target_text_num, 1)) - for t in range(0, start_index): - tar = getattr(model_inputs, f"{prefix}")[0, t].item() - out_labels_all[prefix][t, 0] = tar - - for t in range(start_index, target_text_num): - out_labels = self.get_labels(model_outs, target_prefix_list, t) - out_labels_all = self.store( - out_labels_all, out_labels, target_prefix_list, t - ) - return out_labels_all - - def sample( - self, - model_inputs: ModelInput, - target_prefix_list: List, - start_index: int = 0, - **kwargs: Any, - ) -> Tensor: - target_text_num = int(model_inputs.canvas_text_num[0].item()) - start_index = min(start_index, target_text_num) - for t in range(start_index, target_text_num): - model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t) - ( - src, - text_mask_src, - feat_cat, - ) = self.emb(model_inputs) - z = self.enc(src, text_mask_src) - z = self.lf.sample(z, text_mask_src) - zd = self.dec(feat_cat, z, model_inputs) - model_outs = self.head(zd, feat_cat) - out_labels_all = {} - for prefix in target_prefix_list: - out_labels_all[prefix] = np.zeros((target_text_num, 1)) - for t in range(0, start_index): - tar = getattr(model_inputs, f"{prefix}")[0, t].item() - out_labels_all[prefix][t, 0] = tar - - for t in range(start_index, target_text_num): - out_labels = self.get_labels(model_outs, target_prefix_list, t) - out_labels_all = self.store( - out_labels_all, out_labels, target_prefix_list, t - ) - return out_labels_all - - def initialize_weights(self) -> None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.xavier_normal_(m.weight) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - m.weight.data.normal_(mean=0.0, std=0.02) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.LayerNorm): - if m.weight is not None: - m.weight.data.fill_(1) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Embedding): - m.weight.data.normal_(mean=0.0, std=0.02) diff --git a/src/typography_generation/model/common.py b/src/typography_generation/model/common.py deleted file mode 100644 index 30296b2..0000000 --- a/src/typography_generation/model/common.py +++ /dev/null @@ -1,446 +0,0 @@ -import copy -from collections import OrderedDict -from typing import Any, Callable, Optional, Union - -import torch -from torch import Tensor, nn -from torch.functional import F - - -def _get_clones(module: Any, N: int) -> nn.ModuleList: - return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) - - -def _get_activation_fn(activation: Any) -> Any: - if activation == "relu": - return F.relu - elif activation == "gelu": - return F.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) - - -def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]: - if src.is_nested: - return None - else: - src_size = src.size() - if len(src_size) == 2: - # unbatched: S, E - return src_size[0] - else: - # batched: B, S, E if batch_first else S, B, E - seq_len_pos = 1 if batch_first else 0 - return src_size[seq_len_pos] - - -def _generate_square_subsequent_mask( - sz: int, - device: torch.device = torch.device( - torch._C._get_default_device() - ), # torch.device('cpu'), - dtype: torch.dtype = torch.get_default_dtype(), -) -> Tensor: - r"""Generate a square causal mask for the sequence. The masked positions are filled with float('-inf'). - Unmasked positions are filled with float(0.0). - """ - return torch.triu( - torch.full((sz, sz), float("-inf"), dtype=dtype, device=device), - diagonal=1, - ) - - -def _detect_is_causal_mask( - mask: Optional[Tensor], - is_causal: Optional[bool] = None, - size: Optional[int] = None, -) -> bool: - # Prevent type refinement - make_causal = is_causal is True - - if is_causal is None and mask is not None: - sz = size if size is not None else mask.size(-2) - causal_comparison = _generate_square_subsequent_mask( - sz, device=mask.device, dtype=mask.dtype - ) - - # Do not use `torch.equal` so we handle batched masks by - # broadcasting the comparison. - if mask.size() == causal_comparison.size(): - make_causal = bool((mask == causal_comparison).all()) - else: - make_causal = False - - return make_causal - - -class MyTransformerDecoder(nn.Module): - __constants__ = ["norm"] - - def __init__(self, decoder_layer, num_layers, norm=None): - super().__init__() - torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}") - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - get_weight: bool = False, - ) -> Tensor: - output = tgt - - seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first) - - for mod in self.layers: - output, w = mod( - output, - memory, - tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) - - if self.norm is not None: - output = self.norm(output) - - if get_weight is True: - return output, w - else: - return output - - -class MyTransformerDecoderLayer(nn.Module): - __constants__ = ["batch_first"] - - def __init__( - self, - d_model: float = 256, - nhead: int = 8, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - layer_norm_eps: float = 1e-5, - batch_first: bool = False, - device: Any = None, - dtype: Any = None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.self_attn = nn.MultiheadAttention( - d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs - ) - self.multihead_attn = nn.MultiheadAttention( - d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs - ) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) - self.dropout = nn.Dropout(dropout) - self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - def __setstate__(self, state: Any) -> None: - if "activation" not in state: - state["activation"] = F.relu - super().__setstate__(state) - - def forward( - self, - tgt: torch.Tensor, - memory: torch.Tensor, - tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - tgt2 = self.self_attn( - tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask - )[0] - tgt = tgt + self.dropout1(tgt2) - tgt = self.norm1(tgt) - tgt2, weight = self.multihead_attn( - tgt, - memory, - memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - ) - tgt = tgt + self.dropout2(tgt2) - tgt = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = tgt + self.dropout3(tgt2) - tgt = self.norm3(tgt) - return tgt, weight - - -class MyTransformerEncoder(nn.Module): - __constants__ = ["norm"] - - def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: Any = None - ) -> None: - super(MyTransformerEncoder, self).__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward( - self, - src: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - get_weight: bool = False, - ) -> Tensor: - output = src - for mod in self.layers: - output, w = mod( - output, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - ) - - if self.norm is not None: - output = self.norm(output) - if get_weight is True: - return output, w - else: - return output - - -class MyTransformerEncoderLayer(nn.Module): - __constants__ = ["batch_first"] - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, - layer_norm_eps: float = 1e-5, - batch_first: bool = False, - norm_first: bool = False, - device: Any = None, - dtype: Any = None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super(MyTransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention( - d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs - ) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) - - self.norm_first = norm_first - self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - # Legacy string support for activation function. - if isinstance(activation, str): - self.activation = _get_activation_fn(activation) - else: - self.activation = activation - - def __setstate__(self, state: Any) -> None: - if "activation" not in state: - state["activation"] = F.relu - super(MyTransformerEncoderLayer, self).__setstate__(state) - - def forward( - self, - src: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - x = src - if self.norm_first: - sa, w = self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) - x = x + sa - x = x + self._ff_block(self.norm2(x)) - else: - sa, w = self._sa_block(x, src_mask, src_key_padding_mask) - x = self.norm1(x + sa) - x = self.norm2(x + self._ff_block(x)) - - return x, w - - # self-attention block - def _sa_block( - self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] - ) -> Tensor: - x, w = self.self_attn( - x, - x, - x, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=True, - ) - return self.dropout1(x), w - - # feed forward block - def _ff_block(self, x: Tensor) -> Tensor: - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - return self.dropout2(x) - - -def _conv3x3_bn_relu( - in_channels: int, - out_channels: int, - dilation: int = 1, - kernel_size: int = 3, - stride: int = 1, -) -> nn.Sequential: - if dilation == 0: - dilation = 1 - padding = 0 - else: - padding = dilation - return nn.Sequential( - OrderedDict( - [ - ( - "conv", - nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=False, - ), - ), - ("bn", nn.BatchNorm2d(out_channels)), - ("relu", nn.ReLU()), - ] - ) - ) - - -class LinearView(nn.Module): - def __init__(self, channel: int) -> None: - super(LinearView, self).__init__() - self.channel = channel - - def forward(self, x: Tensor) -> Tensor: - if len(x.shape) == 3: - BN, CHARN, CN = x.shape - if CHARN == 1: - x = x.view(BN, self.channel) - return x - - -def fn_ln_relu(fn: Any, out_channels: int, dp: float = 0.1) -> nn.Sequential: - return nn.Sequential( - OrderedDict( - [ - ("fn", fn), - ("ln", nn.LayerNorm(out_channels)), - ("view", LinearView(out_channels)), - ("dp", nn.Dropout(dp)), - ] - ) - ) - - -class Linearx2(nn.Module): - def __init__(self, in_ch: int, out_ch: int) -> None: - super().__init__() - self.fcn = nn.Sequential( - OrderedDict( - [ - ("l1", nn.Linear(in_ch, in_ch)), - ("r1", nn.LeakyReLU(0.2)), - ("l2", nn.Linear(in_ch, out_ch)), - ] - ) - ) - - def forward(self, out: Tensor) -> Tensor: - logits = self.fcn(out) # Shape [G, N, 2] - return logits - - -class Linearx3(nn.Module): - def __init__(self, in_ch: int, out_ch: int) -> None: - super().__init__() - self.fcn = nn.Sequential( - OrderedDict( - [ - ("l1", nn.Linear(in_ch, in_ch)), - ("r1", nn.LeakyReLU(0.2)), - ("l1", nn.Linear(in_ch, in_ch)), - ("r1", nn.LeakyReLU(0.2)), - ("l2", nn.Linear(in_ch, out_ch)), - ] - ) - ) - - def forward(self, out: Tensor) -> Tensor: - logits = self.fcn(out) # Shape [G, N, 2] - return logits - - -class ConstEmbedding(nn.Module): - def __init__( - self, d_model: int = 256, seq_len: int = 50, positional_encoding: bool = True - ): - super().__init__() - self.d_model = d_model - self.seq_len = seq_len - self.PE = PositionalEncodingLUT( - d_model, max_len=seq_len, positional_encoding=positional_encoding - ) - - def forward(self, z: Tensor) -> Tensor: - if len(z.shape) == 2: - N = z.size(0) - elif len(z.shape) == 3: - N = z.size(1) - else: - raise Exception - pos = self.PE(z.new_zeros(self.seq_len, N, self.d_model)) - return pos - - -class PositionalEncodingLUT(nn.Module): - def __init__( - self, - d_model: int = 256, - dropout: float = 0.1, - max_len: int = 50, - positional_encoding: bool = True, - ): - super(PositionalEncodingLUT, self).__init__() - self.PS = positional_encoding - self.dropout = nn.Dropout(p=dropout) - position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(1) - self.register_buffer("position", position) - self.pos_embed = nn.Embedding(max_len, d_model) - self._init_embeddings() - - def _init_embeddings(self) -> None: - nn.init.kaiming_normal_(self.pos_embed.weight, mode="fan_in") - - def forward(self, x: Tensor, inp_ignore: bool = False) -> Tensor: - pos = self.position[: x.size(0)] - x = self.pos_embed(pos).repeat(1, x.size(1), 1) - return self.dropout(x) diff --git a/src/typography_generation/model/decoder.py b/src/typography_generation/model/decoder.py deleted file mode 100644 index a2b33ff..0000000 --- a/src/typography_generation/model/decoder.py +++ /dev/null @@ -1,210 +0,0 @@ -from typing import Any, List, Tuple - -import torch -from torch import Tensor, nn -from typography_generation.config.attribute_config import ( - TextElementContextEmbeddingAttributeConfig, - TextElementContextPredictionAttributeConfig, -) -from typography_generation.io.data_object import ModelInput -from typography_generation.model.common import ( - ConstEmbedding, - Linearx2, - MyTransformerDecoder, - MyTransformerDecoderLayer, - fn_ln_relu, -) - - -class Decoder(nn.Module): - def __init__( - self, - prefix_list_target: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - d_model: int = 256, - n_head: int = 8, - dropout: float = 0.1, - num_decoder_layers: int = 4, - seq_length: int = 50, - positional_encoding: bool = True, - autoregressive_scheme: bool = True, - ): - super().__init__() - - self.prefix_list_target = prefix_list_target - self.d_model = d_model - self.dropout_element = dropout - self.seq_length = seq_length - self.autoregressive_scheme = autoregressive_scheme - dim_feedforward = d_model * 2 - - # Positional encoding in transformer - position = torch.arange(0, 1, dtype=torch.long).unsqueeze(1) - self.register_buffer("position", position) - self.pos_embed = nn.Embedding(1, d_model) - self.embedding = ConstEmbedding(d_model, seq_length, positional_encoding) - - # Decoder layer - # decoder_layer = nn.TransformerDecoderLayer( - # d_model, n_head, dim_feedforward, dropout - # ) - decoder_norm = nn.LayerNorm(d_model) - # self.decoder = nn.TransformerDecoder( - # decoder_layer, num_decoder_layers, decoder_norm - # ) - decoder_layer = MyTransformerDecoderLayer( - d_model, n_head, dim_feedforward, dropout - ) - self.decoder = MyTransformerDecoder( - decoder_layer, num_decoder_layers, decoder_norm - ) - - # mask config in training - mask_tgt = torch.ones((seq_length, seq_length)) - mask_tgt = torch.triu(mask_tgt == 1).transpose(0, 1) - self.mask_tgt = mask_tgt.float().masked_fill(mask_tgt == 0, float("-inf")) - # Build bart embedding - self.build_bart_embedding(prefix_list_target, embedding_config_element) - - def build_bart_embedding( - self, - prefix_list_target: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - ) -> None: - for prefix in prefix_list_target: - target_embedding_config = getattr(embedding_config_element, prefix) - kwargs = target_embedding_config.emb_layer_kwargs # dict of args - emb_layer = nn.Embedding(**kwargs) - emb_layer = fn_ln_relu(emb_layer, self.d_model, self.dropout_element) - setattr(self, f"{prefix}_emb", emb_layer) - setattr(self, f"{prefix}_flag", target_embedding_config.flag) - - def get_features_via_fn( - self, fn: Any, inputs: List, batch_num: int, text_num: Tensor - ) -> Tensor: - inputs_fn = [] - for b in range(batch_num): - tn = int(text_num[b].item()) - for t in range(tn): - inputs_fn.append(inputs[b][t].view(-1)) - outs = None - if len(inputs_fn) > 0: - inputs_fn = torch.stack(inputs_fn) - outs = fn(inputs_fn) - outs = outs.view(len(inputs_fn), self.d_model) - feat = torch.zeros(self.seq_length, batch_num, self.d_model) - feat = feat.to(text_num.device).float() - cnt = 0 - for b in range(batch_num): - tn = int(text_num[b].item()) - for t in range(tn): - if outs is not None: - feat[t, b] = outs[cnt] - cnt += 1 - return feat - - def get_style_context_embedding( - self, model_inputs: ModelInput, batch_num: int, text_num: Tensor - ) -> Tensor: - feat = torch.zeros(self.seq_length, batch_num, self.d_model) - feat = feat.to(text_num.device).float() - for prefix in self.prefix_list_target: - inp = getattr(model_inputs, prefix).long() - layer = getattr(self, f"{prefix}_emb") - f = self.get_features_via_fn(layer, inp, batch_num, text_num) - feat = feat + f - feat = feat / len(self.prefix_list_target) - return feat - - def shift_context_feat( - self, context_feat: Tensor, batch_num: int, text_num: Tensor - ) -> Tensor: - shifted_context_feat = torch.zeros(self.seq_length, batch_num, self.d_model) - shifted_context_feat = shifted_context_feat.to(text_num.device).float() - for b in range(batch_num): - tn = int(text_num[b].item()) - for t in range(tn - 1): - shifted_context_feat[t + 1, b] = context_feat[t, b] - return shifted_context_feat - - def get_bart_embedding(self, src: Tensor, model_inputs: ModelInput) -> Tensor: - batch_num, text_num = model_inputs.batch_num, model_inputs.canvas_text_num - context_feat = self.get_style_context_embedding( - model_inputs, batch_num, text_num - ) - shifted_context_feat = self.shift_context_feat( - context_feat, batch_num, text_num - ) - position_feat = self.embedding(src) - return shifted_context_feat + position_feat - - def forward( - self, - src: Tensor, - z: Tensor, - model_inputs: ModelInput, - ) -> Tensor: - if self.autoregressive_scheme is True: - src = self.get_bart_embedding(src, model_inputs) - else: - src = self.embedding(src) - mask_tgt = self.mask_tgt.to(src.device) - out = self.decoder(src, z, mask_tgt, tgt_key_padding_mask=None) - return out - - def get_transformer_weight( - self, - src: Tensor, - z: Tensor, - model_inputs: ModelInput, - ) -> Tuple: - if self.autoregressive_scheme is True: - src = self.get_bart_embedding(src, model_inputs) - else: - src = self.embedding(src) - mask_tgt = self.mask_tgt.to(src.device) - out, weights = self.decoder( - src, z, mask_tgt, tgt_key_padding_mask=None, get_weight=True - ) - return out, weights - - -class FCN(nn.Module): - def __init__(self, d_model: int, label_num: int, bypass: bool) -> None: - super().__init__() - self.bypass = bypass - if bypass is True: - self.fcn = Linearx2(d_model * 2, label_num) - else: - self.fcn = Linearx2(d_model, label_num) - - def forward(self, inp: Tensor, elm_agg_emb: Tensor = None) -> Tensor: - if self.bypass is True: - inp = torch.cat((inp, elm_agg_emb), 2) - logits = inp.permute(1, 0, 2) - logits = self.fcn(logits) # Shape [G, N, 2] - return logits - - -class MultiTask(nn.Module): - def __init__( - self, - prefix_list_target: List, - prediction_config: TextElementContextPredictionAttributeConfig, - d_model: int = 256, - bypass: bool = True, - ): - super().__init__() - self.prefix_list_target = prefix_list_target - for prefix in self.prefix_list_target: - target_prediction_config = getattr(prediction_config, prefix) - layer = FCN(d_model, target_prediction_config.out_dim, bypass) - setattr(self, f"{prefix}_layer", layer) - - def forward(self, z: Tensor, elm_agg_emb: Tensor) -> dict: - outputs = {} - for prefix in self.prefix_list_target: - layer = getattr(self, f"{prefix}_layer") - out = layer(z, elm_agg_emb) - outputs[prefix] = out - return outputs diff --git a/src/typography_generation/model/embedding.py b/src/typography_generation/model/embedding.py deleted file mode 100644 index 725ddfa..0000000 --- a/src/typography_generation/model/embedding.py +++ /dev/null @@ -1,407 +0,0 @@ -from collections import OrderedDict -import time -from typing import Any, List, Tuple -from einops import rearrange, repeat -from logzero import logger -import torch -from torch import Tensor, nn -from torch.functional import F -from typography_generation.config.attribute_config import ( - CanvasContextEmbeddingAttributeConfig, - TextElementContextEmbeddingAttributeConfig, - EmbeddingConfig, -) -from typography_generation.io.data_object import ModelInput -from typography_generation.model.common import ( - _conv3x3_bn_relu, - fn_ln_relu, -) - - -def set_tensor_type(inp: Tensor, tensor_type: str) -> Tensor: - if tensor_type == "float": - inp = inp.float() - elif tensor_type == "long": - inp = inp.long() - else: - raise NotImplementedError() - return inp - - -def setup_emb_layer( - self: nn.Module, prefix: str, target_embedding_config: EmbeddingConfig -) -> None: - if target_embedding_config.emb_layer is not None: - kwargs = target_embedding_config.emb_layer_kwargs # dict of args - if target_embedding_config.emb_layer == "nn.Embedding": - emb_layer = nn.Embedding(**kwargs) - setattr(self, f"{prefix}_type", "long") - elif target_embedding_config.emb_layer == "nn.Linear": - emb_layer = nn.Linear(**kwargs) - setattr(self, f"{prefix}_type", "float") - else: - raise NotImplementedError() - emb_layer = fn_ln_relu(emb_layer, self.d_model, self.dropout) - setattr(self, f"{prefix}_emb", emb_layer) - else: - setattr(self, f"{prefix}_type", "float") - setattr(self, f"{prefix}_flag", target_embedding_config.flag) - setattr(self, f"{prefix}_specific", target_embedding_config.specific_func) - setattr(self, f"{prefix}_inp", target_embedding_config.input_prefix) - - -def get_output( - self: nn.Module, prefix: str, model_inputs: ModelInput, batch_num: int -) -> Tensor: - inp_prefix = getattr(self, f"{prefix}_inp") - specific_func = getattr(self, f"{prefix}_specific") - tensor_type = getattr(self, f"{prefix}_type") - inp = getattr(model_inputs, inp_prefix) - text_num = getattr(model_inputs, "canvas_text_num") - inp = set_tensor_type(inp, tensor_type) - if specific_func is not None: - fn = getattr(self, specific_func) - out = fn( - inputs=inp, - batch_num=batch_num, - text_num=text_num, - ) - else: - fn = getattr(self, f"{prefix}_emb") - out = self.get_features_via_fn( - fn=fn, inputs=inp, batch_num=batch_num, text_num=text_num - ) - return out - - -class Down(nn.Module): - def __init__(self, input_dim: int, output_dim: int) -> None: - super(Down, self).__init__() - self.conv1 = _conv3x3_bn_relu(input_dim, output_dim, kernel_size=3) - self.drop = nn.Dropout() - - def forward(self, feat: Tensor) -> Tensor: - feat = self.conv1(feat) - return self.drop(feat) - - -class Embedding(nn.Module): - def __init__( - self, - prefix_list_element: List, - prefix_list_canvas: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - d_model: int = 256, - dropout: float = 0.1, - seq_length: int = 50, - ) -> None: - super(Embedding, self).__init__() - self.emb_element = EmbeddingElementContext( - prefix_list_element, embedding_config_element, d_model, dropout, seq_length - ) - self.emb_canvas = EmbeddingCanvasContext( - prefix_list_canvas, - embedding_config_canvas, - d_model, - dropout, - ) - self.prefix_list_element = prefix_list_element - self.prefix_list_canvas = prefix_list_canvas - - self.d_model = d_model - self.seq_length = seq_length - - mlp_dim = ( - self.compute_modality_num(embedding_config_canvas, embedding_config_element) - * d_model - ) - self.mlp = nn.Sequential( - OrderedDict( - [ - ("l1", nn.Linear(mlp_dim, d_model)), - ("r1", nn.LeakyReLU(0.2)), - ("l2", nn.Linear(d_model, d_model)), - ("r2", nn.LeakyReLU(0.2)), - ] - ) - ) - self.mlp_dim = mlp_dim - - def compute_modality_num( - self, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - embedding_config_text: TextElementContextEmbeddingAttributeConfig, - ) -> int: - modality_num = 0 - for prefix in self.prefix_list_element: - target_embedding_config = getattr(embedding_config_text, prefix) - flag = target_embedding_config.flag - if flag is False: - pass - else: - modality_num += 1 - for prefix in self.prefix_list_canvas: - target_embedding_config = getattr(embedding_config_canvas, prefix) - flag = target_embedding_config.flag - if flag is True: - modality_num += 1 - return modality_num - - def flatten_feature( - self, - feats: List, - ) -> Tensor: - feats_flatten = [] - modarity_num = len(feats) - for m in range(modarity_num): - feats_flatten.append(feats[m]) - return feats_flatten - - def reshape_tbc( - self, - feat_elements: List, # modality num(M) x [B, S, C] - feat_canvas: List, # canvas num(CV) x [B, C] - batch_num: int, - text_num: Tensor, - ) -> Tensor: - feat_elements = torch.stack(feat_elements) # M, B, S, C - feat_elements = rearrange(feat_elements, "m b s c -> s b (m c)") - feat_canvas = torch.stack(feat_canvas) # CV, B, C - feat_canvas = rearrange(feat_canvas, "cv b c -> 1 b (cv c)") - feat_canvas = repeat(feat_canvas, "1 b c -> s b c", s=self.seq_length) - feat = torch.cat((feat_canvas, feat_elements), dim=2) - feat = self.mlp(feat) - return feat - - def get_transformer_inputs( - self, - feat_elements: List[Tensor], - feat_canvas: Tensor, - batch_num: int, - text_num: Tensor, - ) -> Tuple[Tensor, Tensor]: - feat_element, text_mask_element = self.lineup_features( - feat_elements, batch_num, text_num - ) - feat_canvas = torch.stack(feat_canvas) - feat_canvas = feat_canvas.view( - feat_canvas.shape[0], feat_element.shape[1], feat_element.shape[2] - ) - src = torch.cat((feat_canvas, feat_element), dim=0) - canvas_mask = torch.zeros(src.shape[1], feat_canvas.shape[0]) + 1 - canvas_mask = canvas_mask.float().to(text_num.device) - text_mask_src = torch.cat((canvas_mask, text_mask_element), dim=1) - return src, text_mask_src - - def lineup_features( - self, feats: Tensor, batch_num: int, text_num: Tensor - ) -> Tuple[Tensor, Tensor]: - modality_num = len(feats) - device = text_num.device - feat = ( - torch.zeros(self.seq_length, modality_num, batch_num, self.d_model) - .float() - .to(device) - ) - for m in range(modality_num): - logger.debug(f"{feats[m].shape=}") - feat[:, m, :, :] = rearrange(feats[m], "b t c -> t b c") - feat = rearrange(feat, "t m b c -> (t m) b c") - indices = rearrange(torch.arange(self.seq_length), "s -> 1 s").to(device) - mask = ( - (indices < text_num).to(device).float() - ) # (B, S), indicating valid attribute locations - text_mask = repeat(mask, "b s -> b (s m)", m=modality_num) - return feat, text_mask - - def forward(self, model_inputs: ModelInput) -> Tuple[Tensor, Tensor, Tensor]: - feat_elements = self.emb_element(model_inputs) - feat_canvas = self.emb_canvas(model_inputs) - - src, text_mask_src = self.get_transformer_inputs( - feat_elements, - feat_canvas, - model_inputs.batch_num, - model_inputs.canvas_text_num, - ) - feat_cat = self.reshape_tbc( - feat_elements, - feat_canvas, - model_inputs.batch_num, - model_inputs.canvas_text_num, - ) - - return (src, text_mask_src, feat_cat) - - -class EmbeddingElementContext(nn.Module): - def __init__( - self, - prefix_list_element: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - d_model: int, - dropout: float = 0.1, - seq_length: int = 50, - ): - super(EmbeddingElementContext, self).__init__() - self.d_model = d_model - self.dropout = dropout - self.seq_length = seq_length - self.prefix_list_element = prefix_list_element - - for prefix in self.prefix_list_element: - target_embedding_config = getattr(embedding_config_element, prefix) - setup_emb_layer(self, prefix, target_embedding_config) - if target_embedding_config.specific_build is not None: - build_func = getattr(self, target_embedding_config.specific_build) - build_func() - - def get_local_feat( - self, inputs: Tensor, batch_num: int, text_num: Tensor, **kwargs: Any - ) -> Tensor: - inputs_fn = [] - for b in range(batch_num): - tn = int(text_num[b].item()) - for t in range(tn): - inputs_fn.append(inputs[b, t]) - outs = None - if len(inputs_fn) > 0: - inputs_fn = torch.stack(inputs_fn) - with torch.no_grad(): - outs = self.imgencoder(inputs_fn).detach() - outs = self.downdim(outs) - outs = outs.view(len(outs), -1) - outs = self.fcimg(outs) - return outs - - # def get_features_via_fn( - # self, fn: Any, inputs: List, batch_num: int, text_num: Tensor - # ) -> Tensor: - # inputs_layer = [] - # for b in range(batch_num): - # tn = int(text_num[b].item()) - # for t in range(tn): - # inputs_layer.append(inputs[b][t].view(-1)) - # outs = None - # if len(inputs_layer) > 0: - # inputs_layer = torch.stack(inputs_layer) - # outs = fn(inputs_layer) - # outs = outs.view(len(inputs_layer), self.d_model) - # return outs - - def get_features_via_fn( - self, fn: Any, inputs: Tensor, batch_num: int, text_num: Tensor - ) -> Tensor: - device = inputs.device - feat = torch.zeros(batch_num, self.seq_length, self.d_model).float().to(device) - indices = rearrange(torch.arange(self.seq_length), "s -> 1 s").to(device) - mask = (indices < text_num).to( - device - ) # (B, S), indicating valid attribute locations - feat[mask] = fn(inputs[mask]) # (B, S, C) - return feat - - def text_emb_layer( - self, inputs: Tensor, batch_num: int, text_num: Tensor, **kwargs: Any - ) -> Tensor: - inputs_fn = [] - for b in range(batch_num): - tn = int(text_num[b].item()) - for t in range(tn): - inputs_fn.append(inputs[b, t].view(-1)) - outs = None - if len(inputs_fn) > 0: - inputs_fn = torch.stack(inputs_fn) - outs = self.text_emb_emb(inputs_fn) - outs = outs.view(len(inputs_fn), self.d_model) - return outs - - def text_local_img_emb_layer( - self, inputs: Tensor, batch_num: int, text_num: Tensor, **kwargs: Any - ) -> Tensor: - inputs_fn = [] - for b in range(batch_num): - tn = int(text_num[b].item()) - for t in range(tn): - inputs_fn.append(inputs[b, t].view(-1)) - outs = None - if len(inputs_fn) > 0: - inputs_fn = torch.stack(inputs_fn) - outs = self.text_local_img_emb_emb(inputs_fn) - outs = outs.view(len(inputs_fn), self.d_model) - return outs - - def text_font_emb_layer( - self, inputs: Tensor, batch_num: int, text_num: Tensor - ) -> Tensor: - inputs_fn = [] - for b in range(batch_num): - tn = int(text_num[b].item()) - for t in range(tn): - inputs_fn.append(inputs[b, t].view(-1)) - outs = None - if len(inputs_fn) > 0: - inputs_fn = torch.stack(inputs_fn) - outs = self.text_font_emb_emb(inputs_fn) - outs = outs.view(len(inputs_fn), self.d_model) - return outs - - def forward(self, model_inputs: ModelInput) -> Tensor: - feats = [] - for prefix in self.prefix_list_element: - flag = getattr(self, f"{prefix}_flag") - if flag is True: - start = time.time() - logger.debug(f"{prefix=}") - out = get_output(self, prefix, model_inputs, model_inputs.batch_num) - logger.debug(f"{prefix=} {out.shape=} {time.time()-start}") - feats.append(out) - return feats - - -class EmbeddingCanvasContext(nn.Module): - def __init__( - self, - canvas_prefix_list: List, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - d_model: int, - dropout: float, - ): - super(EmbeddingCanvasContext, self).__init__() - self.d_model = d_model - self.dropout = dropout - self.prefix_list = canvas_prefix_list - for prefix in self.prefix_list: - target_embedding_config = getattr(embedding_config_canvas, prefix) - setup_emb_layer(self, prefix, target_embedding_config) - if target_embedding_config.specific_build is not None: - build_func = getattr(self, target_embedding_config.specific_build) - build_func() - - def get_feat(self, inputs: Tensor, **kwargs: Any) -> Tensor: - with torch.no_grad(): - feat = self.imgencoder(inputs.float()).detach() - feat = F.relu(self.avgpool(feat)) - feat = self.downdim(feat) - feat = feat.view(feat.shape[0], feat.shape[1]) - return feat - - def canvas_bg_img_emb_layer(self, inputs: Tensor, **kwargs: Any) -> Tensor: - feat = self.canvas_bg_img_emb_emb(inputs.float()) - feat = feat.view(feat.shape[0], feat.shape[1]) - return feat - - def get_features_via_fn(self, fn: Any, inputs: Tensor, **kwargs: Any) -> Tensor: - out = fn(inputs) - return out - - def forward(self, model_inputs: ModelInput) -> List: - feats = [] - for prefix in self.prefix_list: - flag = getattr(self, f"{prefix}_flag") - if flag is True: - logger.debug(f"canvas prefix {prefix}") - out = get_output(self, prefix, model_inputs, model_inputs.batch_num) - feats.append(out) - return feats diff --git a/src/typography_generation/model/encoder.py b/src/typography_generation/model/encoder.py deleted file mode 100644 index a8ce865..0000000 --- a/src/typography_generation/model/encoder.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import Tuple - -import torch -from torch import Tensor, nn - -from typography_generation.model.common import ( - MyTransformerEncoder, - MyTransformerEncoderLayer, -) - - -class Encoder(nn.Module): - def __init__( - self, - d_model: int = 256, - n_head: int = 8, - dropout: float = 0.1, - num_encoder_layers: int = 4, - ): - super(Encoder, self).__init__() - encoder_norm = nn.LayerNorm(d_model) - dim_feedforward = d_model * 2 - encoder_layer = MyTransformerEncoderLayer( - d_model, n_head, dim_feedforward, dropout - ) - self.encoder = MyTransformerEncoder( - encoder_layer, num_encoder_layers, encoder_norm - ) - - def forward(self, src: Tensor, text_mask: Tensor) -> Tuple: - text_mask_fill = text_mask.masked_fill( - text_mask == 0, float("-inf") - ).masked_fill(text_mask == 1, float(0.0)) - z = self.encoder(src, mask=None, src_key_padding_mask=text_mask_fill.bool()) - if torch.sum(torch.isnan(z)) > 0: - z = z.masked_fill(torch.isnan(z), 0) - return z - - def get_transformer_weight(self, src: Tensor, text_mask: Tensor) -> Tuple: - text_mask_fill = text_mask.masked_fill( - text_mask == 0, float("-inf") - ).masked_fill(text_mask == 1, float(0.0)) - z, weights = self.encoder( - src, mask=None, src_key_padding_mask=text_mask_fill, get_weight=True - ) - if torch.sum(torch.isnan(z)) > 0: - z = z.masked_fill(torch.isnan(z), 0) - return z, weights diff --git a/src/typography_generation/model/mfc.py b/src/typography_generation/model/mfc.py deleted file mode 100644 index cc83c2e..0000000 --- a/src/typography_generation/model/mfc.py +++ /dev/null @@ -1,154 +0,0 @@ -from typing import Any, List, Dict -import numpy as np - -import torch -from torch import Tensor, nn -from typography_generation.config.attribute_config import ( - CanvasContextEmbeddingAttributeConfig, - TextElementContextEmbeddingAttributeConfig, - TextElementContextPredictionAttributeConfig, -) -from typography_generation.io.crello_util import CrelloProcessor -from typography_generation.io.data_object import ModelInput - -from logzero import logger - -from typography_generation.model.mlp import MLP - - -class MFC(MLP): - def __init__( - self, - prefix_list_element: List, - prefix_list_canvas: List, - prefix_list_target: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - prediction_config_element: TextElementContextPredictionAttributeConfig, - d_model: int = 256, - dropout: float = 0.1, - seq_length: int = 50, - **kwargs: Any, - ) -> None: - logger.info(f"MFC model") - super().__init__( - prefix_list_element, - prefix_list_canvas, - prefix_list_target, - embedding_config_element, - embedding_config_canvas, - prediction_config_element, - d_model, - dropout, - seq_length, - ) - for prefix in prefix_list_target: - target_prediction_config = getattr(prediction_config_element, prefix) - setattr(self, f"{prefix}_loss_type", target_prediction_config.loss_type) - - def get_label( - self, - out: Tensor, - text_index: int, - batch_index: int = 0, - ) -> Tensor: - sorted_label = torch.sort(input=out[batch_index], dim=1, descending=True)[1] - target_label = 0 # top1 - out_label = sorted_label[text_index][target_label].item() - return out_label - - def get_out( - self, - out: Tensor, - text_index: int, - batch_index: int = 0, - ) -> Tensor: - _out = out[batch_index, text_index].data.cpu().numpy() - return _out - - def get_outs( - self, - model_outs: Dict, - target_prefix_list: List, - text_index: int, - batch_index: int = 0, - ) -> Dict: - outs = {} - for prefix in target_prefix_list: - out = model_outs[f"{prefix}"] - loss_type = getattr(self, f"{prefix}_loss_type") - if loss_type == "cre": - _out = self.get_label(out, text_index, batch_index) - else: - _out = self.get_out(out, text_index, batch_index) - outs[f"{prefix}"] = _out - return outs - - def store( - self, - out_all: Dict, - out_labels: Dict, - target_prefix_list: List, - text_index: int, - ) -> Dict: - for prefix in target_prefix_list: - out_all[f"{prefix}"][text_index, 0] = out_labels[f"{prefix}"] - return out_all - - def prediction( - self, - model_inputs: ModelInput, - dataset: CrelloProcessor, - target_prefix_list: List, - start_index: int = 0, - ) -> Tensor: - target_text_num = int(model_inputs.canvas_text_num[0].item()) - start_index = min(start_index, target_text_num) - for t in range(start_index, target_text_num): - model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t) - - out_all = {} - for prefix in target_prefix_list: - loss_type = getattr(self, f"{prefix}_loss_type") - if loss_type == "cre": - out_all[prefix] = np.zeros((target_text_num, 1)) - else: - out_all[prefix] = [] - for t in range(0, start_index): - tar = getattr(model_inputs, f"{prefix}")[0, t].item() - if loss_type == "cre": - out_all[prefix][t, 0] = tar - else: - out_all[prefix].append(tar) - for t in range(start_index, target_text_num): - _, _, feat_cat = self.emb(model_inputs) - model_outs = self.head(feat_cat) - out = self.get_outs(model_outs, target_prefix_list, t) - for prefix in target_prefix_list: - loss_type = getattr(self, f"{prefix}_loss_type") - if loss_type == "cre": - out_all[f"{prefix}"][t, 0] = out[f"{prefix}"] - else: - out_all[f"{prefix}"].append(out[f"{prefix}"]) - return out_all - - def initialize_weights(self) -> None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.xavier_normal_(m.weight) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - m.weight.data.normal_(mean=0.0, std=0.02) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.LayerNorm): - if m.weight is not None: - m.weight.data.fill_(1) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Embedding): - m.weight.data.normal_(mean=0.0, std=0.02) diff --git a/src/typography_generation/model/mlp.py b/src/typography_generation/model/mlp.py deleted file mode 100644 index 81491f4..0000000 --- a/src/typography_generation/model/mlp.py +++ /dev/null @@ -1,175 +0,0 @@ -from typing import Any, List, Dict -import numpy as np - -import torch -from torch import Tensor, nn -from typography_generation.config.attribute_config import ( - CanvasContextEmbeddingAttributeConfig, - TextElementContextEmbeddingAttributeConfig, - TextElementContextPredictionAttributeConfig, -) -from typography_generation.io.crello_util import CrelloProcessor -from typography_generation.io.data_object import ModelInput -from typography_generation.model.common import Linearx2 -from typography_generation.model.embedding import Embedding - -from logzero import logger - - -class FCN(nn.Module): - def __init__(self, d_model: int, label_num: int) -> None: - super().__init__() - self.fcn = Linearx2(d_model, label_num) - - def forward(self, elm_agg_emb: Tensor = None) -> Tensor: - logits = elm_agg_emb.permute(1, 0, 2) - logits = self.fcn(logits) # Shape [G, N, 2] - return logits - - -class MultiTask(nn.Module): - def __init__( - self, - prefix_list_target: List, - prediction_config: TextElementContextPredictionAttributeConfig, - d_model: int = 256, - ): - super().__init__() - self.prefix_list_target = prefix_list_target - for prefix in self.prefix_list_target: - target_prediction_config = getattr(prediction_config, prefix) - layer = FCN(d_model, target_prediction_config.out_dim) - setattr(self, f"{prefix}_layer", layer) - - def forward(self, elm_agg_emb: Tensor) -> dict: - outputs = {} - for prefix in self.prefix_list_target: - layer = getattr(self, f"{prefix}_layer") - out = layer(elm_agg_emb) - outputs[prefix] = out - return outputs - - -class MLP(nn.Module): - def __init__( - self, - prefix_list_element: List, - prefix_list_canvas: List, - prefix_list_target: List, - embedding_config_element: TextElementContextEmbeddingAttributeConfig, - embedding_config_canvas: CanvasContextEmbeddingAttributeConfig, - prediction_config_element: TextElementContextPredictionAttributeConfig, - d_model: int = 256, - dropout: float = 0.1, - seq_length: int = 50, - **kwargs: Any, - ) -> None: - super().__init__() - logger.info(f"MLP settings") - logger.info(f"d_model: {d_model}") - logger.info(f"seq_length: {seq_length}") - self.embedding_config_element = embedding_config_element - - self.emb = Embedding( - prefix_list_element, - prefix_list_canvas, - embedding_config_element, - embedding_config_canvas, - d_model, - dropout, - seq_length, - ) - self.head = MultiTask( - prefix_list_target, - prediction_config_element, - d_model, - ) - self.initialize_weights() - - def forward(self, model_inputs: ModelInput) -> Tensor: - ( - _, - _, - feat_cat, - ) = self.emb(model_inputs) - outs = self.head(feat_cat) - return outs - - def get_labels( - self, - model_outs: Dict, - target_prefix_list: List, - text_index: int, - batch_index: int = 0, - ) -> Dict: - out_labels = {} - for prefix in target_prefix_list: - out = model_outs[f"{prefix}"] - sorted_label = torch.sort(input=out[batch_index], dim=1, descending=True)[1] - target_label = 0 # top1 - out_label = sorted_label[text_index][target_label] - out_labels[f"{prefix}"] = out_label - - return out_labels - - def store( - self, - out_labels_all: Dict, - out_labels: Dict, - target_prefix_list: List, - text_index: int, - ) -> Dict: - for prefix in target_prefix_list: - out_labels_all[f"{prefix}"][text_index, 0] = out_labels[f"{prefix}"].item() - return out_labels_all - - def prediction( - self, - model_inputs: ModelInput, - dataset: CrelloProcessor, - target_prefix_list: List, - start_index: int = 0, - ) -> Tensor: - target_text_num = int(model_inputs.canvas_text_num[0].item()) - start_index = min(start_index, target_text_num) - for t in range(start_index, target_text_num): - model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t) - - out_labels_all = {} - for prefix in target_prefix_list: - out_labels_all[prefix] = np.zeros((target_text_num, 1)) - for t in range(0, start_index): - tar = getattr(model_inputs, f"{prefix}")[0, t].item() - out_labels_all[prefix][t, 0] = tar - for t in range(start_index, target_text_num): - _, _, feat_cat = self.emb(model_inputs) - model_outs = self.head(feat_cat) - out_labels = self.get_labels(model_outs, target_prefix_list, t) - model_inputs.update_th_style_attributes( - self.embedding_config_element, target_prefix_list, out_labels, t - ) - out_labels_all = self.store( - out_labels_all, out_labels, target_prefix_list, t - ) - return out_labels_all - - def initialize_weights(self) -> None: - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.xavier_normal_(m.weight) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - m.weight.data.normal_(mean=0.0, std=0.02) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.LayerNorm): - if m.weight is not None: - m.weight.data.fill_(1) - if m.bias is not None: - m.bias.data.zero_() - elif isinstance(m, nn.Embedding): - m.weight.data.normal_(mean=0.0, std=0.02) diff --git a/src/typography_generation/model/model.py b/src/typography_generation/model/model.py deleted file mode 100644 index 6b0d633..0000000 --- a/src/typography_generation/model/model.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Any, Dict -from torch import nn - -from typography_generation.model.bart import BART -from typography_generation.model.baseline import AllRandom, AllZero, Mode -from typography_generation.model.canvas_vae import CanvasVAE -from typography_generation.model.mfc import MFC -from typography_generation.model.mlp import MLP - -MODEL_REGISTRY: Dict[str, nn.Module] = { - "bart": BART, - "mlp": MLP, - "mfc": MFC, - "canvasvae": CanvasVAE, - "allzero": AllZero, - "allrandom": AllRandom, - "mode": Mode, -} - - -def create_model(model_name: str, **kwargs: Any) -> nn.Module: - """Factory function to create a model instance.""" - model = MODEL_REGISTRY[model_name](**kwargs) - model.model_name = model_name - return model diff --git a/src/typography_generation/preprocess/__init__.py b/src/typography_generation/preprocess/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/src/typography_generation/preprocess/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/typography_generation/preprocess/map_features.py b/src/typography_generation/preprocess/map_features.py deleted file mode 100644 index ece5c2d..0000000 --- a/src/typography_generation/preprocess/map_features.py +++ /dev/null @@ -1,280 +0,0 @@ -import pickle -from typing import Any, Dict, List, Tuple -import datasets -import os -import numpy as np -from PIL import Image -from logzero import logger -import torch -import skia -from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer - -from typography_generation.visualization.renderer_util import ( - get_skia_font, - get_text_actual_width, - get_texts, -) - - -def get_scaleinfo( - element_data: Dict, -) -> List: - svgid = element_data["id"] - scaleinfo = svgid2scaleinfo[svgid] - return list(scaleinfo) - - -def get_canvassize( - bg_img: Any, -) -> List: - img_size = (bg_img.size[1], bg_img.size[0]) - return list(img_size) - - -def get_canvas_bg_img_emb(bg_img: Any, **kwargs: Any) -> np.array: - inputs = processor(images=[bg_img], return_tensors="pt") - inputs["pixel_values"] = inputs["pixel_values"].to(device) - image_feature = model.get_image_features(**inputs) - return list(image_feature.data.cpu().numpy().flatten()) - - -def get_text_emb_list( - element_data: Dict, -) -> List: - text_emb_list: List[List] - text_emb_list = [] - for k in range(len(element_data["text"])): - text = element_data["text"][k] - inputs = text_tokenizer([text], padding=True, return_tensors="pt") - if inputs["input_ids"].shape[1] > 77: - inp = inputs["input_ids"][:, :77] - else: - inp = inputs["input_ids"] - inp = inp.to(device) - text_features = model.get_text_features(inp).data.cpu().numpy()[0] - text_emb_list.append(text_features) - return text_emb_list - - -def _get_text_actual_width( - element_data: Dict, - W: int, -) -> List: - svgid = element_data["id"] - scaleinfo = svgid2scaleinfo[svgid] - scale_h, scale_w = scaleinfo - text_actual_width_list = [] - element_num = len(element_data["text"]) - - for i in range(element_num): - if element_data["text"][i] == "": - text_actual_width_list.append(None) - else: - texts = get_texts(element_data, i) - - font_label = element_data["font"][i] - font_name = fontlabel2fontname(int(font_label)) - font_name = font_name.replace(" ", "_") - font_skia, _ = get_skia_font( - font2ttf, fontmgr, element_data, i, font_name, scale_h - ) - text_width = get_text_actual_width( - element_data, i, texts, font_skia, scale_w - ) - text_actual_width_list.append(text_width / float(W)) - return text_actual_width_list - - -def get_text_center_y(element_data: dict, text_index: int) -> float: - text_height = element_data["height"][text_index] - top = element_data["top"][text_index] - center_y = (top + top + text_height) / 2.0 - return float(center_y) - - -def get_text_center_y_list( - element_data: dict, -) -> List: - text_center_y_list = [] - element_num = len(element_data["text"]) - for text_id in range(element_num): - if element_data["text"][text_id] == "": - text_center_y_list.append(None) - else: - text_center_y_list.append(get_text_center_y(element_data, text_id)) - return text_center_y_list - - -def get_text_center_x( - element_data: dict, - text_index: int, - text_actual_width: float, -) -> float: - left = element_data["left"][text_index] - w = element_data["width"][text_index] - textAlign = element_data["text_align"][text_index] - right = left + w - if textAlign == 1: - center_x = (left + right) / 2.0 - elif textAlign == 3: - center_x = right - text_actual_width / 2.0 - elif textAlign == 2: - center_x = left + text_actual_width / 2.0 - return float(center_x) - - -def get_text_center_x_list( - element_data: dict, - text_actual_width: float, -) -> List: - svgid = element_data["id"] - text_center_x_list = [] - element_num = len(element_data["text"]) - for text_id in range(element_num): - if element_data["text"][text_id] == "": - text_center_x_list.append(None) - else: - _text_actual_width = text_actual_width[text_id] - text_center_x_list.append( - get_text_center_x(element_data, text_id, _text_actual_width) - ) - return text_center_x_list - - -def get_text_local_img( - img: Any, - text_center_y: float, - text_center_x: float, - H: int, - W: int, -) -> np.array: - text_center_y = text_center_y * H - text_center_x = text_center_x * W - - text_center_y = min(max(text_center_y, 0), H) - text_center_x = min(max(text_center_x, 0), W) - img = img.resize((640, 640)) - img = np.array(img) - local_img_size = 64 * 5 - local_img_size_half = local_img_size // 2 - img_pad = np.zeros((640 + local_img_size, 640 + local_img_size, 3)) - img_pad[ - local_img_size_half : 640 + local_img_size_half, - local_img_size_half : 640 + local_img_size_half, - ] = img - h_rate = 640 / float(H) - w_rate = 640 / float(W) - text_center_y = int(np.round(text_center_y * h_rate + local_img_size_half)) - text_center_x = int(np.round(text_center_x * w_rate + local_img_size_half)) - local_img = img_pad[ - text_center_y - local_img_size_half : text_center_y + local_img_size_half, - text_center_x - local_img_size_half : text_center_x + local_img_size_half, - ] - return local_img - - -def get_text_local_img_emb_list( - element_data: Dict, - bg_img: Any, - text_center_y: float, - text_center_x: float, -) -> List: - text_local_img_emb_list: List[List] - text_local_img_emb_list = [] - for k in range(len(element_data["text"])): - if element_data["text"][k] == "": - text_local_img_emb_list.append([]) - else: - H, W = bg_img.size[1], bg_img.size[0] - local_img = get_text_local_img( - bg_img.copy(), text_center_y[k], text_center_x[k], H, W - ) - local_img = Image.fromarray(local_img.astype(np.uint8)).resize((224, 224)) - inputs = processor(images=[local_img], return_tensors="pt") - inputs["pixel_values"] = inputs["pixel_values"].to(device) - image_feature = model.get_image_features(**inputs) - text_local_img_emb_list.append(image_feature.data.cpu().numpy()) - - return text_local_img_emb_list - - -def get_orderlist( - center_y: List[float], - center_x: List[float], -) -> List: - """ - Sort elments based on the raster scan order. - """ - center_y = [10000 if y is None else y for y in center_y] - center_x = [10000 if x is None else x for x in center_x] - center_y = np.array(center_y) - center_x = np.array(center_x) - sortedid = np.argsort(center_y * 1000 + center_x) - return list(sortedid) - - -def add_features( - element_data: Dict, -) -> Dict: - svgid = element_data["id"] - fn = os.path.join(data_dir, "generate_bg_png", f"{svgid}.png") - bg_img = Image.open(fn).convert("RGB") # background image - element_data["scale_box"] = get_scaleinfo(element_data) - element_data["canvas_bg_size"] = get_canvassize(bg_img) - element_data["canvas_bg_img_emb"] = get_canvas_bg_img_emb(bg_img) - element_data["text_emb"] = get_text_emb_list(element_data) - text_actual_width = _get_text_actual_width(element_data, bg_img.size[0]) - text_center_y = get_text_center_y_list(element_data) - text_center_x = get_text_center_x_list(element_data, text_actual_width) - element_data["text_center_y"] = text_center_y - element_data["text_center_x"] = text_center_x - element_data["text_actual_width"] = text_actual_width - element_data["text_local_img_emb"] = get_text_local_img_emb_list( - element_data, bg_img, text_center_y, text_center_x - ) - element_data["order_list"] = get_orderlist(text_center_y, text_center_x) - return element_data - - -def map_features( - _data_dir: str, -): - fn = os.path.join(_data_dir, "svgid2scaleinfo.pkl") - global svgid2scaleinfo - global processor - global text_tokenizer - global model - global device - global data_dir - global font2ttf - global fontmgr - global fontlabel2fontname - - dataset = datasets.load_dataset("cyberagent/crello", revision="3.1") - - data_dir = _data_dir - svgid2scaleinfo = pickle.load(open(fn, "rb")) - processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") - text_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") - model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model.to(device) - - fn = os.path.join(data_dir, "font2ttf.pkl") - _font2ttf = pickle.load(open(fn, "rb")) - font2ttf = {} - for key in _font2ttf.keys(): - tmp = _font2ttf[key].split("/data/dataset/crello/")[1] - fn = os.path.join(data_dir, tmp) - font2ttf[key] = fn - font2ttf = font2ttf - fontmgr = skia.FontMgr() - fontlabel2fontname = dataset["train"].features["font"].feature.int2str - - dataset_new = {} - for dataset_division in ["train", "validation", "test"]: - logger.info(f"{dataset_division=}") - _dataset = dataset[dataset_division] - dataset_new[dataset_division] = _dataset.map(add_features) - dataset_new = datasets.DatasetDict(dataset_new) - dataset_new.save_to_disk(f"{_data_dir}/crello_map_features") diff --git a/src/typography_generation/tools/__init__.py b/src/typography_generation/tools/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/src/typography_generation/tools/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/typography_generation/tools/color_func.py b/src/typography_generation/tools/color_func.py deleted file mode 100644 index 2675bbf..0000000 --- a/src/typography_generation/tools/color_func.py +++ /dev/null @@ -1,457 +0,0 @@ -from typing import Any, Tuple -import numpy as np - -_illuminants = { - "A": { - "2": (1.098466069456375, 1, 0.3558228003436005), - "10": (1.111420406956693, 1, 0.3519978321919493), - "R": (1.098466069456375, 1, 0.3558228003436005), - }, - "B": { - "2": (0.9909274480248003, 1, 0.8531327322886154), - "10": (0.9917777147717607, 1, 0.8434930535866175), - "R": (0.9909274480248003, 1, 0.8531327322886154), - }, - "C": { - "2": (0.980705971659919, 1, 1.1822494939271255), - "10": (0.9728569189782166, 1, 1.1614480488951577), - "R": (0.980705971659919, 1, 1.1822494939271255), - }, - "D50": { - "2": (0.9642119944211994, 1, 0.8251882845188288), - "10": (0.9672062750333777, 1, 0.8142801513128616), - "R": (0.9639501491621826, 1, 0.8241280285499208), - }, - "D55": { - "2": (0.956797052643698, 1, 0.9214805860173273), - "10": (0.9579665682254781, 1, 0.9092525159847462), - "R": (0.9565317453467969, 1, 0.9202554587037198), - }, - "D65": { - "2": (0.95047, 1.0, 1.08883), # This was: `lab_ref_white` - "10": (0.94809667673716, 1, 1.0730513595166162), - "R": (0.9532057125493769, 1, 1.0853843816469158), - }, - "D75": { - "2": (0.9497220898840717, 1, 1.226393520724154), - "10": (0.9441713925645873, 1, 1.2064272211720228), - "R": (0.9497220898840717, 1, 1.226393520724154), - }, - "E": {"2": (1.0, 1.0, 1.0), "10": (1.0, 1.0, 1.0), "R": (1.0, 1.0, 1.0)}, -} -xyz_from_rgb = np.array( - [ - [0.412453, 0.357580, 0.180423], - [0.212671, 0.715160, 0.072169], - [0.019334, 0.119193, 0.950227], - ] -) - - -def xyz_tristimulus_values(illuminant: str, observer: str, dtype: Any) -> np.array: - """Get the CIE XYZ tristimulus values. - - Given an illuminant and observer, this function returns the CIE XYZ tristimulus - values [2]_ scaled such that :math:`Y = 1`. - - Parameters - ---------- - illuminant : {"A", "B", "C", "D50", "D55", "D65", "D75", "E"} - The name of the illuminant (the function is NOT case sensitive). - observer : {"2", "10", "R"} - One of: 2-degree observer, 10-degree observer, or 'R' observer as in - R function ``grDevices::convertColor`` [3]_. - dtype: dtype, optional - Output data type. - - Returns - ------- - values : array - Array with 3 elements :math:`X, Y, Z` containing the CIE XYZ tristimulus values - of the given illuminant. - - Raises - ------ - ValueError - If either the illuminant or the observer angle are not supported or - unknown. - - References - ---------- - .. [1] https://en.wikipedia.org/wiki/Standard_illuminant#White_points_of_standard_illuminants - .. [2] https://en.wikipedia.org/wiki/CIE_1931_color_space#Meaning_of_X,_Y_and_Z - .. [3] https://www.rdocumentation.org/packages/grDevices/versions/3.6.2/topics/convertColor - - Notes - ----- - The CIE XYZ tristimulus values are calculated from :math:`x, y` [1]_, using the - formula - - .. math:: X = x / y - - .. math:: Y = 1 - - .. math:: Z = (1 - x - y) / y - - The only exception is the illuminant "D65" with aperture angle 2° for - backward-compatibility reasons. - - Examples - -------- - Get the CIE XYZ tristimulus values for a "D65" illuminant for a 10 degree field of - view - - >>> xyz_tristimulus_values(illuminant="D65", observer="10") - array([0.94809668, 1. , 1.07305136]) - """ - illuminant = illuminant.upper() - observer = observer.upper() - try: - return np.asarray(_illuminants[illuminant][observer], dtype=dtype) - except KeyError: - raise ValueError( - f"Unknown illuminant/observer combination " - f"(`{illuminant}`, `{observer}`)" - ) - - -def xyz2lab( - xyz: np.array, illuminant: str = "D65", observer: str = "2", channel_axis: int = -1 -) -> np.array: - """XYZ to CIE-LAB color space conversion. - - Parameters - ---------- - xyz : (..., 3, ...) array_like - The image in XYZ format. By default, the final dimension denotes - channels. - illuminant : {"A", "B", "C", "D50", "D55", "D65", "D75", "E"}, optional - The name of the illuminant (the function is NOT case sensitive). - observer : {"2", "10", "R"}, optional - One of: 2-degree observer, 10-degree observer, or 'R' observer as in - R function grDevices::convertColor. - channel_axis : int, optional - This parameter indicates which axis of the array corresponds to - channels. - - .. versionadded:: 0.19 - ``channel_axis`` was added in 0.19. - - Returns - ------- - out : (..., 3, ...) ndarray - The image in CIE-LAB format. Same dimensions as input. - - Raises - ------ - ValueError - If `xyz` is not at least 2-D with shape (..., 3, ...). - ValueError - If either the illuminant or the observer angle is unsupported or - unknown. - - Notes - ----- - By default Observer="2", Illuminant="D65". CIE XYZ tristimulus values - x_ref=95.047, y_ref=100., z_ref=108.883. See function - :func:`~.xyz_tristimulus_values` for a list of supported illuminants. - - References - ---------- - .. [1] http://www.easyrgb.com/en/math.php - .. [2] https://en.wikipedia.org/wiki/CIELAB_color_space - - Examples - -------- - >>> from skimage import data - >>> from skimage.color import rgb2xyz, xyz2lab - >>> img = data.astronaut() - >>> img_xyz = rgb2xyz(img) - >>> img_lab = xyz2lab(img_xyz) - """ - arr = np.array(xyz) - - xyz_ref_white = xyz_tristimulus_values( - illuminant=illuminant, observer=observer, dtype=arr.dtype - ) - - # scale by CIE XYZ tristimulus values of the reference white point - arr = arr / xyz_ref_white - - # Nonlinear distortion and linear transformation - mask = arr > 0.008856 - arr[mask] = np.cbrt(arr[mask]) - arr[~mask] = 7.787 * arr[~mask] + 16.0 / 116.0 - - x, y, z = arr[..., 0], arr[..., 1], arr[..., 2] - - # Vector scaling - L = (116.0 * y) - 16.0 - a = 500.0 * (x - y) - b = 200.0 * (y - z) - - return np.concatenate([x[..., np.newaxis] for x in [L, a, b]], axis=-1) - - -def rgb2xyz(rgb: np.array, channel_axis: int = -1) -> np.array: - """RGB to XYZ color space conversion. - - Parameters - ---------- - rgb : (..., 3, ...) array_like - The image in RGB format. By default, the final dimension denotes - channels. - channel_axis : int, optional - This parameter indicates which axis of the array corresponds to - channels. - - .. versionadded:: 0.19 - ``channel_axis`` was added in 0.19. - - Returns - ------- - out : (..., 3, ...) ndarray - The image in XYZ format. Same dimensions as input. - - Raises - ------ - ValueError - If `rgb` is not at least 2-D with shape (..., 3, ...). - - Notes - ----- - The CIE XYZ color space is derived from the CIE RGB color space. Note - however that this function converts from sRGB. - - References - ---------- - .. [1] https://en.wikipedia.org/wiki/CIE_1931_color_space - - Examples - -------- - >>> from skimage import data - >>> img = data.astronaut() - >>> img_xyz = rgb2xyz(img) - """ - # Follow the algorithm from http://www.easyrgb.com/index.php - # except we don't multiply/divide by 100 in the conversion - arr = np.array(rgb) - mask = arr > 0.04045 - arr[mask] = np.power((arr[mask] + 0.055) / 1.055, 2.4) - arr[~mask] /= 12.92 - return arr @ xyz_from_rgb.T.astype(arr.dtype) - - -def rgb2lab( - rgb: np.array, - illuminant: str = "D65", - observer: str = "2", - channel_axis: int = -1, -) -> np.array: - """Conversion from the sRGB color space (IEC 61966-2-1:1999) - to the CIE Lab colorspace under the given illuminant and observer. - - Parameters - ---------- - rgb : (..., 3, ...) array_like - The image in RGB format. By default, the final dimension denotes - channels. - illuminant : {"A", "B", "C", "D50", "D55", "D65", "D75", "E"}, optional - The name of the illuminant (the function is NOT case sensitive). - observer : {"2", "10", "R"}, optional - The aperture angle of the observer. - channel_axis : int, optional - This parameter indicates which axis of the array corresponds to - channels. - - .. versionadded:: 0.19 - ``channel_axis`` was added in 0.19. - - Returns - ------- - out : (..., 3, ...) ndarray - The image in Lab format. Same dimensions as input. - - Raises - ------ - ValueError - If `rgb` is not at least 2-D with shape (..., 3, ...). - - Notes - ----- - RGB is a device-dependent color space so, if you use this function, be - sure that the image you are analyzing has been mapped to the sRGB color - space. - - This function uses rgb2xyz and xyz2lab. - By default Observer="2", Illuminant="D65". CIE XYZ tristimulus values - x_ref=95.047, y_ref=100., z_ref=108.883. See function - :func:`~.xyz_tristimulus_values` for a list of supported illuminants. - - References - ---------- - .. [1] https://en.wikipedia.org/wiki/Standard_illuminant - """ - return xyz2lab(rgb2xyz(rgb), illuminant, observer) - - -def _cart2polar_2pi(x: np.array, y: np.array) -> Tuple[np.array, np.array]: - """convert cartesian coordinates to polar (uses non-standard theta range!) - - NON-STANDARD RANGE! Maps to ``(0, 2*pi)`` rather than usual ``(-pi, +pi)`` - """ - r, t = np.hypot(x, y), np.arctan2(y, x) - t += np.where(t < 0.0, 2 * np.pi, 0) - return r, t - - -def _float_inputs( - lab1: np.array, lab2: np.array, allow_float32: bool = True -) -> Tuple[np.array, np.array]: - lab1 = np.asarray(lab1) - lab2 = np.asarray(lab2) - float_dtype = np.float64 - lab1 = lab1.astype(float_dtype, copy=False) - lab2 = lab2.astype(float_dtype, copy=False) - return lab1, lab2 - - -def deltaE_ciede2000( - lab1: np.array, - lab2: np.array, - kL: int = 1, - kC: int = 1, - kH: int = 1, - channel_axis: int = -1, -) -> np.array: - """Color difference as given by the CIEDE 2000 standard. - - CIEDE 2000 is a major revision of CIDE94. The perceptual calibration is - largely based on experience with automotive paint on smooth surfaces. - - Parameters - ---------- - lab1 : array_like - reference color (Lab colorspace) - lab2 : array_like - comparison color (Lab colorspace) - kL : float (range), optional - lightness scale factor, 1 for "acceptably close"; 2 for "imperceptible" - see deltaE_cmc - kC : float (range), optional - chroma scale factor, usually 1 - kH : float (range), optional - hue scale factor, usually 1 - channel_axis : int, optional - This parameter indicates which axis of the arrays corresponds to - channels. - - .. versionadded:: 0.19 - ``channel_axis`` was added in 0.19. - - Returns - ------- - deltaE : array_like - The distance between `lab1` and `lab2` - - Notes - ----- - CIEDE 2000 assumes parametric weighting factors for the lightness, chroma, - and hue (`kL`, `kC`, `kH` respectively). These default to 1. - - References - ---------- - .. [1] https://en.wikipedia.org/wiki/Color_difference - .. [2] http://www.ece.rochester.edu/~gsharma/ciede2000/ciede2000noteCRNA.pdf - :DOI:`10.1364/AO.33.008069` - .. [3] M. Melgosa, J. Quesada, and E. Hita, "Uniformity of some recent - color metrics tested with an accurate color-difference tolerance - dataset," Appl. Opt. 33, 8069-8077 (1994). - """ - lab1, lab2 = _float_inputs(lab1, lab2, allow_float32=True) - - channel_axis = channel_axis % lab1.ndim - unroll = False - if lab1.ndim == 1 and lab2.ndim == 1: - unroll = True - if lab1.ndim == 1: - lab1 = lab1[None, :] - if lab2.ndim == 1: - lab2 = lab2[None, :] - channel_axis += 1 - L1, a1, b1 = np.moveaxis(lab1, source=channel_axis, destination=0)[:3] - L2, a2, b2 = np.moveaxis(lab2, source=channel_axis, destination=0)[:3] - - # distort `a` based on average chroma - # then convert to lch coordinates from distorted `a` - # all subsequence calculations are in the new coordinates - # (often denoted "prime" in the literature) - Cbar = 0.5 * (np.hypot(a1, b1) + np.hypot(a2, b2)) - c7 = Cbar**7 - G = 0.5 * (1 - np.sqrt(c7 / (c7 + 25**7))) - scale = 1 + G - C1, h1 = _cart2polar_2pi(a1 * scale, b1) - C2, h2 = _cart2polar_2pi(a2 * scale, b2) - # recall that c, h are polar coordinates. c==r, h==theta - - # cide2000 has four terms to delta_e: - # 1) Luminance term - # 2) Hue term - # 3) Chroma term - # 4) hue Rotation term - - # lightness term - Lbar = 0.5 * (L1 + L2) - tmp = (Lbar - 50) ** 2 - SL = 1 + 0.015 * tmp / np.sqrt(20 + tmp) - L_term = (L2 - L1) / (kL * SL) - - # chroma term - Cbar = 0.5 * (C1 + C2) # new coordinates - SC = 1 + 0.045 * Cbar - C_term = (C2 - C1) / (kC * SC) - - # hue term - h_diff = h2 - h1 - h_sum = h1 + h2 - CC = C1 * C2 - - dH = h_diff.copy() - dH[h_diff > np.pi] -= 2 * np.pi - dH[h_diff < -np.pi] += 2 * np.pi - dH[CC == 0.0] = 0.0 # if r == 0, dtheta == 0 - dH_term = 2 * np.sqrt(CC) * np.sin(dH / 2) - - Hbar = h_sum.copy() - mask = np.logical_and(CC != 0.0, np.abs(h_diff) > np.pi) - Hbar[mask * (h_sum < 2 * np.pi)] += 2 * np.pi - Hbar[mask * (h_sum >= 2 * np.pi)] -= 2 * np.pi - Hbar[CC == 0.0] *= 2 - Hbar *= 0.5 - - T = ( - 1 - - 0.17 * np.cos(Hbar - np.deg2rad(30)) - + 0.24 * np.cos(2 * Hbar) - + 0.32 * np.cos(3 * Hbar + np.deg2rad(6)) - - 0.20 * np.cos(4 * Hbar - np.deg2rad(63)) - ) - SH = 1 + 0.015 * Cbar * T - - H_term = dH_term / (kH * SH) - - # hue rotation - c7 = Cbar**7 - Rc = 2 * np.sqrt(c7 / (c7 + 25**7)) - dtheta = np.deg2rad(30) * np.exp(-(((np.rad2deg(Hbar) - 275) / 25) ** 2)) - R_term = -np.sin(2 * dtheta) * Rc * C_term * H_term - - # put it all together - dE2 = L_term**2 - dE2 += C_term**2 - dE2 += H_term**2 - dE2 += R_term - ans = np.sqrt(np.maximum(dE2, 0)) - if unroll: - ans = ans[0] - return ans diff --git a/src/typography_generation/tools/denormalizer.py b/src/typography_generation/tools/denormalizer.py deleted file mode 100644 index d0d06d2..0000000 --- a/src/typography_generation/tools/denormalizer.py +++ /dev/null @@ -1,154 +0,0 @@ -from typing import Dict, List, Tuple, Union - -import numpy as np -from typography_generation.io.crello_util import CrelloProcessor -from typography_generation.io.data_object import DesignContext -from logzero import logger - - -class Denormalizer: - def __init__(self, dataset: CrelloProcessor): - self.dataset = dataset - - def denormalize( - self, - prefix: str, - text_num: int, - prediction: np.array, - design_context: DesignContext, - ) -> Tuple: - pred = prediction[prefix] - gt = design_context.get_data(prefix) - pred_token = [] - gt_token = [] - pred_denorm = [] - gt_denorm = [] - canvas_img_size_h, canvas_img_size_w = design_context.canvas_context.img_size - canvas_h_scale, canvas_w_scale = design_context.canvas_context.scale_box - for t in range(text_num): - g = gt[t] - if prefix in self.dataset.tokenizer.prefix_list: - p_token = pred[t] - p = self.dataset.tokenizer.detokenize(prefix, p_token) - g_token = self.dataset.tokenizer.tokenize(prefix, g) - elif prefix in self.dataset.tokenizer.rawdata_list: - p = pred[t] - p_token = getattr(self.dataset, f"raw2token_{prefix}")(p) - g_token = getattr(self.dataset, f"raw2token_{prefix}")(g) - if self.dataset.tokenizer.rawdata_out_format[prefix] == "token": - p = p_token - g = g_token - else: - p = pred[t] - else: - p_token = pred[t] - p = p_token - g_token = g - p = self.denormalize_elm( - prefix, - p, - canvas_img_size_h, - canvas_img_size_w, - canvas_h_scale, - canvas_w_scale, - ) - g = self.denormalize_elm( - prefix, - g, - canvas_img_size_h, - canvas_img_size_w, - canvas_h_scale, - canvas_w_scale, - ) - pred_token.append([p_token]) - gt_token.append(g_token) - pred_denorm.append([p]) - gt_denorm.append(g) - return pred_token, gt_token, pred_denorm, gt_denorm - - def denormalize_gt( - self, - prefix: str, - text_num: int, - design_context: DesignContext, - ) -> Tuple: - gt = design_context.get_data(prefix) - gt_token = [] - gt_denorm = [] - canvas_img_size_h, canvas_img_size_w = design_context.canvas_context.img_size - canvas_h_scale, canvas_w_scale = design_context.canvas_context.scale_box - logger.debug(f"{prefix} {gt}") - for t in range(text_num): - g = gt[t] - if prefix in self.dataset.tokenizer.prefix_list: - g_token = self.dataset.tokenizer.tokenize(prefix, g) - elif prefix in self.dataset.tokenizer.rawdata_list: - g_token = getattr(self.dataset, f"raw2token_{prefix}")(g) - if self.dataset.tokenizer.rawdata_out_format[prefix] == "token": - g = g_token - else: - g_token = g - g = self.denormalize_elm( - prefix, - g, - canvas_img_size_h, - canvas_img_size_w, - canvas_h_scale, - canvas_w_scale, - ) - gt_token.append(g_token) - gt_denorm.append(g) - return gt_token, gt_denorm - - def denormalize_elm( - self, - prefix: str, - val: Union[int, float, Tuple], - canvas_img_size_h: int, - canvas_img_size_w: int, - canvas_h_scale: float, - canvas_w_scale: float, - ) -> Union[int, float, Tuple]: - if hasattr(self.dataset, f"denorm_{prefix}"): - func = getattr(self.dataset, f"denorm_{prefix}") - data_info = { - "val": val, - "img_height": canvas_img_size_h, - "img_width": canvas_img_size_w, - "scale_h": canvas_h_scale, - "scale_w": canvas_w_scale, - } - val = func(**data_info) - return val - - def convert_attributes( - self, - prefix: str, - pred: List, - element_data: Dict, - text_ids: List, - scale_h: float, - ) -> Dict: - converter = getattr(self.dataset, f"convert_{prefix}") - converted_attributes = [] - for i, text_index in enumerate(text_ids): - inputs = { - "val": pred[i][0], - "element_data": element_data, - "text_index": text_index, - "scale_h": scale_h, - } - _converted_attributes = converter(**inputs) - converted_attributes.append(_converted_attributes) - if len(text_ids) > 0: - converted_attribute_prefixes = list(_converted_attributes.keys()) - converted_attribute_dict = {} - for prefix in converted_attribute_prefixes: - attribute_data = [] - for _converted_attributes in converted_attributes: - val = _converted_attributes[prefix] - attribute_data.append([val]) - converted_attribute_dict[prefix] = attribute_data - return converted_attribute_dict - else: - return None diff --git a/src/typography_generation/tools/evaluator.py b/src/typography_generation/tools/evaluator.py deleted file mode 100644 index 1099221..0000000 --- a/src/typography_generation/tools/evaluator.py +++ /dev/null @@ -1,225 +0,0 @@ -import os -import pickle -import time -from typing import Dict, List - -import torch -import torch.nn as nn -import torch.utils.data -from logzero import logger -from typography_generation.io.data_loader import CrelloLoader -from typography_generation.io.data_object import ( - DesignContext, - ModelInput, - PrefixListObject, -) -from typography_generation.tools.denormalizer import Denormalizer -from typography_generation.tools.score_func import ( - EvalDataEntire, - EvalDataInstance, -) -from typography_generation.tools.train import collate_batch -from typography_generation.visualization.renderer import TextRenderer -from typography_generation.visualization.visualizer import ( - get_text_ids, -) - -show_classfication_score_att = [ - "text_font", - "text_font_emb", - "text_align_type", - "text_capitalize", -] -show_abs_erros_att = [ - "text_font_size", - "text_font_size_raw", - "text_center_y", - "text_center_x", - "text_letter_spacing", - "text_angle", - "text_line_height_scale", -] - -show_bigram_score_att = [ - "text_font", - "text_font_emb", - "text_font_color", - "text_align_type", - "text_capitalize", - "text_font_size", - "text_font_size_raw", - "text_center_y", - "text_center_x", - "text_angle", - "text_letter_spacing", - "text_line_height_scale", -] - - -############################################################ -# Evaluator -############################################################ -class Evaluator: - def __init__( - self, - model: nn.Module, - gpu: bool, - save_dir: str, - dataset: CrelloLoader, - prefix_list_object: PrefixListObject, - batch_size: int = 1, - num_worker: int = 2, - show_interval: int = 100, - dataset_division: str = "test", - save_file_prefix: str = "score", - debug: bool = False, - ) -> None: - self.gpu = gpu - self.save_dir = save_dir - self.batch_size = batch_size - self.num_worker = num_worker - self.show_interval = show_interval - self.debug = debug - self.prefix_list_target = prefix_list_object.target - self.dataset_division = dataset_division - - self.dataset = dataset - - self.model = model - if gpu is True: - self.model.cuda() - - self.entire_data = EvalDataEntire( - self.prefix_list_target, - save_dir, - save_file_prefix=save_file_prefix, - ) - self.denormalizer = Denormalizer(self.dataset.dataset) - - self.save_data: Dict[str, Dict[str, List]] - self.save_data = dict() - - fontlabel2fontname = dataset.dataset.dataset.features["font"].feature.int2str - - self.renderer = TextRenderer(dataset.data_dir, fontlabel2fontname) - - def eval_model(self) -> None: - # Data generators - dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_worker, - pin_memory=True, - collate_fn=collate_batch, - ) - with torch.no_grad(): - self.steps = len(dataloader) - self.step = 0 - self.cnt = 0 - self.model.eval() - end = time.time() - for inputs in dataloader: - ( - design_context_list, - model_input_batchdata, - svg_id, - index, - ) = inputs - self.eval_step( - design_context_list, - model_input_batchdata, - svg_id, - index, - end, - ) - end = time.time() - self.step += 1 - - self.show_scores() - if self.dataset_division == "test": - self.save_prediction() - - def eval_step( - self, - design_context_list: List[DesignContext], - model_input_batchdata: List, - svg_id: List, - index: List, - end: float, - ) -> None: - start = time.time() - model_inputs = ModelInput(design_context_list, model_input_batchdata, self.gpu) - predictions = self.model.prediction( - model_inputs, self.dataset.dataset, self.prefix_list_target - ) - data_index = svg_id[0] - - self.instance_data = EvalDataInstance(self.prefix_list_target) - - text_num = design_context_list[0].canvas_context.canvas_text_num - - self.save_data[data_index] = dict() - for prefix in self.prefix_list_target: - pred_token, gt_token, pred, gt = self.denormalizer.denormalize( - prefix, text_num, predictions, design_context_list[0] - ) - self.instance_data.rigister_att( - text_num, - prefix, - pred_token, - gt_token, - pred, - gt, - ) - self.entire_data.update_prediction_data( - data_index, self.instance_data, f"{prefix}" - ) - self.save_data[data_index][prefix] = pred - if hasattr(self.denormalizer.dataset, f"convert_{prefix}"): - element_data = self.dataset.dataset.dataset[index[0]] - text_ids = get_text_ids(element_data) - canvas_h_scale = design_context_list[0].canvas_context.scale_box[0] - converted_attribute_dict = self.denormalizer.convert_attributes( - prefix, - pred, - element_data, - text_ids, - canvas_h_scale, - ) - if converted_attribute_dict is not None: - for prefix, v in converted_attribute_dict.items(): - self.save_data[data_index][prefix] = v - self.entire_data.text_num[data_index] = text_num - - forward_time = time.time() - if self.step % 200 == 0: - data_show = "{}/{}/{}, forward_time: {:.3f} data {:.3f}".format( - self.cnt, - self.step + 1, - self.steps, - forward_time - start, - (start - end), - ) - logger.info(data_show) - - def show_scores(self) -> None: - for prefix in show_classfication_score_att: - if prefix in self.prefix_list_target: - self.entire_data.show_classification_score( - prefix, topk=5, show_topk=[0, 2, 4] - ) - for prefix in show_abs_erros_att: - if prefix in self.prefix_list_target: - self.entire_data.show_abs_erros(prefix) - if "text_font_color" in self.prefix_list_target: - self.entire_data.show_font_color_scores() - for prefix in show_bigram_score_att: - if prefix in self.prefix_list_target: - self.entire_data.show_structure_score(prefix) - self.entire_data.show_alpha_overlap_score() - - def save_prediction(self) -> None: - file_name = os.path.join(self.save_dir, "prediction.pkl") - with open(file_name, mode="wb") as f: - pickle.dump(self.save_data, f) diff --git a/src/typography_generation/tools/loss.py b/src/typography_generation/tools/loss.py deleted file mode 100644 index 2960fa6..0000000 --- a/src/typography_generation/tools/loss.py +++ /dev/null @@ -1,194 +0,0 @@ -from typing import Dict, List, Tuple, Union - -import numpy as np -import torch -from logzero import logger -from torch import Tensor, nn -from torch.functional import F -from typography_generation.config.attribute_config import ( - TextElementContextPredictionAttributeConfig, -) -from typography_generation.io.data_object import ModelInput -from typography_generation.model.common import Linearx3 - - -class LossFunc(object): - def __init__( - self, - model_name: str, - prefix_list_target: List, - prediction_config_element: TextElementContextPredictionAttributeConfig, - gpu: bool, - topk: int = 5, - ) -> None: - super(LossFunc, self).__init__() - self.prefix_list_target = prefix_list_target - self.topk = topk - for prefix in self.prefix_list_target: - target_prediction_config = getattr(prediction_config_element, prefix) - setattr( - self, f"{prefix}_ignore_label", target_prediction_config.ignore_label - ) - setattr(self, f"{prefix}_loss_type", target_prediction_config.loss_type) - if model_name == "canvasvae": - self.vae_weight = 0 - self.fix_vae_weight = 0.002 - if model_name == "mfc": - self.d = Linearx3(40, 1) - if gpu is True: - self.d.cuda() - self.d_optimizer = torch.optim.AdamW( - self.d.parameters(), - lr=0.0002, - betas=(0.5, 0.999), - weight_decay=0.01, - ) - - def get_loss( - self, prefix: str, pred: Tensor, gt: Tensor, loss_type: str, training: bool - ) -> Tensor: - if loss_type == "cre": - loc = gt != getattr(self, f"{prefix}_ignore_label") - pred = pred[loc] - gt = gt[loc] - loss = F.cross_entropy(pred, gt.long()) - elif loss_type == "l1": - loc = gt != getattr(self, f"{prefix}_ignore_label") - pred = pred[loc] - gt = gt[loc] - logger.debug(f"get loss {prefix} {loss_type} {gt}") - loss = F.l1_loss(pred.reshape(gt.shape), gt.float()) - elif loss_type == "mfc_gan": - fake_output = self.d(pred.reshape(gt.shape).detach()) - real_output = self.d(gt) - loc = gt[:, :, 0:1] != getattr(self, f"{prefix}_ignore_label") - if training is True: - real_output = real_output[loc] - Ldadv = ( - fake_output.mean() - - real_output.mean() # + self.lambda_gp * gradient_penalty - ) - self.d_optimizer.zero_grad() - Ldadv.backward() - self.d_optimizer.step() - fake_output = self.d(pred.reshape(gt.shape)) - fake_output = fake_output[loc] - Lsadv = -fake_output.mean() - loc = gt != getattr(self, f"{prefix}_ignore_label") - pred = pred[loc] - gt = gt[loc] - loss = 10 * F.mse_loss(pred.reshape(gt.shape), gt.float()) + Lsadv - else: - raise NotImplementedError() - return loss - - def update_vae_weight(self, epoch: int, epochs: int, step: int, steps: int) -> None: - logger.debug("update vae weight") - self.vae_weight = (epoch + step / steps) / epochs - logger.debug(f"vae weight value {self.vae_weight}") - - def vae_loss(self, mu: Tensor, logsigma: Tensor) -> Tuple: - loss_kl = -0.5 * torch.mean( - 1 + logsigma - mu.pow(2) - torch.exp(logsigma) - ) # vae kl divergence - loss_kl_weighted = loss_kl * self.vae_weight * self.fix_vae_weight - return loss_kl_weighted, loss_kl - - def get_vae_loss(self, vae_items: Tuple) -> Tensor: - loss_kl_weighted = 0 - for mu, logsigma in vae_items: - loss, _ = self.vae_loss(mu, logsigma) - loss_kl_weighted += loss - return loss_kl_weighted # +self.config.VAE_L2_LOSS_WEIGHT*loss_l2 - - def __call__(self, model_inputs: ModelInput, preds: Dict, training: bool) -> Tuple: - total_loss = 0 - record_items = {} - - for prefix in self.prefix_list_target: - pred = preds[prefix] - gt = getattr(model_inputs, prefix) - loss_type = getattr(self, f"{prefix}_loss_type") - loss = self.get_loss( - prefix, - pred, - gt, - loss_type, - training, - ) - pred_label, gt_data = self.get_pred_gt_label(prefix, pred, gt, loss_type) - record_items[prefix] = (pred_label, gt_data, loss.item()) - total_loss += loss - - if "vae_data" in preds.keys(): - logger.debug("compute vae loss") - vae_loss = self.get_vae_loss(preds["vae_data"]) - total_loss = total_loss + vae_loss - logger.debug(f"vae loss {vae_loss}") - return total_loss, record_items - - def get_pred_gt_label( - self, prefix: str, pred: Tensor, gt: Tensor, loss_type: str - ) -> Tuple[List, Union[List, np.array]]: - loc = gt != getattr(self, f"{prefix}_ignore_label") - pred = pred[loc] - gt = gt[loc] - pred_label = self.get_pred_label(pred, loss_type) - gt_label = self.get_gt_label(gt, loss_type) - return pred_label, gt_label - - def get_pred_label( - self, - pred: Tensor, - loss_type: str = "cre", - ) -> List: - if loss_type == "cre": - pred = torch.sort(input=pred, dim=1, descending=True)[1].data.cpu().numpy() - preds = [] - for i in range(len(pred)): - p = [] - for k in range(min(self.topk, pred.shape[1])): - p.append(pred[i, k]) - preds.append(p) - elif loss_type == "l1": - preds = [] - for i in range(len(pred)): - p = [] - for k in range(self.topk): - p.append(0) # dummy data - preds.append(p) - elif loss_type == "mfc_gan": - preds = [] - for i in range(len(pred)): - p = [] - for k in range(self.topk): - p.append(0) # dummy data - preds.append(p) - else: - raise NotImplementedError() - return preds - - def get_gt_label( - self, - gt: Tensor, - loss_type: str = "cre", - ) -> Union[List, np.array]: - if loss_type == "cre": - gt_labels = gt.data.cpu().numpy() - elif loss_type == "l1": - gt_labels = [] - for i in range(len(gt)): - g = [] - for k in range(self.topk): - g.append(0) # dummy data - gt_labels.append(g) - elif loss_type == "mfc_gan": - gt_labels = [] - for i in range(len(gt)): - g = [] - for k in range(self.topk): - g.append(0) # dummy data - gt_labels.append(g) - else: - raise NotImplementedError() - return gt_labels diff --git a/src/typography_generation/tools/prediction_recorder.py b/src/typography_generation/tools/prediction_recorder.py deleted file mode 100644 index 9743990..0000000 --- a/src/typography_generation/tools/prediction_recorder.py +++ /dev/null @@ -1,141 +0,0 @@ -from typing import Dict, List -from logzero import logger - -############################################################ -# Prediction Recoder -############################################################ - - -class PredictionRecoder: - def __init__(self, prefix_list_target: List, topk: int = 5): - super(PredictionRecoder, self).__init__() - self.prefix_list_target = prefix_list_target - self.topk = topk - self.show_topk = list(range(self.topk)) - self.register(prefix_list_target) - self.epoch = 0 - - def register(self, prefix_list_target: List) -> None: - self.all_target_list = [] - self.cl_target_list = [] - self.loss_target_list = [] - for prefix in prefix_list_target: - self.all_target_list.append(f"{prefix}_cl") - self.all_target_list.append(f"{prefix}_loss") - self.cl_target_list.append(f"{prefix}_cl") - self.loss_target_list.append(f"{prefix}_loss") - self.regist_cl_recorder_set(self.cl_target_list) - self.regist_recorder_set(self.loss_target_list) - - def regist_recorder_set(self, target_list: List) -> None: - for tar in target_list: - self.regist_record_target(tar) - - def regist_cl_recorder_set(self, target_list: List) -> None: - for tar in target_list: - self.regist_cl_recorder(tar) - - def regist_cl_recorder(self, registration_name: str) -> None: - self.regist_record_target(f"{registration_name}_pred") - self.regist_record_target(f"{registration_name}_gt") - - def regist_record_target(self, registration_name: str) -> None: - setattr(self, registration_name, []) - setattr(self, f"{registration_name}_history", []) - for k in range(len(self.show_topk)): - setattr(self, f"{registration_name}_history_{self.show_topk[k]}", []) - - def __call__(self, recoder_items: Dict) -> None: - for name, (pred, gt, loss) in recoder_items.items(): - clname = f"{name}_cl" - lossname = f"{name}_loss" - if clname in self.cl_target_list: - for p, g in zip(pred, gt): - getattr(self, f"{clname}_pred").append(p) - getattr(self, f"{clname}_gt").append(g) - if lossname in self.loss_target_list: - getattr(self, lossname).append(loss) - - def reset(self) -> None: - for name in self.all_target_list: - if name in self.cl_target_list: - setattr(self, f"{name}_pred", []) - setattr(self, f"{name}_gt", []) - if name in self.loss_target_list: - setattr(self, name, []) - - def store_score(self) -> None: - for name in self.all_target_list: - if name in self.cl_target_list: - for k in range(len(self.show_topk)): - getattr(self, f"{name}_pred_history_{self.show_topk[k]}").append( - getattr(self, f"{name}{self.show_topk[k]}_acc") - ) - if name in (self.loss_target_list): - getattr(self, f"{name}_history").append(getattr(self, f"{name}_mean")) - - def compute_score(self) -> None: - for name in self.all_target_list: - if name in self.cl_target_list: - topkacc = {} - for k in range(self.topk): - topkacc[k] = 0 - for p, g in zip( - getattr(self, f"{name}_pred"), getattr(self, f"{name}_gt") - ): - flag = 0 - for k in range(min(self.topk, len(p))): - if p[k] == g: - flag = 1 - topkacc[k] += flag - for k in range(min(self.topk, len(topkacc))): - setattr( - self, - f"{name}{k}_acc", - topkacc[k] / max(len(getattr(self, f"{name}_pred")), 1), - ) - if name in (self.loss_target_list): - mean_loss = 0 - for l in getattr(self, f"{name}"): - mean_loss += l - setattr( - self, - f"{name}_mean", - mean_loss / max(len(getattr(self, f"{name}")), 1), - ) - score_dict = self.show_scores() - return score_dict - - def show_scores(self) -> None: - score_dict = {} - for name in self.all_target_list: - if name in self.cl_target_list: - for k in range(len(self.show_topk)): - logger.info( - f"{name}{self.show_topk[k]}_acc:{getattr(self, f'{name}{self.show_topk[k]}_acc')}" - ) - score_dict[name] = getattr(self, f"{name}{self.show_topk[0]}_acc") - if name in self.loss_target_list: - logger.info(f"{name}_mean:{getattr(self, f'{name}_mean')}") - score_dict[name] = getattr(self, f"{name}_mean") - return score_dict - - def show_history_scores(self) -> None: - for i in range(self.epoch): - logger.info(f"epoch {i}") - for name in self.all_target_list: - if name in self.cl_target_list: - for k in range(len(self.show_topk)): - logger.info( - f"{name}{self.show_topk[k]}acc:{getattr(self, f'{name}_pred_history_{self.show_topk[k]}')[i]}" - ) - if name in self.loss_target_list: - logger.info(f"{name}_mean:{getattr(self, f'{name}_history')[i]}") - - def step_epoch(self) -> None: - self.compute_score() - self.store_score() - self.update_epoch() - - def update_epoch(self) -> None: - self.epoch += 1 diff --git a/src/typography_generation/tools/sampler.py b/src/typography_generation/tools/sampler.py deleted file mode 100644 index 1808e32..0000000 --- a/src/typography_generation/tools/sampler.py +++ /dev/null @@ -1,150 +0,0 @@ -import time -from typing import List - -import torch -import torch.nn as nn -import torch.utils.data -from typography_generation.io.data_loader import CrelloLoader -from typography_generation.tools.score_func import EvalDataInstance -from typography_generation.io.data_object import ( - DataPreprocessConfig, - DesignContext, - FontConfig, - ModelInput, - PrefixListObject, - SamplingConfig, -) -from logzero import logger - -from typography_generation.tools.train import collate_batch - -from typography_generation.tools.evaluator import ( - Evaluator, -) - - -############################################################ -# Sampler -############################################################ -class Sampler(Evaluator): - def __init__( - self, - model: nn.Module, - gpu: bool, - save_dir: str, - dataset: CrelloLoader, - prefix_list_object: PrefixListObject, - sampling_config: SamplingConfig, - batch_size: int = 1, - num_worker: int = 2, - show_interval: int = 100, - dataset_division: str = "test", - debug: bool = False, - ) -> None: - super().__init__( - model, - gpu, - save_dir, - dataset, - prefix_list_object, - batch_size=batch_size, - num_worker=num_worker, - show_interval=show_interval, - dataset_division=dataset_division, - debug=debug, - ) - self.sampling_config = sampling_config - - def sample(self) -> None: - # Data generators - dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_worker, - pin_memory=True, - collate_fn=collate_batch, - ) - with torch.no_grad(): - self.steps = len(dataloader) - self.step = 0 - self.cnt = 0 - self.model.eval() - end = time.time() - for inputs in dataloader: - design_context_list, model_input_batchdata, svg_id, _ = inputs - self.sample_step( - design_context_list, model_input_batchdata, svg_id, end - ) - end = time.time() - self.step += 1 - - self.show_scores() - self.entire_data.show_diversity_scores(self.prefix_list_target) - self.save_prediction() - - def sample_step( - self, - design_context_list: List[DesignContext], - model_input_batchdata: List, - svg_id: List, - end: float, - ) -> None: - _data_index = svg_id[0] - self.entire_data.data_index_list.append(_data_index) - for iter in range(self.sampling_config.sampling_num): - data_index = f"{_data_index}_{iter}" - self.sample_iter( - design_context_list, model_input_batchdata, data_index, end - ) - - def sample_iter( - self, - design_context_list: List[DesignContext], - model_input_batchdata: List, - data_index: str, - end: float, - ) -> None: - start = time.time() - model_inputs = ModelInput(design_context_list, model_input_batchdata, self.gpu) - sampler_input = { - "model_inputs": model_inputs, - "target_prefix_list": self.prefix_list_target, - "sampling_param_geometry": self.sampling_config.sampling_param_geometry, - "sampling_param_semantic": self.sampling_config.sampling_param_semantic, - } - predictions = self.model.sample(**sampler_input) - - self.instance_data = EvalDataInstance(self.prefix_list_target) - - text_num = design_context_list[0].canvas_context.canvas_text_num - - self.save_data[data_index] = dict() - for prefix in self.prefix_list_target: - pred_token, gt_token, pred, gt = self.denormalizer.denormalize( - prefix, text_num, predictions, design_context_list[0] - ) - self.instance_data.rigister_att( - text_num, - prefix, - pred_token, - gt_token, - pred, - gt, - ) - self.entire_data.update_prediction_data( - data_index, self.instance_data, f"{prefix}" - ) - self.save_data[data_index][prefix] = pred - self.entire_data.text_num[data_index] = text_num - - forward_time = time.time() - if self.step % 200 == 0: - data_show = "{}/{}/{}, forward_time: {:.3f} data {:.3f}".format( - self.cnt, - self.step + 1, - self.steps, - forward_time - start, - (start - end), - ) - logger.info(data_show) diff --git a/src/typography_generation/tools/score_func.py b/src/typography_generation/tools/score_func.py deleted file mode 100644 index f8a1709..0000000 --- a/src/typography_generation/tools/score_func.py +++ /dev/null @@ -1,657 +0,0 @@ -import itertools -from typing import Dict, List, Tuple - -import numpy as np -from _io import TextIOWrapper -from logzero import logger as log -from typography_generation.tools.color_func import deltaE_ciede2000, rgb2lab - - -class EvalDataInstance: - def __init__( - self, - attribute_list: List, - ) -> None: - self.attribute_list = attribute_list - self.target_list = ["pred_token", "gt_token", "pred", "gt"] - self.reset() - - def reset(self) -> None: - for att in self.attribute_list: - for tar in self.target_list: - registration_name = f"{att}_{tar}" - setattr(self, registration_name, []) - - def rigister_att( - self, - text_num: int, - prefix: str, - target_pred_token: np.array, - target_gt_token: np.array, - target_pred: np.array, - target_gt: np.array, - start_index: int = 0, - ) -> None: - for i in range(start_index, text_num): - registration_name = f"{prefix}_pred_token" - getattr(self, registration_name).append(target_pred_token[i]) - registration_name = f"{prefix}_gt_token" - getattr(self, registration_name).append(target_gt_token[i]) - registration_name = f"{prefix}_pred" - getattr(self, registration_name).append(target_pred[i]) - registration_name = f"{prefix}_gt" - getattr(self, registration_name).append(target_gt[i]) - - -class EvalDataEntire: - def __init__( - self, attribute_list: List, save_dir: str, save_file_prefix: str = "score" - ) -> None: - self.attribute_list = attribute_list - self.text_num = {} - self.overlap_scores = {} - self.data_index_list = [] - self.target_list = ["pred_token", "gt_token", "pred", "gt"] - for att in self.attribute_list: - for tar in self.target_list: - registration_name = f"{att}_{tar}" - setattr(self, registration_name, {}) - - save_file = f"{save_dir}/{save_file_prefix}.txt" - self.f = open(save_file, "w") - - def update_prediction_data( - self, - index: str, - instance_obj: EvalDataInstance, - prefix: str, - ) -> None: - getattr(self, f"{prefix}_pred_token")[index] = getattr( - instance_obj, f"{prefix}_pred_token" - ) - getattr(self, f"{prefix}_gt_token")[index] = getattr( - instance_obj, f"{prefix}_gt_token" - ) - getattr(self, f"{prefix}_pred")[index] = getattr(instance_obj, f"{prefix}_pred") - getattr(self, f"{prefix}_gt")[index] = getattr(instance_obj, f"{prefix}_gt") - - def update_sampling_data( - self, - index: str, - instance_obj: EvalDataInstance, - prefix: str, - ) -> None: - primary_index, sub_index = index.split("_") - if int(sub_index) > 0: - getattr(self, prefix)[primary_index].append(getattr(instance_obj, prefix)) - else: - getattr(self, prefix)[primary_index] = [] - getattr(self, prefix)[primary_index].append(getattr(instance_obj, prefix)) - - def show_classification_score( - self, att: str, topk: int = 10, show_topk: List = [0, 5] - ) -> None: - log.info(f"{att}") - self.f.write(f"{att} \n") - compute_score( - getattr(self, f"{att}_pred"), - getattr(self, f"{att}_gt"), - topk, - show_topk, - self.f, - ) - - def show_abs_erros( - self, prefix: str, blanktype: str = "", topk: int = 10, show_topk: List = [0, 5] - ) -> None: - log.info(f"{prefix}") - self.f.write(f"{prefix} \n") - compute_abs_error_score( - getattr(self, f"{prefix}_pred"), - getattr(self, f"{prefix}_gt"), - self.f, - ) - - def show_font_color_scores( - self, blanktype: str = "", topk: int = 10, show_topk: List = [0, 5] - ) -> None: - log.info(f"font_color") - self.f.write(f"font_color \n") - compute_color_score( - getattr(self, f"text_font_color_pred"), - getattr(self, f"text_font_color_gt"), - self.f, - ) - - def show_structure_score(self, att: str) -> None: - log.info(f"{att}") - self.f.write(f"{att} \n") - compute_bigram_score( - getattr(self, f"text_num"), - getattr(self, f"{att}_pred_token"), - getattr(self, f"{att}_gt_token"), - self.f, - ) - - def show_visual_similarity_scores(self) -> None: - registration_name = "l2error" - dict_average_score(registration_name, getattr(self, registration_name)) - registration_name = "psnr" - dict_average_score(registration_name, getattr(self, registration_name)) - - def show_time_score(self) -> None: - registration_name = "time" - dict_average_score(registration_name, getattr(self, registration_name)) - - def show_diversity_scores(self, attribute_list: List) -> None: - for att in attribute_list: - log.info(f"{att}") - self.f.write(f"{att} \n") - compute_label_diversity_score( - getattr(self, "data_index_list"), - getattr(self, f"text_num"), - getattr(self, f"{att}_pred_token"), - self.f, - ) - - def show_alpha_overlap_score(self) -> None: - compute_alpha_overlap( - getattr(self, "overlap_scores"), - self.f, - ) - - -def compute_abs_error_score( - eval_list: dict, - gt_list: dict, - f: TextIOWrapper, -) -> None: - cnt_img = 0 - l1_distance = 0.0 - r_l1_distance = 0.0 - for index in eval_list.keys(): - g = gt_list[index] - if len(g) > 0: - e = eval_list[index] - d = 0.0 - rd = 0.0 - for i, (pred, gt) in enumerate(zip(e, g)): - pred = pred[0] - _d = abs(gt - pred) - _rd = abs(gt - pred) / max(abs(float(gt)), 1e-5) - d += _d - rd += _rd - l1_distance += d / len(g) - r_l1_distance += rd / len(g) - cnt_img += 1 - l1_distance /= cnt_img - log.info(f"l1_distance {l1_distance}") - f.write(f"l1_distance {l1_distance} \n") - r_l1_distance /= cnt_img - log.info(f"r_l1_distance {r_l1_distance}") - f.write(f"r_l1_distance {r_l1_distance} \n") - - -def compute_color_score( - font_color_eval_list: dict, - font_color_gt_list: dict, - f: TextIOWrapper, -) -> None: - cnt_img = 0 - color_distance = 0.0 - for index in font_color_eval_list.keys(): - e = font_color_eval_list[index] - g = font_color_gt_list[index] - if len(g) > 0: - d = 0.0 - for i, (pred, gt) in enumerate(zip(e, g)): - pred = pred[0] - lab_p = rgb2lab(np.array(pred).reshape(1, 1, 3).astype(np.float32)) - lab_g = rgb2lab(np.array(gt).reshape(1, 1, 3).astype(np.float32)) - _d = deltaE_ciede2000(lab_p, lab_g) - d += _d[0][0] - color_distance += d / len(g) - cnt_img += 1 - color_distance /= cnt_img - log.info(f"color_distance {color_distance}") - f.write(f"color_distance {color_distance} \n") - - -def compute_score( - eval_list: dict, - gt_list: dict, - topk: int, - show_topk: List, - f: TextIOWrapper, -) -> Tuple: - cnt_elm = 0 - cnt_img = 0 - topk_acc_elm = {} - topk_acc_img = {} - for k in range(topk): - topk_acc_elm[k] = 0.0 - topk_acc_img[k] = 0.0 - for index in eval_list.keys(): - e = eval_list[index] - g = gt_list[index] - topk_acc_tmp = {} - for k in range(topk): - topk_acc_tmp[k] = 0.0 - if len(g) > 0: - cnt_gt = 0 - for i, gt in enumerate(g): - flag = 0 - for k in range(min(topk, len(e[i]))): - if int(gt) == int(e[i][k]): - flag = 1.0 - topk_acc_elm[k] += flag - topk_acc_tmp[k] += flag - cnt_gt += 1 - cnt_elm += 1 - for k in range(topk): - if cnt_gt > 0: - topk_acc_img[k] += topk_acc_tmp[k] / cnt_gt - cnt_img += 1 - for k in range(topk): - topk_acc_elm[k] /= cnt_elm - topk_acc_img[k] /= cnt_img - if k in show_topk: - log.info(f"top{k} img_level_acc {topk_acc_img[k]}") - f.write(f"top{k} img_level_acc {topk_acc_img[k]} \n") - return topk_acc_elm, topk_acc_img - - -def dict_average_score(prefix: str, score_dict: dict) -> None: - score_mean = 0 - cnt = 0 - for index in score_dict.keys(): - score_mean += score_dict[index] - cnt += 1 - log.info("{} {}".format(prefix, score_mean / cnt)) - - -def compute_unigram_label_score( - text_num: int, - text_target_mask: np.array, - font_pred: np.array, - font_gt: np.array, - ignore_labels: List[int], -) -> float: - cnt = 0 - correct_cnt = 0 - for i in range(text_num): - fi = font_pred[i] - if fi in ignore_labels or text_target_mask[i] == 0: - continue - else: - cnt += 1 - for j in range(text_num): - fj = int(font_gt[j]) - if fi == fj: - correct_cnt += 1 - break - if cnt > 0: - score = float(correct_cnt) / cnt - else: - score = 0 - return score - - -def compute_bigram_label_score( - pred: np.array, - gt: np.array, -) -> float: - text_num = len(gt) - text_cmb = list(itertools.combinations(list(range(text_num)), 2)) - cnt = 0 - correct_cnt = 0 - for pi, pj in text_cmb: - fpi = pred[pi][0] - fpj = pred[pj][0] - cnt += 1 - for gi, gj in text_cmb: - fgi = int(gt[gi]) - fgj = int(gt[gj]) - if (fpi == fgi) and (fpj == fgj): - correct_cnt += 1 - break - elif (fpi == fgj) and (fpj == fgi): - correct_cnt += 1 - break - if cnt > 0: - score = float(correct_cnt) / cnt - else: - score = 0 - return score - - -def get_binary_classification_scores( - l11cnt: int, l00cnt: int, l10cnt: int, l01cnt: int -) -> Tuple: - if l11cnt + l10cnt > 0: - precision = float(l11cnt) / (l11cnt + l10cnt) - else: - precision = 0 - - if l11cnt + l01cnt > 0: - recall = float(l11cnt) / (l11cnt + l01cnt) - else: - recall = 0 - - if l00cnt + l01cnt > 0: - precision_inv = float(l00cnt) / (l00cnt + l01cnt) - else: - precision_inv = 0 - - if l00cnt + l10cnt > 0: - recall_inv = float(l00cnt) / (l00cnt + l10cnt) - else: - recall_inv = 0 - - if precision + recall > 0: - fvalue = 2 * precision * recall / (precision + recall) - else: - fvalue = 0 - - if 2 * l11cnt + l01cnt + l10cnt == 0: - _fvalue = np.nan - else: - _fvalue = 2 * l11cnt / (2 * l11cnt + l01cnt + l10cnt) - - if precision_inv + recall_inv > 0: - fvalue_inv = 2 * precision_inv * recall_inv / (precision_inv + recall_inv) - else: - fvalue_inv = 0 - - if 2 * l00cnt + l01cnt + l10cnt == 0: - _fvalue_inv = np.nan - else: - _fvalue_inv = 2 * l00cnt / (2 * l00cnt + l01cnt + l10cnt) - - if l11cnt + l00cnt + l01cnt + l10cnt > 0: - accuracy = float(l11cnt + l00cnt) / (l11cnt + l00cnt + l01cnt + l10cnt) - else: - accuracy = 0 - if l00cnt + l01cnt > 0: - spcecificity = float(l00cnt) / (l00cnt + l01cnt) - else: - spcecificity = 0 - return ( - accuracy, - spcecificity, - precision, - recall, - fvalue, - precision_inv, - recall_inv, - fvalue_inv, - _fvalue, - _fvalue_inv, - ) - - -def compute_bigram_structure_score( - pred: np.array, - gt: np.array, -) -> Tuple: - text_num = len(gt) - text_cmb = list(itertools.combinations(list(range(text_num)), 2)) - l11cnt = 0 - l00cnt = 0 - l10cnt = 0 - l01cnt = 0 - for pi, pj in text_cmb: - fpi = pred[pi][0] - fpj = pred[pj][0] - fgi = int(gt[pi]) - fgj = int(gt[pj]) - if (fpi == fpj) and (fgi == fgj): - l11cnt += 1 - # l00cnt += 1 - if (fpi != fpj) and (fgi != fgj): - l00cnt += 1 - # l11cnt += 1 - if (fpi != fpj) and (fgi == fgj): - l10cnt += 1 - # l01cnt += 1 - if (fpi == fpj) and (fgi != fgj): - l01cnt += 1 - # l10cnt += 1 - scores = get_binary_classification_scores(l11cnt, l00cnt, l10cnt, l01cnt) - return scores, (l11cnt, l00cnt, l10cnt, l01cnt) - - -def get_structure_type( - gt: np.array, -) -> float: - text_num = len(gt) - text_cmb = list(itertools.combinations(list(range(text_num)), 2)) - consistency_num = 0 - contrast_num = 0 - for pi, pj in text_cmb: - fgi = int(gt[pi]) - fgj = int(gt[pj]) - if fgi == fgj: - consistency_num += 1 - else: - contrast_num += 1 - if text_num <= 1: - flag = 0 # uncount - elif consistency_num == 0: - flag = 1 # no consistency - elif contrast_num == 0: - flag = 2 # no contrast - else: - flag = 3 # others - return flag - - -def compute_bigram_score( - text_num_list: dict, - pred_list: dict, - gt_list: dict, - f: TextIOWrapper, -) -> Tuple: - cnt = 0 - structure_accuracy_mean = 0.0 - structure_precision_mean = 0.0 - structure_recall_mean = 0.0 - structure_fvalue_mean = 0.0 - structure_spcecificity_mean = 0.0 - structure_precision_inv_mean = 0.0 - structure_recall_inv_mean = 0.0 - structure_fvalue_inv_mean = 0.0 - label_score_mean = 0.0 - l11cnt_all = 0 - l00cnt_all = 0 - l10cnt_all = 0 - l01cnt_all = 0 - diff_case_scores = {} - diff_case_counts = {} - diff_case_scores[1] = 0.0 # no consistency - diff_case_scores[2] = 0.0 # no contrast - diff_case_scores[3] = 0.0 # others - diff_case_counts[1] = 0 - diff_case_counts[2] = 0 - diff_case_counts[3] = 0 - structure_nanmean = [] - for index in pred_list.keys(): - text_num = text_num_list[index] - if text_num == 0: - continue - pred = pred_list[index] - gt = gt_list[index] - flag = get_structure_type(gt) - scores, counts = compute_bigram_structure_score(pred, gt) - ( - structure_accuracy, - structure_spcecificity, - structure_precision, - structure_recall, - structure_fvalue, - structure_precision_inv, - structure_recall_inv, - structure_fvalue_inv, - _structure_fvalue, - _structure_fvalue_inv, - ) = scores - l11cnt, l00cnt, l10cnt, l01cnt = counts - l11cnt_all += l11cnt - l00cnt_all += l00cnt - l10cnt_all += l10cnt - l01cnt_all += l01cnt - label_score = compute_bigram_label_score(pred, gt) - structure_accuracy_mean += structure_accuracy - structure_spcecificity_mean += structure_spcecificity - structure_precision_mean += structure_precision - structure_recall_mean += structure_recall - structure_fvalue_mean += structure_fvalue - structure_precision_inv_mean += structure_precision_inv - structure_recall_inv_mean += structure_recall_inv - structure_fvalue_inv_mean += structure_fvalue_inv - label_score_mean += label_score - - if flag == 0: - pass - elif flag == 1: # no consistency - diff_case_scores[1] += structure_fvalue_inv - diff_case_counts[1] += 1 - elif flag == 2: # no contrast - diff_case_scores[2] += structure_fvalue - diff_case_counts[2] += 1 - elif flag == 3: # others - diff_case_scores[3] += (structure_fvalue + structure_fvalue_inv) / 2.0 - diff_case_counts[3] += 1 - structure_nanmean.append(_structure_fvalue) - structure_nanmean.append(_structure_fvalue_inv) - - cnt += 1 - structure_accuracy_mean /= cnt - structure_spcecificity_mean /= cnt - structure_precision_mean /= cnt - structure_recall_mean /= cnt - structure_fvalue_mean /= cnt - structure_precision_inv_mean /= cnt - structure_recall_inv_mean /= cnt - structure_fvalue_inv_mean /= cnt - label_score_mean /= cnt - log.info("structure_accuracy {:.3f}".format(structure_accuracy_mean)) - f.write("structure_accuracy {:.3f} \n".format(structure_accuracy_mean)) - - # log.info("label_score {:.3f}".format(label_score_mean)) - # f.write("label_score {:.3f} \n".format(label_score_mean)) - log.info("structure nanmean {:.3f}".format(np.nanmean(structure_nanmean))) - f.write("structure nanmean {:.3f} \n".format(np.nanmean(structure_nanmean))) - for i in range(1, 4): - if diff_case_counts[i] > 0: - log.info( - "structure_case_score{} count:{} {:.3f}".format( - i, diff_case_counts[i], diff_case_scores[i] / diff_case_counts[i] - ) - ) - f.write( - "structure_case_score{} count:{} {:.3f} \n".format( - i, diff_case_counts[i], diff_case_scores[i] / diff_case_counts[i] - ) - ) - else: - log.info("structure_case_score{} count:{} -".format(i, diff_case_counts[i])) - f.write( - "structure_case_score{} count:{} - \n".format(i, diff_case_counts[i]) - ) - - scores = get_binary_classification_scores( - l11cnt_all, l00cnt_all, l10cnt_all, l01cnt_all - ) - ( - structure_accuracy, - structure_spcecificity, - structure_precision, - structure_recall, - structure_fvalue, - structure_precision_inv, - structure_recall_inv, - structure_fvalue_inv, - _structure_fvalue, - _structure_fvalue_inv, - ) = scores - return structure_accuracy_mean, label_score_mean - - -def compute_label_diversity_score( - data_index_list: List, - text_num_list: Dict, - pred_list: Dict, - f: TextIOWrapper, - sampling_num: int = 10, -) -> None: - def compute_label_diversity(pred_labels: List, text_num: int) -> float: - unique_num_rate_avg = 0.0 - for k in range(text_num): - labels = [] - for j in range(len(pred_labels)): - l = int(pred_labels[j][k][0]) - labels.append(l) - unique_num_rate = len(set(labels)) / float(len(pred_labels)) - unique_num_rate_avg += unique_num_rate - unique_num_rate_avg /= text_num - return unique_num_rate_avg - - label_diversity_avg = 0.0 - cnt = 0 - for index in data_index_list: - text_num = text_num_list[f"{index}_0"] - pred_lists = [] - for n in range(sampling_num): - preds = pred_list[f"{index}_{n}"] - pred_lists.append(preds) - if text_num > 0: - label_diversity = compute_label_diversity(pred_lists, text_num) - label_diversity_avg += label_diversity - cnt += 1 - label_diversity_avg /= cnt - log.info("diversity score {:.1f}".format(label_diversity_avg * 100)) - f.write("diversity score {:.1f}\n".format(label_diversity_avg * 100)) - - -def compute_alpha_overlap( - overlap_scores: Dict, - f: TextIOWrapper, -) -> None: - overlap_score_all = 0 - cnt_all = 0 - data_index_list = list(overlap_scores.keys()) - for index in data_index_list: - overlap_score = overlap_scores[f"{index}"] - if overlap_score is not None: - overlap_score_all += overlap_score - cnt_all += 1 - if cnt_all > 0: - overlap_score_all = overlap_score_all / cnt_all - - log.info("alpha overlap score {:.2f}".format(overlap_score_all)) - f.write("alpha overlap score {:.2f}\n".format(overlap_score_all)) - - -def _compute_alpha_overlap( - alpha_map_list: List, -) -> None: - overlap_score = 0 - cnt = 0 - for i in range(len(alpha_map_list)): - alpha_i = alpha_map_list[i] - for j in range(len(alpha_map_list)): - if i == j: - continue - else: - alpha_j = alpha_map_list[j] - overlap = np.sum(alpha_i * alpha_j) - if np.sum(alpha_i) > 0: - recall = overlap / np.sum(alpha_i) - overlap_score += recall - cnt += 1 - if cnt > 0: - overlap_score = overlap_score / cnt - return overlap_score - else: - return None diff --git a/src/typography_generation/tools/structure_preserved_sampler.py b/src/typography_generation/tools/structure_preserved_sampler.py deleted file mode 100644 index d8e5c96..0000000 --- a/src/typography_generation/tools/structure_preserved_sampler.py +++ /dev/null @@ -1,103 +0,0 @@ -import time -from typing import List - -import torch -import torch.nn as nn -import torch.utils.data -from typography_generation.io.data_loader import CrelloLoader -from typography_generation.tools.sampler import Sampler -from typography_generation.tools.score_func import EvalDataInstance -from typography_generation.io.data_object import ( - DataPreprocessConfig, - DesignContext, - FontConfig, - ModelInput, - PrefixListObject, - SamplingConfig, -) -from logzero import logger - - -############################################################ -# Structure Preserved Sampler -############################################################ -class StructurePreservedSampler(Sampler): - def __init__( - self, - model: nn.Module, - gpu: bool, - save_dir: str, - dataset: CrelloLoader, - prefix_list_object: PrefixListObject, - sampling_config: SamplingConfig, - batch_size: int = 1, - num_worker: int = 2, - show_interval: int = 100, - dataset_division: str = "test", - debug: bool = False, - ) -> None: - super().__init__( - model, - gpu, - save_dir, - dataset, - prefix_list_object, - sampling_config, - batch_size=batch_size, - num_worker=num_worker, - show_interval=show_interval, - dataset_division=dataset_division, - debug=debug, - ) - - def sample_iter( - self, - design_context_list: List[DesignContext], - model_input_batchdata: List, - data_index: str, - end: float, - ) -> None: - start = time.time() - model_inputs = ModelInput(design_context_list, model_input_batchdata, self.gpu) - sampler_input = { - "model_inputs": model_inputs, - "dataset": self.dataset.dataset, - "target_prefix_list": self.prefix_list_target, - "sampling_param_geometry": self.sampling_config.sampling_param_geometry, - "sampling_param_semantic": self.sampling_config.sampling_param_semantic, - } - predictions = self.model.structure_preserved_sample(**sampler_input) - - self.instance_data = EvalDataInstance(self.prefix_list_target) - - text_num = design_context_list[0].canvas_context.canvas_text_num - - self.save_data[data_index] = dict() - for prefix in self.prefix_list_target: - pred_token, gt_token, pred, gt = self.denormalizer.denormalize( - prefix, text_num, predictions, design_context_list[0] - ) - self.instance_data.rigister_att( - text_num, - prefix, - pred_token, - gt_token, - pred, - gt, - ) - self.entire_data.update_prediction_data( - data_index, self.instance_data, f"{prefix}" - ) - self.save_data[data_index][prefix] = pred - self.entire_data.text_num[data_index] = text_num - - forward_time = time.time() - # if self.step % 200 == 0: - data_show = "{}/{}/{}, forward_time: {:.3f} data {:.3f}".format( - self.cnt, - self.step + 1, - self.steps, - forward_time - start, - (start - end), - ) - logger.info(data_show) diff --git a/src/typography_generation/tools/tokenizer.py b/src/typography_generation/tools/tokenizer.py deleted file mode 100644 index 178c85f..0000000 --- a/src/typography_generation/tools/tokenizer.py +++ /dev/null @@ -1,95 +0,0 @@ -import pickle -from typing import Dict, Tuple, Union -from einops import repeat -import numpy as np -from torch import Tensor -import torch - -default_cluster_num_dict = { - "text_font_size": 16, - "text_font_color": 64, - "text_height": 16, - "text_width": 16, - "text_top": 64, - "text_left": 64, - "text_center_y": 64, - "text_center_x": 64, - "text_angle": 16, - "text_letter_spacing": 16, - "text_line_height_scale": 16, - "canvas_aspect_ratio": 16, -} - - -class Tokenizer: - def __init__( - self, - data_dir: str, - cluster_num_dict: Union[Dict, None] = None, - load_cluster: bool = True, - ) -> None: - self.prefix_list = [ - "text_font_size", - "text_font_color", - "text_height", - "text_width", - "text_top", - "text_left", - "text_center_y", - "text_center_x", - "text_angle", - "text_letter_spacing", - "text_line_height_scale", - "canvas_aspect_ratio", - ] - self.rawdata2token = { - "text_font_emb": "text_font", - "text_font_size_raw": "text_font_size", - } - self.rawdata_list = list(self.rawdata2token.keys()) - self.rawdata_out_format = { - "text_font_emb": "token", - "text_font_size_raw": "raw", - } - - self.prediction_token_list = [ - "text_font_emb", - ] - if cluster_num_dict is None: - cluster_num_dict = default_cluster_num_dict - if load_cluster is True: - for prefix in self.prefix_list: - cluster_num = cluster_num_dict[prefix] - fn = f"{data_dir}/cluster/{prefix}_{cluster_num}.pkl" - if prefix == "text_font_color": - cluster = np.array(pickle.load(open(fn, "rb"))) - else: - cluster = np.array(pickle.load(open(fn, "rb"))).flatten() - setattr(self, f"{prefix}_cluster", cluster) - - def assign_label(self, val: Union[float, int], bins: np.array) -> int: - label = int(np.argsort(np.square(bins - val))[0]) - return label - - def assign_color_label(self, val: Union[float, int], bins: np.array) -> int: - val = np.tile(np.array(val)[np.newaxis, :], (len(bins), 1)) - d = np.square(bins - val).sum(1) - label = int(np.argsort(d, axis=0)[0]) - return label - - def tokenize(self, prefix: str, val: Union[float, int]) -> int: - if prefix == "text_font_color": - label = self.assign_color_label(val, getattr(self, f"{prefix}_cluster")) - else: - label = self.assign_label(val, getattr(self, f"{prefix}_cluster")) - return label - - def detokenize(self, prefix: str, label: Union[int, float]) -> Union[Tuple, float]: - if prefix == "text_font_color": - b, g, r = getattr(self, f"{prefix}_cluster")[int(label)] - return (r, g, b) - elif prefix == "text_font_size_raw": - val = float(getattr(self, f"text_font_size_cluster")[int(label)]) - else: - val = float(getattr(self, f"{prefix}_cluster")[int(label)]) - return val diff --git a/src/typography_generation/tools/train.py b/src/typography_generation/tools/train.py deleted file mode 100644 index c7e18b9..0000000 --- a/src/typography_generation/tools/train.py +++ /dev/null @@ -1,293 +0,0 @@ -import gc -import os -import re -import time -from typing import Any, List, Tuple, Union - -import datasets -import torch -import torch.nn as nn -import torch.utils.data -from logzero import logger -from torch.utils.data._utils.collate import default_collate -from torch.utils.tensorboard import SummaryWriter - -from typography_generation.config.attribute_config import ( - TextElementContextPredictionAttributeConfig, -) -from typography_generation.io.data_loader import CrelloLoader -from typography_generation.io.data_object import ( - ModelInput, - PrefixListObject, -) -from typography_generation.tools.loss import LossFunc -from typography_generation.tools.prediction_recorder import PredictionRecoder -from typography_generation.tools.tokenizer import Tokenizer - - -############################################################ -# DataParallel_withLoss -############################################################ -class FullModel(nn.Module): - def __init__(self, model: nn.Module, loss: LossFunc, gpu: bool) -> None: - super(FullModel, self).__init__() - self.model = model - self.loss = loss - self.gpu = gpu - - def forward( - self, - design_context_list: List, - model_input_batchdata: List, - ) -> Tuple: - model_inputs = ModelInput(design_context_list, model_input_batchdata, self.gpu) - outputs = self.model(model_inputs) - total_loss, record_items = self.loss(model_inputs, outputs, self.training) - return (outputs, torch.unsqueeze(total_loss, 0), record_items) - - def update(self, epoch: int, epochs: int, step: int, steps: int) -> None: - if self.model.model_name == "canvasvae": - self.loss.update_vae_weight(epoch, epochs, step, steps) - - -def collate_batch(batch: Tuple[Any, List, str]) -> Tuple: - design_contexts_list = [] - input_batch = [] - svg_id_list = [] - index_list = [] - for design_contexts, model_input_list, svg_id, index in batch: - design_contexts_list.append(design_contexts) - input_batch.append(model_input_list) - svg_id_list.append(svg_id) - index_list.append(index) - input_batch = default_collate(input_batch) - return design_contexts_list, input_batch, svg_id_list, index_list - - -OPTIMIZER_DICT = { - "adam": (torch.optim.AdamW, {"betas": (0.5, 0.999)}), - "sgd": (torch.optim.SGD, {}), -} - - -############################################################ -# Trainer -############################################################ -class Trainer: - def __init__( - self, - model: nn.Module, - gpu: bool, - save_dir: str, - dataset: CrelloLoader, - dataset_val: CrelloLoader, - prefix_list_object: PrefixListObject, - prediction_config_element: TextElementContextPredictionAttributeConfig, - epochs: int = 31, - save_epoch: int = 5, - batch_size: int = 32, - num_worker: int = 2, - learning_rate: float = 0.0002, - weight_decay: float = 0.01, - optimizer_option: str = "adam", - show_interval: int = 100, - train_only: bool = False, - debug: bool = False, - ) -> None: - self.gpu = gpu - self.save_dir = save_dir - self.epochs = epochs - self.save_epoch = save_epoch - self.batch_size = batch_size - self.num_worker = num_worker - self.show_interval = show_interval - self.train_only = train_only - self.debug = debug - self.dataset = dataset - self.dataset_val = dataset_val - self.prefix_list_target = prefix_list_object.target - - layer_regex = { - "lr1": r"(emb.*)|(enc.*)|(lf.*)|(dec.*)|(head.*)", - } - self.epoch = 0 - # model.emb.emb_canvas.load_resnet_weight(data_dir) - # model.emb.emb_element.load_resnet_weight(data_dir) - param = [ - p - for name, p in model.named_parameters() - if bool(re.fullmatch(layer_regex["lr1"], name)) - ] - param_name = [ - name - for name, _ in model.named_parameters() - if bool(re.fullmatch(layer_regex["lr1"], name)) - ] - lossfunc = LossFunc( - model.model_name, - self.prefix_list_target, - prediction_config_element, - gpu, - topk=1, - ) - logger.info(optimizer_option) - optimizer_func, optimizer_kwarg = OPTIMIZER_DICT[optimizer_option] - optimizer_kwarg["lr"] = learning_rate - optimizer_kwarg["weight_decay"] = weight_decay - self.optimizer = optimizer_func(param, **optimizer_kwarg) - self.fullmodel = FullModel(model, lossfunc, gpu) - self.writer = SummaryWriter(os.path.join(save_dir, "tensorboard")) - if gpu is True: - logger.info("use gpu") - logger.info(f"torch.cuda.is_available() {torch.cuda.is_available()}") - self.fullmodel.cuda() - logger.info("model to cuda") - - def train_model(self) -> None: - # Data generators - dataloader = torch.utils.data.DataLoader( - self.dataset, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_worker, - pin_memory=True, - collate_fn=collate_batch, - ) - self.fullmodel.train() - self.pr_train = PredictionRecoder(self.prefix_list_target) - self.pr_val = PredictionRecoder(self.prefix_list_target) - self.epoch = 0 - self.iter_count_train = 0 - self.iter_count_val = 0 - for epoch in range(0, self.epochs): - logger.info("Epoch {}/{}.".format(epoch, self.epochs)) - self.epoch = epoch - # Training - self.pr_train.reset() - logger.info("training") - self.train_epoch(dataloader) - self.pr_train.step_epoch() - torch.cuda.empty_cache() - gc.collect() - logger.info("validation") - self.pr_val.reset() - if self.train_only is False: - self.val_model() - self.pr_val.step_epoch() - if epoch % self.save_epoch == 0: - torch.save( - self.fullmodel.model.state_dict(), - os.path.join(self.save_dir, "model.pth".format(epoch)), - ) - torch.cuda.empty_cache() - gc.collect() - logger.info("training finished") - logger.info("show data") - self.pr_train.show_history_scores() - self.pr_val.show_history_scores() - - def train_epoch(self, dataloader: Any) -> None: - self.steps = len(dataloader) - self.fullmodel.train() - self.step = 0 - self.cnt = 0 - end = time.time() - for inputs in dataloader: - logger.debug("load data") - design_context_list, model_input_batchdata, _, _ = inputs - logger.debug("train step") - self.train_step(design_context_list, model_input_batchdata, end) - end = time.time() - self.step += 1 - self.fullmodel.update(self.epoch, self.epochs, self.step, self.steps) - if self.debug is True: - break - - def train_step( - self, - design_context_list: List, - model_input_batchdata: List, - end: float, - ) -> None: - start = time.time() - logger.debug("model apply") - _, total_loss, recoder_items = self.fullmodel( - design_context_list, model_input_batchdata - ) - logger.debug(f"model apply {time.time()-start}") - logger.debug("record prediction and gt") - self.pr_train(recoder_items) - logger.debug(f"record {time.time()-start}") - logger.debug("update parameters") - total_loss = torch.mean(total_loss) - self.optimizer.zero_grad() - total_loss.backward() - self.optimizer.step() - logger.debug(f"optimize {time.time()-start}") - forward_time = time.time() - if self.step % self.show_interval == 0: - if self.gpu is True: - torch.cuda.empty_cache() - data_show = "{}/{}/{}/{}, forward_time: {:.3f} data {:.3f}".format( - self.epoch, - self.cnt, - self.step + 1, - self.steps, - forward_time - start, - (start - end), - ) - logger.info(data_show) - data_show = "total_loss: {:.3f}".format(total_loss.item()) - logger.info(data_show) - score_dict = self.pr_train.compute_score() - for k, v in score_dict.items(): - self.writer.add_scalar(f"train/{k}", v, self.iter_count_train) - self.iter_count_train += 1 - - def val_model(self) -> None: - # Data generators - dataloader = torch.utils.data.DataLoader( - self.dataset_val, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_worker, - pin_memory=True, - collate_fn=collate_batch, - ) - with torch.no_grad(): - self.steps = len(dataloader) - self.step = 0 - self.cnt = 0 - self.fullmodel.eval() - end = time.time() - for inputs in dataloader: - design_context_list, model_input_batchdata, _, _ = inputs - self.val_step(design_context_list, model_input_batchdata, end) - end = time.time() - self.step += 1 - - def val_step( - self, - design_context_list: List, - model_input_batchdata: List, - end: float, - ) -> None: - start = time.time() - _, _, recoder_items = self.fullmodel(design_context_list, model_input_batchdata) - self.pr_val(recoder_items) - forward_time = time.time() - if self.step % 40 == 0: - data_show = "{}/{}/{}, forward_time: {:.3f} data {:.3f}".format( - self.cnt, - self.step + 1, - self.steps, - forward_time - start, - (start - end), - ) - logger.info(data_show) - score_dict = self.pr_val.compute_score() - for k, v in score_dict.items(): - self.writer.add_scalar(f"val/{k}", v, self.iter_count_val) - self.iter_count_val += 1 - torch.cuda.empty_cache() - gc.collect() diff --git a/src/typography_generation/visualization/__init__.py b/src/typography_generation/visualization/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/src/typography_generation/visualization/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/typography_generation/visualization/renderer.py b/src/typography_generation/visualization/renderer.py deleted file mode 100644 index 5a7d8ee..0000000 --- a/src/typography_generation/visualization/renderer.py +++ /dev/null @@ -1,183 +0,0 @@ -from typing import Any, List, Tuple -import skia -import os -import pickle -import numpy as np - -from typography_generation.visualization.renderer_util import ( - get_color_map, - get_skia_font, - get_text_actual_height, - get_text_actual_width, - get_text_alpha, - get_texts, -) - - -class TextRenderer: - def __init__( - self, - data_dir: str, - fontlabel2fontname: Any, - ) -> None: - self.fontmgr = skia.FontMgr() - fn = os.path.join(data_dir, "font2ttf.pkl") - _font2ttf = pickle.load(open(fn, "rb")) - font2ttf = {} - for key in _font2ttf.keys(): - tmp = _font2ttf[key].split("/data/dataset/crello/")[1] - fn = os.path.join(data_dir, tmp) - font2ttf[key] = fn - self.font2ttf = font2ttf - fn = os.path.join(data_dir, "fonttype2fontid_fix.pkl") - fonttype2fontid = pickle.load(open(fn, "rb")) - - self.fontid2fonttype = {} - for k, v in fonttype2fontid.items(): - self.fontid2fonttype[v] = k - self.fontlabel2fontname = fontlabel2fontname - - def draw_texts( - self, - element_data: dict, - text_ids: List, - bg: np.array, - scaleinfo: Tuple, - ) -> np.array: - H, W = bg.shape[0], bg.shape[1] - h_rate, w_rate = scaleinfo - canvas = bg.copy() - for text_id in text_ids: - font_label = element_data["font"][text_id] - font_name = self.fontlabel2fontname(int(font_label)) - font_name = font_name.replace(" ", "_") - texts = get_texts(element_data, text_id) - font, _ = get_skia_font( - self.font2ttf, - self.fontmgr, - element_data, - text_id, - font_name, - h_rate, - ) - text_alpha = get_text_alpha( - element_data, - text_id, - texts, - font, - H, - W, - w_rate, - ) - text_rgb_map = get_color_map(element_data, text_id, H, W) - canvas = canvas * (1 - text_alpha) + text_alpha * text_rgb_map - return canvas - - def get_text_alpha_list( - self, - element_data: dict, - text_ids: List, - image_size: Tuple[int, int], - scaleinfo: Tuple[float, float], - ) -> List: - H, W = image_size - h_rate, w_rate = scaleinfo - text_alpha_list = [] - for text_id in text_ids: - font_label = element_data["font"][text_id] - font_name = self.fontlabel2fontname(int(font_label)) - font_name = font_name.replace(" ", "_") - texts = get_texts(element_data, text_id) - font, _ = get_skia_font( - self.font2ttf, - self.fontmgr, - element_data, - text_id, - font_name, - h_rate, - ) - text_alpha = get_text_alpha( - element_data, - text_id, - texts, - font, - H, - W, - w_rate, - ) - text_alpha_list.append(text_alpha) - return text_alpha_list - - def get_text_actual_height_list( - self, element_data: dict, text_ids: List, scaleinfo: Tuple[float, float] - ) -> List: - text_actual_height_list = [] - h_rate, _ = scaleinfo - for text_id in text_ids: - font_label = element_data["font"][text_id] - font_name = self.fontlabel2fontname(int(font_label)) - font_name = font_name.replace(" ", "_") - font, _ = get_skia_font( - self.font2ttf, - self.fontmgr, - element_data, - text_id, - font_name, - h_rate, - ) - text_actual_height = get_text_actual_height(font) - text_actual_height_list.append(text_actual_height) - return text_actual_height_list - - def compute_and_set_text_actual_width( - self, - element_data: dict, - text_ids: List, - scaleinfo: Tuple, - ) -> None: - h_rate, w_rate = scaleinfo - text_actual_width = {} - for text_id in text_ids: - font_label = element_data["font"][text_id] - font_name = self.fontlabel2fontname(int(font_label)) - font_name = font_name.replace(" ", "_") - texts = get_texts(element_data, text_id) - font, _ = get_skia_font( - self.font2ttf, - self.fontmgr, - element_data, - text_id, - font_name, - h_rate, - ) - _text_actual_width = get_text_actual_width( - element_data, text_id, texts, font, w_rate - ) - text_actual_width[text_id] = _text_actual_width - element_data["text_actual_width"] = text_actual_width - - def compute_and_set_text_center( - self, element_data: dict, text_ids: List, image_size: Tuple[int, int] - ) -> None: - H, W = image_size - text_center_x = {} - text_center_y = {} - for text_id in text_ids: - left = element_data["left"][text_id] * W - w = element_data["width"][text_id] * W - textAlign = element_data["text_align"][text_id] - right = left + w - actual_w = element_data["text_actual_width"][text_id] - if textAlign == 1: - _text_center_x = (left + right) / 2.0 - elif textAlign == 3: - _text_center_x = right - actual_w / 2.0 - elif textAlign == 2: - _text_center_x = left + actual_w / 2.0 - text_height = element_data["height"][text_id] * H - top = element_data["top"][text_id] * H - _text_center_y = (top + top + text_height) / 2.0 - text_center_y[text_id] = _text_center_y - text_center_x[text_id] = _text_center_x - element_data["text_center_y"] = text_center_y - element_data["text_center_x"] = text_center_x diff --git a/src/typography_generation/visualization/renderer_util.py b/src/typography_generation/visualization/renderer_util.py deleted file mode 100644 index a4e6504..0000000 --- a/src/typography_generation/visualization/renderer_util.py +++ /dev/null @@ -1,225 +0,0 @@ -import html -import math -import os -from typing import Any, Dict, List, Tuple, Union - -import numpy as np -import skia - - -def capitalize_text(text: str, capitalize: int) -> str: - text_tmp = "" - for character in text: - if capitalize == 1: - character = character.upper() - text_tmp += character - return text_tmp - - -def text_align( - textAlign: int, left: float, center: float, right: float, text_alpha_width: float -) -> float: - if textAlign == 1: - x = center - text_alpha_width / 2.0 - elif textAlign == 3: - x = right - text_alpha_width - elif textAlign == 2: - x = left - return x - - -def get_text_location_info(font: skia.Font, text_tmp: str) -> Tuple: - glyphs = font.textToGlyphs(text_tmp) - positions = font.getPos(glyphs) - rects = font.getBounds(glyphs) - return glyphs, positions, rects - - -def compute_text_alpha_width( - positions: List, rects: List, letterSpacing: float -) -> Union[float, Any]: - twidth = positions[-1].x() + rects[-1].right() - if letterSpacing is not None: - twidth += letterSpacing * (len(rects) - 1) - return twidth - - -def add_letter_margin(x: float, letterSpacing: float) -> float: - if letterSpacing is not None: - x = x + letterSpacing - return x - - -def get_text_actual_height( - font: skia.Font, -): - ascender = -1 * font.getMetrics().fAscent - descender = font.getMetrics().fDescent - leading = font.getMetrics().fLeading - text_height = ascender + descender + leading - return text_height - - -def get_text_alpha( - element_data: Any, - text_index: int, - texts: List, - font: skia.Font, - H: int, - W: int, - w_rate: float, -) -> np.array: - center_y = element_data["text_center_y"][text_index] * H - center_x = element_data["text_center_x"][text_index] * W - text_width = get_text_actual_width(element_data, text_index, texts, font, w_rate) - left = center_x - text_width / 2.0 - right = center_x + text_width / 2.0 - ascender = -1 * font.getMetrics().fAscent - descender = font.getMetrics().fDescent - leading = font.getMetrics().fLeading - line_height_scale = element_data["line_height"][text_index] - line_height = (ascender + descender + leading) * line_height_scale - surface = skia.Surface(W, H) - canvas = surface.getCanvas() - fill_paint = skia.Paint( - AntiAlias=True, - Color=skia.ColorSetRGB(255, 0, 0), - Style=skia.Paint.kFill_Style, - ) - fill_paint.setBlendMode(skia.BlendMode.kSrcOver) - y = center_y - line_height * len(texts) / 2.0 - for text in texts: - text = html.unescape(text) - text = capitalize_text(text, element_data["capitalize"][text_index]) - _, positions, rects = get_text_location_info(font, text) - if len(positions) == 0: - continue - text_alpha_width = compute_text_alpha_width( - positions, rects, element_data["letter_spacing"][text_index] * w_rate - ) - # print(text,element_data["letter_spacing"][text_index],element_data["capitalize"][text_index]) - angle = float(element_data["angle"][text_index]) * 180 / math.pi - x = text_align( - element_data["text_align"][text_index], - left, - center_x, - right, - text_alpha_width, - ) - canvas.rotate(angle, center_x, center_y) - for i, character in enumerate(text): - ydp = np.round(y + positions[i].y() + ascender) - xdp = np.round(x + positions[i].x()) - textblob = skia.TextBlob(character, font) - canvas.drawTextBlob(textblob, xdp, ydp, fill_paint) - x = add_letter_margin( - x, element_data["letter_spacing"][text_index] * w_rate - ) - canvas.rotate(-1 * angle, center_x, y) - y += line_height - text_alpha = surface.makeImageSnapshot().toarray()[:, :, 0] - text_alpha = text_alpha / 255.0 - text_alpha = np.tile(text_alpha[:, :, np.newaxis], (1, 1, 3)) - return np.minimum(text_alpha, np.zeros_like(text_alpha) + 1) - - -def get_text_actual_width( - element_data: Any, - text_index: int, - texts: List, - font: skia.Font, - w_rate: float, -) -> np.array: - text_alpha_width = 0.0 - for text in texts: - text = html.unescape(text) - text = capitalize_text(text, element_data["capitalize"][text_index]) - _, positions, rects = get_text_location_info(font, text) - if len(positions) == 0: - continue - _text_alpha_width = compute_text_alpha_width( - positions, rects, element_data["letter_spacing"][text_index] * w_rate - ) - text_alpha_width = max(text_alpha_width, _text_alpha_width) - return text_alpha_width - - -def font_name_fix(font_name: str) -> str: - if font_name == "Exo_2": - font_name = "Exo\_2" - if font_name == "Press_Start_2P": - font_name = "Press_Start\_2P" - if font_name == "quattrocento": - font_name = "Quattrocento" - if font_name == "yellowtail": - font_name = "Yellowtail" - if font_name == "sunday": - font_name = "Sunday" - if font_name == "bebas_neue": - font_name = "Bebas_Neue" - if font_name == "Brusher": - font_name = "Brusher_Regular" - if font_name == "Amatic_Sc": - font_name = "Amatic_SC" - if font_name == "Pt_Sans": - font_name = "PT_Sans" - if font_name == "Old_Standard_Tt": - font_name = "Old_Standard_TT" - if font_name == "Eb_Garamond": - font_name = "EB_Garamond" - if font_name == "Gfs_Didot": - font_name = "GFS_Didot" - if font_name == "Im_Fell": - font_name = "IM_Fell" - if font_name == "Im_Fell_Dw_Pica_Sc": - font_name = "IM_Fell_DW_Pica_SC" - if font_name == "Marcellus_Sc": - font_name = "Marcellus_SC" - return font_name - - -def get_skia_font( - font2ttf: dict, - fontmgr: skia.FontMgr, - element_data: Dict, - targetid: int, - font_name: str, - scale_h: float, - font_scale: float = 1.0, -) -> Tuple: - font_name = font_name_fix(font_name) - if font_name in font2ttf: - ttf = font2ttf[font_name] - ft = fontmgr.makeFromFile(ttf, 0) - font = skia.Font( - ft, element_data["font_size"][targetid] * scale_h, font_scale, 1e-20 - ) - - return font, font_name - else: - ft = fontmgr.makeFromFile("", 0) - font = skia.Font( - ft, element_data["font_size"][targetid] * scale_h, font_scale, 1e-20 - ) - return None, None - - -def get_color_map(element_data: Any, targetid: int, H: int, W: int) -> np.array: - B, G, R = element_data["color"][targetid] - rgb_map = np.zeros((H, W, 3), dtype=np.uint8) - rgb_map[:, :, 0] = B - rgb_map[:, :, 1] = G - rgb_map[:, :, 2] = R - return rgb_map - - -def get_texts(element_data: Dict, target_id: int) -> List: - text = html.unescape(element_data["text"][target_id]) - _texts = text.split(os.linesep) - texts = [] - for t in _texts: - if t == "": - pass - else: - texts.append(t) - return texts diff --git a/src/typography_generation/visualization/visualizer.py b/src/typography_generation/visualization/visualizer.py deleted file mode 100644 index 1ef6d6c..0000000 --- a/src/typography_generation/visualization/visualizer.py +++ /dev/null @@ -1,190 +0,0 @@ -import copy -from typing import Dict, List, Tuple - -import numpy as np -from typography_generation.tools.denormalizer import Denormalizer -from typography_generation.tools.tokenizer import Tokenizer -from typography_generation.visualization.renderer import TextRenderer - -crelloattstr2pkgattstr = { - "text_font": "font", - "text_font_color": "color", - "text_align_type": "text_align", - "text_capitalize": "capitalize", - "text_font_size": "font_size", - "text_font_size_raw": "font_size", - "text_angle": "angle", - "text_letter_spacing": "letter_spacing", - "text_line_height_scale": "line_height", - "text_center_y": "text_center_y", - "text_center_x": "text_center_x", -} - - -def get_text_ids(element_data: Dict) -> List: - text_ids = [] - for k in range(len(element_data["text"])): - if element_data["text"][k] == "": - pass - else: - text_ids.append(k) - return text_ids - - -def replace_style_data_by_prediction( - prediction: Dict, element_data: Dict, text_ids: List -) -> Dict: - element_data = copy.deepcopy(element_data) - for prefix_pred, prefix_vec in crelloattstr2pkgattstr.items(): - if prefix_pred in prediction.keys(): - for i, t in enumerate(text_ids): - element_data[prefix_vec][t] = prediction[prefix_pred][i][0] - return element_data - - -def get_ordered_text_ids(element_data, order_list) -> List: - text_ids = [] - for i in order_list: - if element_data["text"][i] == "": - pass - else: - text_ids.append(i) - return text_ids - - -def visualize_prediction( - renderer: TextRenderer, - element_data: Dict, - prediction: Dict, - bg_img: np.array, -) -> np.array: - text_ids = get_text_ids(element_data) - order_list = element_data["order_list"] - scaleinfo = element_data["scale_box"] - text_ids = get_ordered_text_ids(element_data, order_list) - element_data = replace_style_data_by_prediction(prediction, element_data, text_ids) - img = renderer.draw_texts(element_data, text_ids, np.array(bg_img), scaleinfo) - return img - - -def get_predicted_alphamaps( - renderer: TextRenderer, - element_data: Dict, - prediction: Dict, - image_size: Tuple[int, int], - order_list: List = None, -) -> List: - text_ids = get_text_ids(element_data) - if order_list is None: - order_list = element_data["order_list"] - scaleinfo = element_data["scale_box"] - text_ids = ordering_text_ids(order_list, text_ids) - element_data = replace_style_data_by_prediction(prediction, element_data, text_ids) - alpha_list = renderer.get_text_alpha_list( - element_data, text_ids, image_size, scaleinfo - ) - return alpha_list - - -def get_element_alphamaps( - renderer: TextRenderer, - element_data: Dict, -) -> List: - text_ids = get_text_ids(element_data) - order_list = element_data["order_list"] - scaleinfo = element_data["scale_box"] - image_size = element_data["canvas_bg_size"] - text_ids = ordering_text_ids(order_list, text_ids) - alpha_list = renderer.get_text_alpha_list( - element_data, text_ids, image_size, scaleinfo - ) - return alpha_list - - -def visualize_data( - renderer: TextRenderer, - element_data: Dict, - bg_img: np.array, -) -> np.array: - text_ids = get_text_ids(element_data) - scaleinfo = element_data["scale_box"] - img = renderer.draw_texts(element_data, text_ids, np.array(bg_img), scaleinfo) - return img - - -def tokenize( - _element_data: Dict, - tokenizer: Tokenizer, - denormalizer: Denormalizer, - text_ids: List, - bg_img: np.array, - scaleinfo: Tuple, -) -> Dict: - element_data = copy.deepcopy(_element_data) - h, w = bg_img.size[1], bg_img.size[0] - for prefix_pred, prefix_vec in crelloattstr2pkgattstr.items(): - for i, t in enumerate(text_ids): - data_info = { - "element_data": _element_data, - "text_index": t, - "img_size": (h, w), - "scaleinfo": scaleinfo, - "text_actual_width": _element_data["text_actual_width"][t], - "text": None, - } - data = getattr(denormalizer.dataset, f"get_{prefix_pred}")(**data_info) - if prefix_pred in denormalizer.dataset.tokenizer.prefix_list: - data = tokenizer.tokenize(prefix_pred, data) - data = tokenizer.detokenize(prefix_pred, data) - data = denormalizer.denormalize_elm( - prefix_pred, data, h, w, scaleinfo[0], scaleinfo[1] - ) - element_data[prefix_vec][t] = data - return element_data - - -def visualize_tokenization( - renderer: TextRenderer, - tokenizer: Tokenizer, - denormalizer: Denormalizer, - element_data: Dict, - bg_img: np.array, -) -> np.array: - text_ids = get_text_ids(element_data) - scaleinfo = element_data["scale_box"] - element_data = tokenize( - element_data, tokenizer, denormalizer, text_ids, bg_img, scaleinfo - ) - img = renderer.draw_texts(element_data, text_ids, np.array(bg_img), scaleinfo) - return img - - -def get_text_coords(element_data: Dict, text_index: int, img_size: Tuple) -> Tuple: - h, w = img_size - top = int(element_data["top"][text_index] * h) - left = int(element_data["left"][text_index] * w) - height = int(element_data["height"][text_index] * h) - width = int(element_data["width"][text_index] * w) - return top, left, top + height, left + width - - -def colorize_text( - element_data: Dict, - canvas: np.array, - text_index: int, - color: Tuple = (255, 0, 0), - w: float = 0.5, -) -> np.array: - text_ids = get_text_ids(element_data) - order_list = element_data["order_list"] - scaleinfo = element_data["scale_box"] - text_ids = ordering_text_ids(order_list, text_ids) - y0, x0, y1, x1 = get_text_coords( - element_data, text_ids[text_index], canvas.shape[:2] - ) - tmp = canvas.copy() - tmp[y0:y1, x0:x1, :] = np.array(color) - canvas[y0:y1, x0:x1, :] = ( - w * canvas[y0:y1, x0:x1, :] + (1 - w) * tmp[y0:y1, x0:x1, :] - ) - return canvas diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 4254d4d..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,92 +0,0 @@ -import pytest -import torch - -from typography_generation.config.config_args_util import ( - get_global_config, - get_model_config_input, - get_prefix_lists, -) -from typography_generation.config.default import get_font_config -from typography_generation.io.build_dataset import ( - build_test_dataset, - build_train_dataset, -) -from typography_generation.model.model import create_model -from typography_generation.tools.train import collate_batch - - -@pytest.fixture -def bartconfig(): - data_dir = "data" - config_name = "bart" - bartconfig = get_global_config(data_dir, config_name) - return bartconfig - - -@pytest.fixture -def bartconfigdataset_test(bartconfig): - data_dir = "data" - prefix_list_object = get_prefix_lists(bartconfig) - font_config = get_font_config(bartconfig) - - bartconfigdataset_test = build_test_dataset( - data_dir, - prefix_list_object, - font_config, - debug=True, - ) - return bartconfigdataset_test - - -@pytest.fixture -def bartconfigdataset(bartconfig): - data_dir = "data" - prefix_list_object = get_prefix_lists(bartconfig) - font_config = get_font_config(bartconfig) - - bartconfigdataset, _ = build_train_dataset( - data_dir, - prefix_list_object, - font_config, - debug=True, - ) - return bartconfigdataset - - -@pytest.fixture -def bartconfigdataset_val(bartconfig): - data_dir = "data" - prefix_list_object = get_prefix_lists(bartconfig) - font_config = get_font_config(bartconfig) - - _, bartconfigdataset_val = build_train_dataset( - data_dir, - prefix_list_object, - font_config, - debug=True, - ) - return bartconfigdataset_val - - -@pytest.fixture -def bartconfigdataloader(bartconfigdataset): - bartconfigdataloader = torch.utils.data.DataLoader( - bartconfigdataset, - batch_size=2, - shuffle=False, - num_workers=1, - pin_memory=True, - collate_fn=collate_batch, - ) - return bartconfigdataloader - - -@pytest.fixture -def bartmodel(bartconfig): - model_name, model_kwargs = get_model_config_input(bartconfig) - - bertmodel = create_model( - model_name, - **model_kwargs, - ) - return bertmodel diff --git a/tests/io/__init__.py b/tests/io/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/tests/io/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/io/test_data_loader.py b/tests/io/test_data_loader.py deleted file mode 100644 index 3967126..0000000 --- a/tests/io/test_data_loader.py +++ /dev/null @@ -1,59 +0,0 @@ -import pytest -import datasets -import torch -from typography_generation.__main__ import ( - get_global_config, - get_prefix_lists, -) -from typography_generation.config.default import ( - get_datapreprocess_config, - get_font_config, -) -from typography_generation.io.build_dataset import build_test_dataset -from typography_generation.io.data_loader import CrelloLoader -from typography_generation.tools.tokenizer import Tokenizer - -from typography_generation.tools.train import collate_batch - -params = {"normal 1": ("data", "bart", 0), "normal 2": ("data", "bart", 1)} - - -@pytest.mark.parametrize("data_dir, config_name, index", list(params.values())) -def test_get_item(data_dir: str, config_name: str, index: int) -> None: - config = get_global_config(data_dir, config_name) - prefix_list_object = get_prefix_lists(config) - font_config = get_font_config(config) - - dataset = build_test_dataset( - data_dir, - prefix_list_object, - font_config, - debug=True, - ) - dataset.__getitem__(index) - - -@pytest.mark.parametrize("data_dir, config_name", [["data", "bart"]]) -def test_dataloader_iteration(data_dir: str, config_name: str) -> None: - config = get_global_config(data_dir, config_name) - prefix_list_object = get_prefix_lists(config) - font_config = get_font_config(config) - - dataset = build_test_dataset( - data_dir, - prefix_list_object, - font_config, - debug=True, - ) - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=1, - shuffle=False, - num_workers=1, - pin_memory=True, - collate_fn=collate_batch, - ) - tmp = dataloader.__iter__() - next(tmp) - next(tmp) diff --git a/tests/model/__init__.py b/tests/model/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/tests/model/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/model/test_model.py b/tests/model/test_model.py deleted file mode 100644 index 70d3298..0000000 --- a/tests/model/test_model.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Any -import pytest -import datasets -import torch -from typography_generation.__main__ import ( - get_global_config, - get_model_config_input, - get_prefix_lists, -) -from typography_generation.config.default import ( - get_datapreprocess_config, - get_font_config, -) -from typography_generation.io.build_dataset import build_test_dataset -from typography_generation.io.data_object import ModelInput -from typography_generation.model.model import create_model -from typography_generation.tools.train import collate_batch - - -@pytest.mark.parametrize( - "data_dir, config_name", - [ - ["data", "bart"], - # ["crello", "mlp"], - # ["crello", "canvasvae"], - # ["crello", "mfc"], - ], -) -def test_model(dataset: Any, data_dir: str, config_name: str) -> None: - config = get_global_config(data_dir, config_name) - model_name, model_kwargs = get_model_config_input(config) - - model = create_model( - model_name, - **model_kwargs, - ) - - dataloader = torch.utils.data.DataLoader( - dataset, - batch_size=2, - shuffle=False, - num_workers=1, - pin_memory=True, - collate_fn=collate_batch, - ) - tmp = dataloader.__iter__() - design_context_list, model_input_batchdata, _, _ = next(tmp) - model_inputs = ModelInput(design_context_list, model_input_batchdata, gpu=False) - model(model_inputs) diff --git a/tests/test_main.py b/tests/test_main.py deleted file mode 100644 index 47fda27..0000000 --- a/tests/test_main.py +++ /dev/null @@ -1,106 +0,0 @@ -import datasets -import pytest -from logzero import logger -import os -from typography_generation.__main__ import ( - get_global_config, - get_model_config_input, - get_prefix_lists, -) -from typography_generation.config.default import ( - get_datapreprocess_config, - get_font_config, - get_model_input_prefix_list, -) -from typography_generation.io.build_dataset import build_train_dataset -from typography_generation.io.data_loader import CrelloLoader -from typography_generation.model.model import create_model -from typography_generation.tools.tokenizer import Tokenizer -from typography_generation.tools.train import Trainer - - -@pytest.mark.parametrize("data_dir, config_name", [["crello", "bart"]]) -def test_model_creation(data_dir: str, config_name: str) -> None: - config = get_global_config(data_dir, config_name) - model_name, model_kwargs = get_model_config_input(config) - - create_model( - model_name, - **model_kwargs, - ) - - -@pytest.mark.parametrize("data_dir, config_name", [["crello", "bart"]]) -def test_loader_creation(data_dir: str, config_name: str) -> None: - config = get_global_config(data_dir, config_name) - prefix_list_model_input = get_model_input_prefix_list(config) - font_config = get_font_config(config) - datapreprocess_config = get_datapreprocess_config(config) - hugging_dataset = datasets.load_from_disk( - os.path.join(data_dir, "extended_dataset", "map_features.hf") - ) - tokenizer = Tokenizer(data_dir) - CrelloLoader( - data_dir, - tokenizer, - hugging_dataset, - prefix_list_model_input, - font_config, - datapreprocess_config, - train=True, - dataset_division="train", - ) - - -@pytest.mark.parametrize( - "data_dir, save_dir, config_name", - [["crello", "job", "bart"]], -) -def test_trainer_creation(data_dir: str, save_dir: str, config_name: str) -> None: - logger.info(config_name) - config = get_global_config(data_dir, config_name) - model_name, model_kwargs = get_model_config_input(config) - - model = create_model( - model_name, - **model_kwargs, - ) - prefix_list_object = get_prefix_lists(config) - prediction_config_element = config.text_element_prediction_attribute_config - font_config = get_font_config(config) - datapreprocess_config = get_datapreprocess_config(config) - optimizer_option = config.train_config.optimizer - - dataset, dataset_val = build_train_dataset( - data_dir, - prefix_list_object, - font_config, - datapreprocess_config, - debug=True, - ) - - gpu = False - Trainer( - model, - gpu, - save_dir, - dataset, - dataset_val, - prefix_list_object, - prediction_config_element, - optimizer_option=optimizer_option, - ) - - -@pytest.mark.parametrize( - "data_dir, config_name", - [["crello", "bart"]], -) -def test_model_config(data_dir: str, config_name: str) -> None: - logger.info(f"config_name {config_name}") - config = get_global_config(data_dir, config_name) - model_name, model_kwargs = get_model_config_input(config) - create_model( - model_name, - **model_kwargs, - ) diff --git a/tests/tool/__init__.py b/tests/tool/__init__.py deleted file mode 100644 index 8b13789..0000000 --- a/tests/tool/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/tests/tool/test_eval.py b/tests/tool/test_eval.py deleted file mode 100644 index fcf9252..0000000 --- a/tests/tool/test_eval.py +++ /dev/null @@ -1,95 +0,0 @@ -import pytest -from typography_generation.__main__ import ( - get_global_config, - get_model_config_input, - get_prefix_lists, -) -from typography_generation.config.default import ( - get_font_config, -) -from typography_generation.io.build_dataset import build_test_dataset -from typography_generation.model.model import create_model -from typography_generation.tools.evaluator import Evaluator - - -def test_eval_iter(bartconfig, bartconfigdataset_test, bartmodel) -> None: - prefix_list_object = get_prefix_lists(bartconfig) - - gpu = False - save_dir = "job" - evaluator = Evaluator( - bartmodel, - gpu, - save_dir, - bartconfigdataset_test, - prefix_list_object, - debug=True, - ) - evaluator.eval_model() - - -@pytest.mark.parametrize( - "data_dir, save_dir, config_name, elementembeddingflagconfigname, canvasembeddingflagconfigname, elementpredictionflagconfigname", - [ - [ - "data", - "job", - "bart", - "text_element_embedding_flag_config/wofontsize", - "canvas_embedding_flag_config/canvas_detail_given", - "text_element_prediction_flag_config/rawfontsize", - ], - ], -) -def test_flag_config_eval( - data_dir: str, - save_dir: str, - config_name: str, - elementembeddingflagconfigname: str, - canvasembeddingflagconfigname: str, - elementpredictionflagconfigname, -) -> None: - global_config_input = {} - global_config_input["data_dir"] = data_dir - global_config_input["model_name"] = config_name - global_config_input["test_config_name"] = "test_config" - global_config_input["model_config_name"] = "model_config" - global_config_input[ - "elementembeddingflag_config_name" - ] = elementembeddingflagconfigname - global_config_input[ - "canvasembeddingflag_config_name" - ] = canvasembeddingflagconfigname - global_config_input[ - "elementpredictionflag_config_name" - ] = elementpredictionflagconfigname - - config = get_global_config(**global_config_input) - model_name, model_kwargs = get_model_config_input(config) - - model = create_model( - model_name, - **model_kwargs, - ) - # import torch - # model.load_state_dict(torch.load("job/weight.pth", map_location="cpu")) - prefix_list_object = get_prefix_lists(config) - font_config = get_font_config(config) - - dataset = build_test_dataset( - data_dir, - prefix_list_object, - font_config, - debug=True, - ) - - gpu = False - evaluator = Evaluator( - model, - gpu, - save_dir, - dataset, - prefix_list_object, - debug=True, - ) - evaluator.eval_model() diff --git a/tests/tool/test_sample.py b/tests/tool/test_sample.py deleted file mode 100644 index 75240c6..0000000 --- a/tests/tool/test_sample.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Dict, List -import pytest -from typography_generation.__main__ import ( - get_global_config, - get_model_config_input, - get_prefix_lists, - get_sampling_config, -) -from typography_generation.config.default import ( - get_font_config, -) -from typography_generation.io.build_dataset import build_test_dataset -from typography_generation.model.model import create_model -from typography_generation.tools.sampler import Sampler - - -@pytest.mark.parametrize( - "data_dir, save_dir, config_name, test_config, model_config", - [ - ["data", "job", "bart", "test_config/topp07", "model_config"], - ], -) -def test_sample_iter( - bartconfigdataset_test, - bartmodel, - data_dir: str, - save_dir: str, - config_name: str, - test_config: str, - model_config: str, -) -> None: - config = get_global_config( - data_dir, - config_name, - test_config_name=test_config, - model_config_name=model_config, - ) - prefix_list_object = get_prefix_lists(config) - sampling_config = get_sampling_config(config) - - gpu = False - sampler = Sampler( - bartmodel, - gpu, - save_dir, - bartconfigdataset_test, - prefix_list_object, - sampling_config, - debug=True, - ) - sampler.sample() - - -@pytest.mark.parametrize( - "data_dir,save_dir,config_name", - [ - ["data", "job", "canvasvae"], - ], -) -def test_sample_canvasvae(data_dir: str, save_dir: str, config_name: str) -> None: - config = get_global_config(data_dir, config_name) - model_name, model_kwargs = get_model_config_input(config) - - model = create_model( - model_name, - **model_kwargs, - ) - prefix_list_object = get_prefix_lists(config) - font_config = get_font_config(config) - sampling_config = get_sampling_config(config) - - gpu = False - dataset = build_test_dataset( - data_dir, - prefix_list_object, - font_config, - debug=True, - ) - - gpu = False - sampler = Sampler( - model, - gpu, - save_dir, - dataset, - prefix_list_object, - sampling_config, - debug=True, - ) - sampler.sample() diff --git a/tests/tool/test_structure_preserved_sampler.py b/tests/tool/test_structure_preserved_sampler.py deleted file mode 100644 index 0f4a042..0000000 --- a/tests/tool/test_structure_preserved_sampler.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Dict, List -import pytest -import os -import pickle -import torch -from typography_generation.__main__ import ( - get_global_config, - get_model_config_input, - get_prefix_lists, - get_sampling_config, -) -from typography_generation.config.default import ( - get_datapreprocess_config, - get_font_config, -) -from typography_generation.io.build_dataset import build_test_dataset -from typography_generation.io.data_loader import CrelloLoader -from typography_generation.io.data_object import ModelInput -from typography_generation.model.model import create_model -from typography_generation.tools.denormalizer import Denormalizer -from typography_generation.tools.structure_preserved_sampler import ( - StructurePreservedSampler, -) -from typography_generation.tools.tokenizer import Tokenizer -import datasets - -from typography_generation.tools.train import collate_batch - -import logzero -from logzero import logger -import logging - - -def test_sample_iter(bartconfig, bartconfigdataset_test, bartmodel) -> None: - prefix_list_object = get_prefix_lists(bartconfig) - sampling_config = get_sampling_config(bartconfig) - - gpu = False - save_dir = "job" - sampler = StructurePreservedSampler( - bartmodel, - gpu, - save_dir, - bartconfigdataset_test, - prefix_list_object, - sampling_config, - debug=True, - ) - sampler.sample() diff --git a/tests/tool/test_train.py b/tests/tool/test_train.py deleted file mode 100644 index 5734b28..0000000 --- a/tests/tool/test_train.py +++ /dev/null @@ -1,167 +0,0 @@ -import pytest -from typography_generation.__main__ import ( - get_global_config, - get_model_config_input, - get_prefix_lists, -) -from typography_generation.config.default import ( - get_font_config, -) -from typography_generation.io.build_dataset import ( - build_test_dataset, - build_train_dataset, -) -from typography_generation.io.data_object import ModelInput -from typography_generation.model.model import create_model -from typography_generation.tools.loss import LossFunc -from typography_generation.tools.train import Trainer - - -def test_compute_loss(bartconfig, bartconfigdataloader, bartmodel) -> None: - prefix_list_object = get_prefix_lists(bartconfig) - - tmp = bartconfigdataloader.__iter__() - design_context_list, model_input_batchdata, _, _ = next(tmp) - model_inputs = ModelInput(design_context_list, model_input_batchdata, gpu=False) - outputs = bartmodel(model_inputs) - prediction_config_element = bartconfig.text_element_prediction_attribute_config - loss = LossFunc( - bartmodel.model_name, - prefix_list_object.target, - prediction_config_element, - gpu=False, - ) - loss(model_inputs, outputs, training=True) - - -def test_train_iter( - bartconfig, bartconfigdataset, bartconfigdataset_val, bartmodel -) -> None: - save_dir = "job" - optimizer_option = bartconfig.train_config.optimizer - prefix_list_object = get_prefix_lists(bartconfig) - prediction_config_element = bartconfig.text_element_prediction_attribute_config - trainer = Trainer( - bartmodel, - False, - save_dir, - bartconfigdataset, - bartconfigdataset_val, - prefix_list_object, - prediction_config_element, - optimizer_option=optimizer_option, - debug=True, - epochs=1, - ) - trainer.train_model() - - -@pytest.mark.parametrize( - "data_dir, config_name", - [["data", "mfc"], ["data", "canvasvae"]], -) -def test_train_config(data_dir: str, config_name: str) -> None: - config = get_global_config(data_dir, config_name) - model_name, model_kwargs = get_model_config_input(config) - - model = create_model( - model_name, - **model_kwargs, - ) - font_config = get_font_config(config) - optimizer_option = config.train_config.optimizer - prefix_list_object = get_prefix_lists(config) - prediction_config_element = config.text_element_prediction_attribute_config - - dataset, dataset_val = build_train_dataset( - data_dir, - prefix_list_object, - font_config, - debug=True, - ) - - save_dir = "job" - trainer = Trainer( - model, - False, - save_dir, - dataset, - dataset_val, - prefix_list_object, - prediction_config_element, - optimizer_option=optimizer_option, - debug=True, - epochs=1, - ) - trainer.train_model() - - -@pytest.mark.parametrize( - "data_dir, save_dir, config_name, elementembeddingflagconfigname, canvasembeddingflagconfigname, elementpredictionflagconfigname", - [ - [ - "data", - "job", - "bart", - "text_element_embedding_flag_config/wofontsize", - "canvas_embedding_flag_config/canvas_detail_given", - "text_element_prediction_flag_config/rawfontsize", - ], - ], -) -def test_flag_config_train( - data_dir: str, - save_dir: str, - config_name: str, - elementembeddingflagconfigname: str, - canvasembeddingflagconfigname: str, - elementpredictionflagconfigname, -) -> None: - global_config_input = {} - global_config_input["data_dir"] = data_dir - global_config_input["model_name"] = config_name - global_config_input["test_config_name"] = "test_config" - global_config_input["model_config_name"] = "model_config" - global_config_input[ - "elementembeddingflag_config_name" - ] = elementembeddingflagconfigname - global_config_input[ - "canvasembeddingflag_config_name" - ] = canvasembeddingflagconfigname - global_config_input[ - "elementpredictionflag_config_name" - ] = elementpredictionflagconfigname - - config = get_global_config(**global_config_input) - model_name, model_kwargs = get_model_config_input(config) - - model = create_model( - model_name, - **model_kwargs, - ) - prefix_list_object = get_prefix_lists(config) - font_config = get_font_config(config) - prediction_config_element = config.text_element_prediction_attribute_config - optimizer_option = config.train_config.optimizer - - dataset, dataset_val = build_train_dataset( - data_dir, - prefix_list_object, - font_config, - debug=True, - ) - - gpu = False - trainer = Trainer( - model, - False, - save_dir, - dataset, - dataset_val, - prefix_list_object, - prediction_config_element, - optimizer_option=optimizer_option, - debug=True, - epochs=1, - ) - trainer.train_model()