diff --git a/.gitignore b/.gitignore
deleted file mode 100644
index dfaefa2..0000000
--- a/.gitignore
+++ /dev/null
@@ -1,17 +0,0 @@
-# Byte-compiled / optimized / DLL files
-__pycache__/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
-__pypackages__/
-
-# Environments
-.env
-.venv
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
deleted file mode 100644
index 261eeb9..0000000
--- a/LICENSE
+++ /dev/null
@@ -1,201 +0,0 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
diff --git a/README.md b/README.md
deleted file mode 100644
index cbde2ad..0000000
--- a/README.md
+++ /dev/null
@@ -1,121 +0,0 @@
-## Paper: Towards Diverse and Consistent Typography Generation
-Wataru Shimoda1, Daichi Haraguchi2, Seiichi Uchida2, Kota Yamaguchi1
-1CyberAgent.Inc, 2 Kyushu University
-Accepted to WACV2024.
-[[Arxiv](https://arxiv.org/abs/2309.02099)]
-[[project-page]()]
-
-## Introduction
-This repository contains the codes for ["Towards Diverse and Consistent Typography Generation"](https://arxiv.org/abs/2309.02099).
-
-
-
-## Requirements
-We check the reproducibility under the environment.
-- Ubuntu (>=20.04)
-- Python3 (>=3.8)
-
-
-## Install
-Clone this repository and navigate to the folder
-
-``` sh
-git clone https://github.com/CyberAgentAILab/tdc-typography-generation.git
-cd tdc-typography-generation
-```
-
-We manage the dependencies of python libraries by [pyproject.toml](https://github.com/CyberAgentAILab/tdc-typography-generation/blob/main/pyproject.toml).
-Please install the dependencies via pip or poetry using [pyproject.toml](https://github.com/CyberAgentAILab/tdc-typography-generation/blob/main/pyproject.toml).
-
-
-If the version of setuptools is `setuptools >=61.0.0`, the following command installs the dependencies via pip:
-``` sh
-pip install .
-```
-or
-We recommend installing the dependencies using Poetry (see [official docs](https://python-poetry.org/docs/)):
-``` sh
-poetry install
-```
-Note that we omit the head of commands `poetry run` in the after guidance for simplification.
-
-## Dataset
-Our model is trained and tested on [Crello dataset](https://huggingface.co/datasets/cyberagent/crello), and this dataset is open on [Hugging Face](https://huggingface.co/).
-We can download this dataset through the [Hugging Face Dataset API](https://huggingface.co/docs/datasets/index), but this dataset does not contain high-resolution background images.
-
-We provide background images via Google Drive ([link](https://storage.googleapis.com/ailab-public/tdc_typography_generation/generate_bg_png.tar.gz), 3.6GB).
-Please download the background images and locate them to `data/generate_bg_png/`.
-``` sh
-tar zxvf generate_bg_png.tar.gz
-mv generate_bg_png data/
-rm generate_bg_png.tar.gz
-```
-
-We also provide font files for rendering designs and computing appropriate text sizes via Google Drive ([link](https://storage.googleapis.com/ailab-public/tdc_typography_generation/font.tar.gz), 43MB).
-Please download the font files and locate them to `data/font/`.
-``` sh
-tar zxvf font.tar.gz
-mv font data/
-rm font.tar.gz
-```
-
-## Usage
-We prepare scripts for experiments as the following.
-
-### Preprocessing
-We recommend adding features to the dataset in advance, it makes training and testing faster.
-We provide a script for preprocessing:
-``` sh
-python -m typography_generation map_features --datadir data
-```
-This script extends the dataset via [map function](https://huggingface.co/docs/datasets/v2.15.0/en/package_reference/main_classes#datasets.Dataset.map).
-The extended dataset is saved in `data/map_featreus`, and `--use_extended_dataset` option manages the use of the extended dataset.
-
-### Training
-The following command trains a model, it takes a half day with the preprocessed dataset and a NVIDIA T4 machine.
-We handle the detail of training via configuration files in `data/config/*.yaml`.
-The basic configurations are in `src/typography_generation/config/*.py`.
-
-``` sh
-python -m typography_generation train_eval \
- --configname bart \
- --jobdir ${OUTPUT_DIR} \
- --datadir data \
- --use_extended_dataset \
- --gpu \
-```
-The outputs are in ${OUTPUT_DIR}.
-
-### Sampling
-The following command samples typographic attributes.
-This command requires `--weight` option, which is a path for loading weights of a trained model.
-A weight file obtained by the avobe training command is in `${OUTPUT_DIR}/weight.pth`.
-Please assign a path of a weight file to `${WEIGHT_FILE}`.
-``` sh
-python -m typography_generation structure_preserved_sample \
- --configname bart \
- --jobdir ${OUTPUT_DIR} \
- --datadir data \
- --weight=${WEIGHT_FILE} \
- --use_extended_dataset \
- --gpu \
-```
-
-
-## Visualization
-We provides notebooks for showing results.
-- `notebooks/score.ipnyb` shows scores of the saved results in `${OUTPUT_DIR}`.
-- `notebooks/vis.ipnyb` shows generated graphic designs in `${OUTPUT_DIR}`.
-
-## Reference
-```bibtex
-@misc{shimoda_2024_tdctg,
- author = {Shimoda, Wataru and Haraguchi, Daichi and Uchida, Seiichi and Yamaguchi, Kota},
- title = {Towards Diverse and Consistent Typography Generation},
- publisher = {arXiv:2309.02099},
- year = {2024},
-}
-```
-
-## Contact
-This repository is maintained by Wataru shimoda(wataru_shimoda[at]cyberagent.co.jp).
\ No newline at end of file
diff --git a/data/cluster/canvas_aspect_ratio_16.pkl b/data/cluster/canvas_aspect_ratio_16.pkl
deleted file mode 100644
index 189ec37..0000000
Binary files a/data/cluster/canvas_aspect_ratio_16.pkl and /dev/null differ
diff --git a/data/cluster/text_angle_16.pkl b/data/cluster/text_angle_16.pkl
deleted file mode 100644
index b1882c8..0000000
Binary files a/data/cluster/text_angle_16.pkl and /dev/null differ
diff --git a/data/cluster/text_center_x_64.pkl b/data/cluster/text_center_x_64.pkl
deleted file mode 100644
index ed9e334..0000000
Binary files a/data/cluster/text_center_x_64.pkl and /dev/null differ
diff --git a/data/cluster/text_center_y_64.pkl b/data/cluster/text_center_y_64.pkl
deleted file mode 100644
index 0eec0ea..0000000
Binary files a/data/cluster/text_center_y_64.pkl and /dev/null differ
diff --git a/data/cluster/text_font_color_64.pkl b/data/cluster/text_font_color_64.pkl
deleted file mode 100644
index 5ed7216..0000000
Binary files a/data/cluster/text_font_color_64.pkl and /dev/null differ
diff --git a/data/cluster/text_font_size_16.pkl b/data/cluster/text_font_size_16.pkl
deleted file mode 100644
index 3cea5bf..0000000
Binary files a/data/cluster/text_font_size_16.pkl and /dev/null differ
diff --git a/data/cluster/text_height_16.pkl b/data/cluster/text_height_16.pkl
deleted file mode 100644
index d00a399..0000000
Binary files a/data/cluster/text_height_16.pkl and /dev/null differ
diff --git a/data/cluster/text_left_64.pkl b/data/cluster/text_left_64.pkl
deleted file mode 100644
index 5e4dc1c..0000000
Binary files a/data/cluster/text_left_64.pkl and /dev/null differ
diff --git a/data/cluster/text_letter_spacing_16.pkl b/data/cluster/text_letter_spacing_16.pkl
deleted file mode 100644
index d0f2e6e..0000000
Binary files a/data/cluster/text_letter_spacing_16.pkl and /dev/null differ
diff --git a/data/cluster/text_line_height_scale_16.pkl b/data/cluster/text_line_height_scale_16.pkl
deleted file mode 100644
index 983bde4..0000000
Binary files a/data/cluster/text_line_height_scale_16.pkl and /dev/null differ
diff --git a/data/cluster/text_top_64.pkl b/data/cluster/text_top_64.pkl
deleted file mode 100644
index e097126..0000000
Binary files a/data/cluster/text_top_64.pkl and /dev/null differ
diff --git a/data/cluster/text_width_16.pkl b/data/cluster/text_width_16.pkl
deleted file mode 100644
index bedb577..0000000
Binary files a/data/cluster/text_width_16.pkl and /dev/null differ
diff --git a/data/config/bart/canvas_embedding_attribute_config.yaml b/data/config/bart/canvas_embedding_attribute_config.yaml
deleted file mode 100644
index e69de29..0000000
diff --git a/data/config/bart/canvas_embedding_flag_config.yaml b/data/config/bart/canvas_embedding_flag_config.yaml
deleted file mode 100644
index 867fe77..0000000
--- a/data/config/bart/canvas_embedding_flag_config.yaml
+++ /dev/null
@@ -1,5 +0,0 @@
-canvas_embedding_flag:
- canvas_bg_img: False
- canvas_bg_img_emb: True
- canvas_aspect_ratio: True
- canvas_text_num: True
diff --git a/data/config/bart/data_config.yaml b/data/config/bart/data_config.yaml
deleted file mode 100644
index 097add5..0000000
--- a/data/config/bart/data_config.yaml
+++ /dev/null
@@ -1,8 +0,0 @@
-data_config:
- font_num: 261
- large_spatio_bin: 64
- small_spatio_bin: 16
- color_bin: 64
- max_text_line_count: 50
- font_emb_type: "label"
- order_type: "rasterscan_asc"
diff --git a/data/config/bart/metainfo.yaml b/data/config/bart/metainfo.yaml
deleted file mode 100644
index 91da1d4..0000000
--- a/data/config/bart/metainfo.yaml
+++ /dev/null
@@ -1,3 +0,0 @@
-meta_info:
- model_name: "bart"
- dataset: "crello"
diff --git a/data/config/bart/model_config.yaml b/data/config/bart/model_config.yaml
deleted file mode 100644
index 6f5a014..0000000
--- a/data/config/bart/model_config.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-model_config:
- d_model: 256
- num_encoder_layers: 8
- num_decoder_layers: 8
- n_head: 8
- dropout: 0.1
- bert_dim: 768
- clip_dim: 512
- mlp_dim: 1792
diff --git a/data/config/bart/test_config.yaml b/data/config/bart/test_config.yaml
deleted file mode 100644
index 1831047..0000000
--- a/data/config/bart/test_config.yaml
+++ /dev/null
@@ -1,4 +0,0 @@
-test_config:
- sampling_param: 0.9
- sampling_param_geometry: 0.1
- sampling_param_semantic: 0.9
diff --git a/data/config/bart/text_element_embedding_attribute_config.yaml b/data/config/bart/text_element_embedding_attribute_config.yaml
deleted file mode 100644
index e69de29..0000000
diff --git a/data/config/bart/text_element_embedding_flag_config.yaml b/data/config/bart/text_element_embedding_flag_config.yaml
deleted file mode 100644
index 119d966..0000000
--- a/data/config/bart/text_element_embedding_flag_config.yaml
+++ /dev/null
@@ -1,19 +0,0 @@
-text_element_embedding_flag:
- text_font: False
- text_font_size: False
- text_font_color: False
- text_angle: False
- text_letter_spacing: False
- text_line_height_scale: False
- text_capitalize: False
- text_line_count: True
- text_char_count: True
- text_height: False
- text_width: False
- text_top: False
- text_left: False
- text_center_y: True
- text_center_x: True
- text_align_type: False
- text_emb: True
- text_local_img_emb: True
diff --git a/data/config/bart/text_element_prediction_attribute_config.yaml b/data/config/bart/text_element_prediction_attribute_config.yaml
deleted file mode 100644
index e69de29..0000000
diff --git a/data/config/bart/text_element_prediction_flag_config.yaml b/data/config/bart/text_element_prediction_flag_config.yaml
deleted file mode 100644
index dda973f..0000000
--- a/data/config/bart/text_element_prediction_flag_config.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-text_element_prediction_flag:
- text_font: True
- text_font_size: True
- text_font_color: True
- text_angle: True
- text_letter_spacing: True
- text_line_height_scale: True
- text_capitalize: True
- text_align_type: True
diff --git a/data/config/bart/train_config.yaml b/data/config/bart/train_config.yaml
deleted file mode 100644
index 12780fd..0000000
--- a/data/config/bart/train_config.yaml
+++ /dev/null
@@ -1,7 +0,0 @@
-train_config:
- epochs: 31
- save_epoch: 5
- batch_size: 32
- num_worker: 2
- train_only: False
- show_interval: 100
diff --git a/data/font2ttf.pkl b/data/font2ttf.pkl
deleted file mode 100644
index 9903ce6..0000000
Binary files a/data/font2ttf.pkl and /dev/null differ
diff --git a/data/fonttype2fontid_fix.pkl b/data/fonttype2fontid_fix.pkl
deleted file mode 100644
index f2ea1a4..0000000
Binary files a/data/fonttype2fontid_fix.pkl and /dev/null differ
diff --git a/data/svgid2scaleinfo.pkl b/data/svgid2scaleinfo.pkl
deleted file mode 100644
index 68e0d4e..0000000
Binary files a/data/svgid2scaleinfo.pkl and /dev/null differ
diff --git a/notebooks/score.ipynb b/notebooks/score.ipynb
deleted file mode 100644
index 8363651..0000000
--- a/notebooks/score.ipynb
+++ /dev/null
@@ -1,114 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "18.186068109361116\n",
- "59.10895250864158\n",
- "89.43196472690855\n",
- "99.98559492941517\n",
- "29.720623040124977\n",
- "0.381246451385764\n",
- "2.807381701892659\n",
- "0.07559504299962953\n"
- ]
- }
- ],
- "source": [
- "import os\n",
- "from matplotlib import pylab as plt\n",
- "import numpy as np\n",
- "import statistics\n",
- "import math\n",
- "\n",
- "def get_score(prefix, score_index, scalar):\n",
- " tmppath = f\"{prefix}/logs/score.txt\"\n",
- " with open(tmppath) as f:\n",
- " scores = f.readlines()\n",
- " score = scores[score_index].replace('\\n', '').split(\" \")[-2]\n",
- " score = float(score)*scalar\n",
- " return score\n",
- "\n",
- "def get_structure_score(prefix, score_index,scalar,type=None):\n",
- " gcspath = f\"gs://ailab-ephemeral/s09275/job_gcp_eval/{prefix}/logs/score.txt\"\n",
- " _prefix = prefix.split(\"/\")[-1]\n",
- " tmppath = f\"tmp_data/dl/{_prefix}.txt\"\n",
- " com = f\"gsutil cp {gcspath} {tmppath}\"\n",
- " if os.path.isfile(tmppath) is False:\n",
- " os.system(com)\n",
- " with open(tmppath) as f:\n",
- " scores = f.readlines()\n",
- " score0 = scores[score_index].replace('\\n', '').split(\" \")[-2]\n",
- " score0 = float(score0)*scalar\n",
- " count0 = int(scores[score_index].replace('\\n', '').split(\" \")[-3].split(\":\")[-1])\n",
- " score1 = scores[score_index+1].replace('\\n', '').split(\" \")[-2]\n",
- " score1 = float(score1)*scalar\n",
- " count1 = int(scores[score_index+1].replace('\\n', '').split(\" \")[-3].split(\":\")[-1])\n",
- " score2 = scores[score_index+2].replace('\\n', '').split(\" \")[-2]\n",
- " score2 = float(score2)*scalar\n",
- " count2 = int(scores[score_index+2].replace('\\n', '').split(\" \")[-3].split(\":\")[-1])\n",
- " if type is None:\n",
- " count = float(count0+count1+count2)\n",
- " score = score0*count0/count + score1*count1/count + score2*count2/count\n",
- " #print(count0,count1,count2)\n",
- " elif type==0:\n",
- " score = score0\n",
- " elif type==1:\n",
- " score = score1\n",
- " elif type==2:\n",
- " score = score2\n",
- " return score\n",
- "\n",
- "output_dir = \"tmp\" # please set your output directory here\n",
- "tarscore2index1={\n",
- " \"font_score\": (1,100,\"a\"),\n",
- " \"font_color_score\": (25,1,\"a\"),\n",
- " \"font_capitalize_score\": (9,100,\"a\"),\n",
- " \"font_align_score\": (5,100,\"a\"),\n",
- " \"font_size_score\": (13,1,\"a\"),\n",
- " \"font_angle_score\": (19,1,\"b\"),\n",
- " \"font_letter_score\": (16,1,\"b\"),\n",
- " \"font_lline_height_score\": (22,1,\"c\"),\n",
- "}\n",
- "for att, (index,k, type) in tarscore2index1.items():\n",
- "\n",
- " font_score = get_score(output_dir, index,k)\n",
- " print(font_score)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": ".venv",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.1"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/pyproject.toml b/pyproject.toml
deleted file mode 100644
index 30a21d2..0000000
--- a/pyproject.toml
+++ /dev/null
@@ -1,50 +0,0 @@
-[tool.pysen]
-version = "0.9"
-
-[tool.pysen.lint]
-enable_black = true
-enable_flake8 = true
-enable_isort = true
-enable_mypy = true
-mypy_preset = "strict"
-line_length = 88
-[tool.pysen.lint.source]
- excludes = [".venv/",".git/", ".pytest_cache/", ".python-version/","data","tmp_data/"]
-[[tool.pysen.lint.mypy_targets]]
- paths = ["."]
-
-[tool.poetry]
-name = "typography-generation"
-version = "0.1.0"
-description = ""
-authors = ["shimoda-uec "]
-readme = "README.md"
-packages = [{include = "typography_generation", from = "src"}]
-
-
-[tool.poetry.dependencies]
-python = "^3.9"
-skia-python = "^87.5"
-einops = "^0.6.1"
-hydra-core = "^1.3.2"
-logzero = "^1.7.0"
-datasets = "^2.12.0"
-torch = "^1.13"
-scikit-learn = "^1.0"
-pytest = "^7.3.1"
-pillow = "9.0.1"
-matplotlib = "3.5"
-transformers = "4.30.2"
-openpyxl = "^3.1.2"
-tensorboard = "^2.14.1"
-gcsfs = "^2023.9.2"
-seam-carving = "^1.1.0"
-
-
-[tool.poetry.group.dev.dependencies]
-jupyter = "^1.0.0"
-notebook = "^6.5.4"
-
-[build-system]
-requires = ["poetry-core"]
-build-backend = "poetry.core.masonry.api"
diff --git a/src/typography_generation/__init__.py b/src/typography_generation/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/src/typography_generation/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/typography_generation/__main__.py b/src/typography_generation/__main__.py
deleted file mode 100644
index 6303516..0000000
--- a/src/typography_generation/__main__.py
+++ /dev/null
@@ -1,415 +0,0 @@
-import argparse
-import logging
-import os
-from typing import Any, Dict, Tuple
-
-import logzero
-import torch
-from logzero import logger
-from typography_generation.config.config_args_util import (
- args2add_data_inputs,
- get_global_config,
- get_global_config_input,
- get_model_config_input,
- get_prefix_lists,
- get_sampling_config,
- get_train_config_input,
-)
-from typography_generation.config.default import (
- get_datapreprocess_config,
- get_font_config,
-)
-from typography_generation.io.build_dataset import (
- build_test_dataset,
- build_train_dataset,
-)
-from typography_generation.model.model import create_model
-
-from typography_generation.preprocess.map_features import map_features
-
-from typography_generation.tools.evaluator import Evaluator
-
-from typography_generation.tools.sampler import Sampler
-from typography_generation.tools.structure_preserved_sampler import (
- StructurePreservedSampler,
-)
-from typography_generation.tools.train import Trainer
-
-
-def get_save_dir(job_dir: str) -> str:
- save_dir = os.path.join(job_dir, "logs")
- os.makedirs(save_dir, exist_ok=True)
- return save_dir
-
-
-def make_logfile(job_dir: str, debug: bool = False) -> None:
- if debug is True:
- logzero.loglevel(logging.DEBUG)
- else:
- logzero.loglevel(logging.INFO)
-
- os.makedirs(job_dir, exist_ok=True)
- file_name = f"{job_dir}/log.log"
- logzero.logfile(file_name)
-
-
-def train(args: Any) -> None:
- logger.info("training")
- data_dir = args.datadir
- global_config_input = get_global_config_input(data_dir, args)
- config = get_global_config(**global_config_input)
- gpu = args.gpu
- model_name, model_kwargs = get_model_config_input(config)
-
- logger.info("model creation")
- model = create_model(
- model_name,
- **model_kwargs,
- )
-
- logger.info(f"log file location {args.jobdir}/log.log")
- make_logfile(args.jobdir, args.debug)
- save_dir = get_save_dir(args.jobdir)
- logger.info(f"save_dir {save_dir}")
- prefix_list_object = get_prefix_lists(config)
- font_config = get_font_config(config)
- prediction_config_element = config.text_element_prediction_attribute_config
-
- train_kwargs = get_train_config_input(config, args.debug)
- logger.info(f"build trainer")
- dataset, dataset_val = build_train_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- use_extended_dataset=args.use_extended_dataset,
- debug=args.debug,
- )
-
- model_trainer = Trainer(
- model,
- gpu,
- save_dir,
- dataset,
- dataset_val,
- prefix_list_object,
- prediction_config_element,
- **train_kwargs,
- )
- logger.info("start training")
- model_trainer.train_model()
-
-
-def train_eval(args: Any) -> None:
- logger.info("training")
- data_dir = args.datadir
- global_config_input = get_global_config_input(data_dir, args)
- config = get_global_config(**global_config_input)
- gpu = args.gpu
- model_name, model_kwargs = get_model_config_input(config)
-
- logger.info("model creation")
- model = create_model(
- model_name,
- **model_kwargs,
- )
-
- logger.info(f"log file location {args.jobdir}/log.log")
- make_logfile(args.jobdir, args.debug)
- save_dir = get_save_dir(args.jobdir)
- logger.info(f"save_dir {save_dir}")
- prefix_list_object = get_prefix_lists(config)
- font_config = get_font_config(config)
- prediction_config_element = config.text_element_prediction_attribute_config
-
- train_kwargs = get_train_config_input(config, args.debug)
- logger.info(f"build trainer")
-
- dataset, dataset_val = build_train_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- use_extended_dataset=args.use_extended_dataset,
- debug=args.debug,
- )
-
- model_trainer = Trainer(
- model,
- gpu,
- save_dir,
- dataset,
- dataset_val,
- prefix_list_object,
- prediction_config_element,
- **train_kwargs,
- )
- logger.info("start training")
- model_trainer.train_model()
-
- dataset = build_test_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- use_extended_dataset=args.use_extended_dataset,
- debug=args.debug,
- )
-
- evaluator = Evaluator(
- model,
- gpu,
- save_dir,
- dataset,
- prefix_list_object,
- debug=args.debug,
- )
- logger.info("start evaluation")
- evaluator.eval_model()
-
-
-# def _train_font_embedding(args: Any) -> None:
-# train_font_embedding(args.datadir, args.jobdir, args.gpu)
-
-
-def loadweight(weight_file: Any, gpu: bool, model: Any) -> Any:
- if weight_file == "":
- pass
- else:
- if gpu is False:
- state_dict = torch.load(weight_file, map_location=torch.device("cpu"))
- else:
- state_dict = torch.load(weight_file)
- model.load_state_dict(state_dict)
- return model
-
-
-def sample(args: Any) -> None:
- data_dir = args.datadir
- global_config_input = get_global_config_input(data_dir, args)
- config = get_global_config(**global_config_input)
- gpu = args.gpu
- model_name, model_kwargs = get_model_config_input(config)
-
- logger.info("model creation")
- model = create_model(
- model_name,
- **model_kwargs,
- )
- weight = args.weight
- model = loadweight(weight, gpu, model)
-
- logger.info(f"log file location {args.jobdir}/log.log")
- make_logfile(args.jobdir, args.debug)
- save_dir = get_save_dir(args.jobdir)
- logger.info(f"save_dir {save_dir}")
- prefix_list_object = get_prefix_lists(config)
- font_config = get_font_config(config)
- sampling_config = get_sampling_config(config)
-
- dataset = build_test_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- use_extended_dataset=args.use_extended_dataset,
- debug=args.debug,
- )
-
- sampler = Sampler(
- model,
- gpu,
- save_dir,
- dataset,
- prefix_list_object,
- sampling_config,
- debug=args.debug,
- )
- logger.info("start sampling")
- sampler.sample()
-
-
-def structure_preserved_sample(args: Any) -> None:
- data_dir = args.datadir
- global_config_input = get_global_config_input(data_dir, args)
- config = get_global_config(**global_config_input)
- gpu = args.gpu
- model_name, model_kwargs = get_model_config_input(config)
-
- logger.info("model creation")
- model = create_model(
- model_name,
- **model_kwargs,
- )
- weight = args.weight
- model = loadweight(weight, gpu, model)
-
- logger.info(f"log file location {args.jobdir}/log.log")
- make_logfile(args.jobdir, args.debug)
- save_dir = get_save_dir(args.jobdir)
- logger.info(f"save_dir {save_dir}")
- prefix_list_object = get_prefix_lists(config)
- font_config = get_font_config(config)
- sampling_config = get_sampling_config(config)
-
- dataset = build_test_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- use_extended_dataset=args.use_extended_dataset,
- debug=args.debug,
- )
-
- sampler = StructurePreservedSampler(
- model,
- gpu,
- save_dir,
- dataset,
- prefix_list_object,
- sampling_config,
- debug=args.debug,
- )
- logger.info("start sampling")
- sampler.sample()
-
-
-def evaluation_pattern(args: Any, prefix: str, evaluation_class: Any) -> None:
- logger.info(f"{prefix}")
- data_dir = args.datadir
- global_config_input = get_global_config_input(data_dir, args)
- config = get_global_config(**global_config_input)
- gpu = args.gpu
- model_name, model_kwargs = get_model_config_input(config)
-
- logger.info("model creation")
- model = create_model(
- model_name,
- **model_kwargs,
- )
- weight = args.weight
- model = loadweight(weight, gpu, model)
-
- logger.info(f"log file location {args.jobdir}/log.log")
- make_logfile(args.jobdir, args.debug)
- save_dir = get_save_dir(args.jobdir)
- logger.info(f"save_dir {save_dir}")
- prefix_list_object = get_prefix_lists(config)
- font_config = get_font_config(config)
-
- dataset = build_test_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- use_extended_dataset=args.use_extended_dataset,
- debug=args.debug,
- )
- evaluator = evaluation_class(
- model,
- gpu,
- save_dir,
- dataset,
- prefix_list_object,
- dataset_division="test",
- debug=args.debug,
- )
- logger.info("start evaluation")
- evaluator.eval_model()
-
-
-def evaluation(args: Any) -> None:
- evaluation_pattern(args, "evaluation", Evaluator)
-
-
-def _map_features(args: Any) -> None:
- logger.info(f"map_features")
- map_features(args.datadir)
-
-
-COMMANDS = {
- "train": train,
- "train_evaluation": train_eval,
- "sample": sample,
- "structure_preserved_sample": structure_preserved_sample,
- "evaluation": evaluation,
- "map_features": _map_features,
-}
-
-
-if __name__ == "__main__":
- logger.info("start")
- parser = argparse.ArgumentParser()
- parser.add_argument("command", help="job option")
- parser.add_argument(
- "--configname",
- type=str,
- default="bart",
- help="config option",
- )
- parser.add_argument(
- "--testconfigname",
- type=str,
- default="test_config",
- help="test config option",
- )
- parser.add_argument(
- "--modelconfigname",
- type=str,
- default="model_config",
- help="test config option",
- )
- parser.add_argument(
- "--canvasembeddingflagconfigname",
- type=str,
- default="canvas_embedding_flag_config",
- help="canvas embedding flag config option",
- )
- parser.add_argument(
- "--elementembeddingflagconfigname",
- type=str,
- default="text_element_embedding_flag_config",
- help="element embedding flag config option",
- )
- parser.add_argument(
- "--elementpredictionflagconfigname",
- type=str,
- default="text_element_prediction_flag_config",
- help="element prediction flag config option",
- )
- parser.add_argument(
- "--datadir",
- type=str,
- default="data",
- help="data location",
- )
- parser.add_argument(
- "--jobdir",
- type=str,
- default=".",
- help="results location",
- )
- parser.add_argument(
- "--job-dir",
- type=str,
- default=".",
- help="dummy",
- )
- parser.add_argument(
- "--weight",
- type=str,
- default="",
- help="weight file location",
- )
- parser.add_argument(
- "--gpu",
- action="store_true",
- help="gpu option",
- )
- parser.add_argument(
- "--use_extended_dataset",
- action="store_true",
- help="dataset option",
- )
- parser.add_argument(
- "--debug",
- action="store_true",
- help="debug option",
- )
- args = parser.parse_args()
- module = COMMANDS[args.command]
- module(args)
diff --git a/src/typography_generation/config/__init__.py b/src/typography_generation/config/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/src/typography_generation/config/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/typography_generation/config/attribute_config.py b/src/typography_generation/config/attribute_config.py
deleted file mode 100644
index 8e15c20..0000000
--- a/src/typography_generation/config/attribute_config.py
+++ /dev/null
@@ -1,533 +0,0 @@
-from dataclasses import dataclass
-from typing import Any, Union
-
-
-@dataclass
-class EmbeddingConfig:
- flag: bool = True
- inp_space: int = 256
- input_prefix: str = "prefix"
- emb_layer: Union[str, None] = "nn.Embedding"
- emb_layer_kwargs: Any = None
- specific_build: Union[str, None] = None
- specific_func: Union[str, None] = None
-
-
-@dataclass
-class PredictionConfig:
- flag: bool = True
- out_dim: int = 256
- layer: Union[str, None] = "nn.Linear"
- loss_type: str = "cre"
- loss_weight: float = 1.0
- ignore_label: int = -1
- decode_format: str = "cl"
- att_type: str = "semantic"
-
-
-@dataclass
-class TextElementContextPredictionAttributeConfig:
- text_font: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_font}",
- "${data_config.font_num}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "semantic",
- )
- text_font_emb: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_font_emb}",
- "${data_config.font_emb_dim}",
- "nn.Linear",
- "mfc_gan",
- 1.0,
- -10000,
- "emb",
- "semantic",
- )
- text_font_size: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_font_size}",
- "${data_config.small_spatio_bin}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "geometry",
- )
- text_font_size_raw: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_font_size_raw}",
- "1",
- "nn.Linear",
- "l1",
- 1.0,
- -1,
- "scalar",
- "geometry",
- )
- text_font_color: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_font_color}",
- "${data_config.color_bin}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "semantic",
- )
- text_angle: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_angle}",
- "${data_config.small_spatio_bin}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "geometry",
- )
- text_letter_spacing: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_letter_spacing}",
- "${data_config.small_spatio_bin}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "geometry",
- )
- text_line_height_scale: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_line_height_scale}",
- "${data_config.small_spatio_bin}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "geometry",
- )
- text_line_height_size: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_line_height_size}",
- "${data_config.small_spatio_bin}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "geometry",
- )
- text_capitalize: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_capitalize}",
- 2,
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "semantic",
- )
- text_align_type: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_align_type}",
- 3,
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "semantic",
- )
- text_center_y: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_center_y}",
- "${data_config.large_spatio_bin}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "geometry",
- )
- text_center_x: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_center_x}",
- "${data_config.large_spatio_bin}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "geometry",
- )
- text_height: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_height}",
- "${data_config.small_spatio_bin}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "geometry",
- )
- text_width: PredictionConfig = PredictionConfig(
- "${text_element_prediction_flag.text_width}",
- "${data_config.large_spatio_bin}",
- "nn.Linear",
- "cre",
- 1.0,
- -1,
- "cl",
- "geometry",
- )
-
-
-@dataclass
-class TextElementContextEmbeddingAttributeConfig:
- text_font: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_font}",
- "${data_config.font_num}",
- "text_font",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.font_num}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_font_size: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_font_size}",
- "${data_config.small_spatio_bin}",
- "text_font_size",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.small_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_font_size_raw: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_font_size_raw}",
- "${data_config.small_spatio_bin}",
- "text_font_size",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.small_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_font_color: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_font_color}",
- "${data_config.color_bin}",
- "text_font_color",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.color_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_line_count: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_line_count}",
- "${data_config.max_text_line_count}",
- "text_line_count",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.max_text_line_count}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_char_count: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_char_count}",
- "${data_config.max_text_char_count}",
- "text_char_count",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.max_text_char_count}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_height: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_height}",
- "${data_config.large_spatio_bin}",
- "text_height",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.large_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_width: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_width}",
- "${data_config.large_spatio_bin}",
- "text_width",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.large_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_top: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_top}",
- "${data_config.large_spatio_bin}",
- "text_top",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.large_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_left: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_left}",
- "${data_config.large_spatio_bin}",
- "text_left",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.large_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_center_y: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_center_y}",
- "${data_config.large_spatio_bin}",
- "text_center_y",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.large_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_center_x: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_center_x}",
- "${data_config.large_spatio_bin}",
- "text_center_x",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.large_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_align_type: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_align_type}",
- "3",
- "text_align_type",
- "nn.Embedding",
- {
- "num_embeddings": "3",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_angle: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_angle}",
- "${data_config.small_spatio_bin}",
- "text_angle",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.small_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_letter_spacing: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_letter_spacing}",
- "${data_config.small_spatio_bin}",
- "text_letter_spacing",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.small_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_line_height_scale: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_line_height_scale}",
- "${data_config.small_spatio_bin}",
- "text_line_height_scale",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.small_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_line_height_size: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_line_height_size}",
- "${data_config.small_spatio_bin}",
- "text_line_height_size",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.small_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
-
- text_capitalize: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_capitalize}",
- "2",
- "text_capitalize",
- "nn.Embedding",
- {
- "num_embeddings": "2",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_emb: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_emb}",
- "${model_config.clip_dim}",
- "text_emb",
- "nn.Linear",
- {
- "in_features": "${model_config.clip_dim}",
- "out_features": "${model_config.d_model}",
- },
- None,
- None,
- )
- text_local_img_emb: EmbeddingConfig = EmbeddingConfig(
- "${text_element_embedding_flag.text_local_img_emb}",
- "${model_config.clip_dim}",
- "text_local_img_emb",
- "nn.Linear",
- {
- "in_features": "${model_config.clip_dim}",
- "out_features": "${model_config.d_model}",
- },
- None,
- None,
- )
-
-
-@dataclass
-class CanvasContextEmbeddingAttributeConfig:
- canvas_bg_img: EmbeddingConfig = EmbeddingConfig(
- "${canvas_embedding_flag.canvas_bg_img}",
- 2048,
- "canvas_bg_img",
- None,
- None,
- "build_resnet_feat_extractor",
- "get_feat",
- )
- canvas_aspect_ratio: EmbeddingConfig = EmbeddingConfig(
- "${canvas_embedding_flag.canvas_aspect_ratio}",
- "${data_config.small_spatio_bin}",
- "canvas_aspect_ratio",
- "nn.Embedding",
- {
- "num_embeddings": "${data_config.small_spatio_bin}",
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- canvas_text_num: EmbeddingConfig = EmbeddingConfig(
- "${canvas_embedding_flag.canvas_text_num}",
- 50,
- "canvas_text_num",
- "nn.Embedding",
- {
- "num_embeddings": 50,
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- canvas_group: EmbeddingConfig = EmbeddingConfig(
- "${canvas_embedding_flag.canvas_group}",
- 6,
- "canvas_group",
- "nn.Embedding",
- {
- "num_embeddings": 6,
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- canvas_format: EmbeddingConfig = EmbeddingConfig(
- "${canvas_embedding_flag.canvas_format}",
- 67,
- "canvas_format",
- "nn.Embedding",
- {
- "num_embeddings": 67,
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- canvas_category: EmbeddingConfig = EmbeddingConfig(
- "${canvas_embedding_flag.canvas_category}",
- 23,
- "canvas_category",
- "nn.Embedding",
- {
- "num_embeddings": 23,
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- canvas_height: EmbeddingConfig = EmbeddingConfig(
- "${canvas_embedding_flag.canvas_height}",
- 46,
- "canvas_height",
- "nn.Embedding",
- {
- "num_embeddings": 46,
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- canvas_width: EmbeddingConfig = EmbeddingConfig(
- "${canvas_embedding_flag.canvas_width}",
- 41,
- "canvas_width",
- "nn.Embedding",
- {
- "num_embeddings": 41,
- "embedding_dim": "${model_config.d_model}",
- },
- None,
- None,
- )
- canvas_bg_img_emb: EmbeddingConfig = EmbeddingConfig(
- "${canvas_embedding_flag.canvas_bg_img_emb}",
- "${model_config.clip_dim}",
- "canvas_bg_img_emb",
- "nn.Linear",
- {
- "in_features": "${model_config.clip_dim}",
- "out_features": "${model_config.d_model}",
- },
- None,
- "canvas_bg_img_emb_layer",
- )
diff --git a/src/typography_generation/config/base_config_object.py b/src/typography_generation/config/base_config_object.py
deleted file mode 100644
index ece0512..0000000
--- a/src/typography_generation/config/base_config_object.py
+++ /dev/null
@@ -1,147 +0,0 @@
-from dataclasses import dataclass
-from typography_generation.config.attribute_config import (
- CanvasContextEmbeddingAttributeConfig,
- TextElementContextEmbeddingAttributeConfig,
- TextElementContextPredictionAttributeConfig,
-)
-
-
-@dataclass
-class MetaInfo:
- model_name: str = "bart"
- dataset: str = "crello"
- data_dir: str = "crello"
-
-
-@dataclass
-class DataConfig:
- font_num: int = 288
- large_spatio_bin: int = 64
- small_spatio_bin: int = 16
- color_bin: int = 64
- max_text_char_count: int = 50
- max_text_line_count: int = 50
- font_emb_type: str = "label"
- font_emb_weight: float = 1.0
- font_emb_dim: int = 40
- font_emb_name: str = "mfc"
- order_type: str = "raster_scan_order"
- seq_length: int = 50
-
-
-@dataclass
-class ModelConfig:
- d_model: int = 256
- num_encoder_layers: int = 4
- num_decoder_layers: int = 4
- n_head: int = 8
- dropout: float = 0.1
- bert_dim: int = 768
- clip_dim: int = 512
- mlp_dim: int = 3584
- mfc_dim: int = 3584
- bypass: bool = True
- seq_length: int = 50
- std_ratio: float = 1.0
-
-
-@dataclass
-class TrainConfig:
- epochs: int = 31
- save_epoch: int = 5
- batch_size: int = 32
- num_worker: int = 2
- train_only: bool = False
- show_interval: int = 100
- learning_rate: float = 0.0002
- weight_decay: float = 0.01
- optimizer: str = "adam"
-
-
-@dataclass
-class TestConfig:
- autoregressive_prediction: bool = False
- sampling_mode: str = "topp"
- sampling_param: float = 0
- sampling_param_geometry: float = 0
- sampling_param_semantic: float = 0
- sampling_num: int = 10
-
-
-@dataclass
-class TextElementEmbeddingFlag:
- text_font: bool = True
- text_font_size: bool = True
- text_font_size_raw: bool = False
- text_font_color: bool = True
- text_line_count: bool = True
- text_char_count: bool = True
- text_height: bool = False
- text_width: bool = False
- text_top: bool = False
- text_left: bool = False
- text_center_y: bool = True
- text_center_x: bool = True
- text_align_type: bool = False
- text_angle: bool = False
- text_letter_spacing: bool = False
- text_line_height_scale: bool = False
- text_line_height_size: bool = False
- text_capitalize: bool = False
- text_emb: bool = True
- text_local_img_emb: bool = True
-
-
-@dataclass
-class CanvasEmbeddingFlag:
- canvas_bg_img: bool = False
- canvas_bg_img_emb: bool = True
- canvas_aspect_ratio: bool = True
- canvas_text_num: bool = True
- canvas_group: bool = False
- canvas_format: bool = False
- canvas_category: bool = False
- canvas_width: bool = False
- canvas_height: bool = False
-
-
-@dataclass
-class TextElementPredictionTargetFlag:
- text_font: bool = True
- text_font_emb: bool = False
- text_font_size: bool = True
- text_font_size_raw: bool = False
- text_font_color: bool = True
- text_angle: bool = True
- text_letter_spacing: bool = True
- text_line_height_scale: bool = True
- text_line_height_size: bool = False
- text_capitalize: bool = True
- text_align_type: bool = True
- text_center_y: bool = False
- text_center_x: bool = False
- text_height: bool = False
- text_width: bool = False
-
-
-@dataclass
-class GlobalConfig:
- meta_info: MetaInfo = MetaInfo()
- model_config: ModelConfig = ModelConfig()
- data_config: DataConfig = DataConfig()
- train_config: TrainConfig = TrainConfig()
- test_config: TestConfig = TestConfig()
- text_element_embedding_flag: TextElementEmbeddingFlag = TextElementEmbeddingFlag()
- canvas_embedding_flag: CanvasEmbeddingFlag = CanvasEmbeddingFlag()
- text_element_prediction_flag: TextElementPredictionTargetFlag = (
- TextElementPredictionTargetFlag()
- )
- canvas_embedding_attribute_config: CanvasContextEmbeddingAttributeConfig = (
- CanvasContextEmbeddingAttributeConfig()
- )
- text_element_embedding_attribute_config: TextElementContextEmbeddingAttributeConfig = (
- TextElementContextEmbeddingAttributeConfig()
- )
- text_element_prediction_attribute_config: TextElementContextPredictionAttributeConfig = (
- TextElementContextPredictionAttributeConfig()
- )
diff --git a/src/typography_generation/config/config_args_util.py b/src/typography_generation/config/config_args_util.py
deleted file mode 100644
index fbe8adb..0000000
--- a/src/typography_generation/config/config_args_util.py
+++ /dev/null
@@ -1,199 +0,0 @@
-from typing import Any, Dict, Tuple
-
-import omegaconf
-from logzero import logger
-
-from typography_generation.config.base_config_object import GlobalConfig
-from typography_generation.config.default import (
- build_config,
- get_bindata,
- get_datapreprocess_config,
- get_font_config,
- get_model_input_prefix_list,
- get_target_prefix_list,
-)
-from typography_generation.io.data_object import (
- BinsData,
- DataPreprocessConfig,
- FontConfig,
- PrefixListObject,
- SamplingConfig,
-)
-
-
-def args2add_data_inputs(args: Any) -> Tuple:
- add_data_inputs = {}
- add_data_inputs["data_dir"] = args.datadir
- add_data_inputs["global_config_input"] = get_global_config_input(args.datadir, args)
- add_data_inputs["gpu"] = args.gpu
- add_data_inputs["weight"] = args.weight
- add_data_inputs["jobdir"] = args.jobdir
- add_data_inputs["debug"] = args.debug
- return add_data_inputs
-
-
-def get_conf(yaml_file: str) -> omegaconf:
- logger.info(f"load {yaml_file}")
- conf = omegaconf.OmegaConf.load(yaml_file)
- return conf
-
-
-def get_global_config(
- data_dir: str,
- model_name: str,
- test_config_name: str = "test_config",
- model_config_name: str = "model_config",
- elementembeddingflag_config_name: str = "text_element_embedding_flag_config",
- elementpredictionflag_config_name: str = "text_element_prediction_flag_config",
- canvasembeddingflag_config_name: str = "canvas_embedding_flag_config",
- elementembeddingatt_config_name: str = "text_element_embedding_attribute_config",
- elementpredictionatt_config_name: str = "text_element_prediction_attribute_config",
- canvasembeddingatt_config_name: str = "canvas_embedding_attribute_config",
-) -> GlobalConfig:
- metainfo_conf = get_conf(f"{data_dir}/config/{model_name}/metainfo.yaml")
- data_conf = get_conf(f"{data_dir}/config/{model_name}/data_config.yaml")
- model_conf = get_conf(f"{data_dir}/config/{model_name}/{model_config_name}.yaml")
- train_conf = get_conf(f"{data_dir}/config/{model_name}/train_config.yaml")
- test_conf = get_conf(f"{data_dir}/config/{model_name}/{test_config_name}.yaml")
- text_element_emb_flag_conf = get_conf(
- f"{data_dir}/config/{model_name}/{elementembeddingflag_config_name}.yaml"
- )
- canvas_emb_flag_conf = get_conf(
- f"{data_dir}/config/{model_name}/{canvasembeddingflag_config_name}.yaml"
- )
- text_element_pred_flag_conf = get_conf(
- f"{data_dir}/config/{model_name}/{elementpredictionflag_config_name}.yaml"
- )
- text_element_emb_attribute_conf = get_conf(
- f"{data_dir}/config/{model_name}/{elementembeddingatt_config_name}.yaml"
- )
- canvas_emb_attribute_conf = get_conf(
- f"{data_dir}/config/{model_name}/{canvasembeddingatt_config_name}.yaml"
- )
- text_element_pred_attribute_conf = get_conf(
- f"{data_dir}/config/{model_name}/{elementpredictionatt_config_name}.yaml"
- )
- config = build_config(
- metainfo_conf,
- data_conf,
- model_conf,
- train_conf,
- test_conf,
- text_element_emb_flag_conf,
- canvas_emb_flag_conf,
- text_element_pred_flag_conf,
- text_element_emb_attribute_conf,
- canvas_emb_attribute_conf,
- text_element_pred_attribute_conf,
- )
- return config
-
-
-def get_data_config(
- config: GlobalConfig,
-) -> Tuple[BinsData, FontConfig, DataPreprocessConfig]:
- bin_data = get_bindata(config)
- font_config = get_font_config(config)
- data_preprocess_config = get_datapreprocess_config(config)
- return bin_data, font_config, data_preprocess_config
-
-
-def get_sampling_config(
- config: GlobalConfig,
-) -> SamplingConfig:
- sampling_param = config.test_config.sampling_param
- sampling_param_geometry = config.test_config.sampling_param_geometry
- sampling_param_semantic = config.test_config.sampling_param_semantic
- sampling_num = config.test_config.sampling_num
- return SamplingConfig(
- sampling_param, sampling_param_geometry, sampling_param_semantic, sampling_num
- )
-
-
-def get_global_config_input(data_dir: str, args: Any) -> Dict:
- global_config_input = {}
- global_config_input["data_dir"] = data_dir
- global_config_input["model_name"] = args.configname
- global_config_input["test_config_name"] = args.testconfigname
- global_config_input["model_config_name"] = args.modelconfigname
- global_config_input[
- "elementembeddingflag_config_name"
- ] = args.elementembeddingflagconfigname
- global_config_input[
- "canvasembeddingflag_config_name"
- ] = args.canvasembeddingflagconfigname
- global_config_input[
- "elementpredictionflag_config_name"
- ] = args.elementpredictionflagconfigname
-
- return global_config_input
-
-
-def get_model_config_input(config: GlobalConfig) -> Tuple[str, Dict]:
- model_kwargs = {}
-
- model_kwargs["prefix_list_element"] = get_target_prefix_list(
- config.text_element_embedding_attribute_config
- )
- model_kwargs["prefix_list_canvas"] = get_target_prefix_list(
- config.canvas_embedding_attribute_config
- )
- model_kwargs["prefix_list_target"] = get_target_prefix_list(
- config.text_element_prediction_attribute_config
- )
- model_kwargs[
- "embedding_config_element"
- ] = config.text_element_embedding_attribute_config
- model_kwargs["embedding_config_canvas"] = config.canvas_embedding_attribute_config
- model_kwargs[
- "prediction_config_element"
- ] = config.text_element_prediction_attribute_config
- model_kwargs["d_model"] = config.model_config.d_model
- model_kwargs["n_head"] = config.model_config.n_head
- model_kwargs["dropout"] = config.model_config.dropout
- model_kwargs["num_encoder_layers"] = config.model_config.num_encoder_layers
- model_kwargs["num_decoder_layers"] = config.model_config.num_decoder_layers
- model_kwargs["seq_length"] = config.model_config.seq_length
- model_kwargs["std_ratio"] = config.model_config.std_ratio
- model_kwargs["bypass"] = config.model_config.bypass
-
- model_name = config.meta_info.model_name
- return model_name, model_kwargs
-
-
-def get_train_config_input(config: GlobalConfig, debug: bool) -> Dict:
- train_kwargs = {}
- if debug is True:
- train_kwargs["epochs"] = 1
- train_kwargs["batch_size"] = 2
- else:
- train_kwargs["epochs"] = config.train_config.epochs
- train_kwargs["batch_size"] = config.train_config.batch_size
- train_kwargs["save_epoch"] = config.train_config.save_epoch
- train_kwargs["num_worker"] = config.train_config.num_worker
- train_kwargs["learning_rate"] = config.train_config.learning_rate
- train_kwargs["show_interval"] = config.train_config.show_interval
- train_kwargs["optimizer_option"] = config.train_config.optimizer
- train_kwargs["weight_decay"] = config.train_config.weight_decay
- train_kwargs["debug"] = debug
- return train_kwargs
-
-
-def get_prefix_lists(config: GlobalConfig) -> PrefixListObject:
- prefix_list_textelement = get_target_prefix_list(
- config.text_element_embedding_attribute_config
- )
- prefix_list_canvas = get_target_prefix_list(
- config.canvas_embedding_attribute_config
- )
- prefix_list_model_input = get_model_input_prefix_list(config)
- prefix_list_target = get_target_prefix_list(
- config.text_element_prediction_attribute_config
- )
- prefix_list_object = PrefixListObject(
- prefix_list_textelement,
- prefix_list_canvas,
- prefix_list_model_input,
- prefix_list_target,
- )
- return prefix_list_object
diff --git a/src/typography_generation/config/default.py b/src/typography_generation/config/default.py
deleted file mode 100644
index c10b5f1..0000000
--- a/src/typography_generation/config/default.py
+++ /dev/null
@@ -1,128 +0,0 @@
-from typing import Any, List, Union
-from logzero import logger
-from omegaconf import OmegaConf
-from typography_generation.config.attribute_config import (
- CanvasContextEmbeddingAttributeConfig,
- TextElementContextEmbeddingAttributeConfig,
- TextElementContextPredictionAttributeConfig,
-)
-from typography_generation.config.base_config_object import (
- CanvasEmbeddingFlag,
- DataConfig,
- GlobalConfig,
- MetaInfo,
- ModelConfig,
- TestConfig,
- TextElementEmbeddingFlag,
- TextElementPredictionTargetFlag,
- TrainConfig,
-)
-from typography_generation.io.data_object import (
- BinsData,
- DataPreprocessConfig,
- FontConfig,
-)
-
-
-def get_bindata(config: GlobalConfig) -> BinsData:
- bin_data = BinsData(
- config.data_config.color_bin,
- config.data_config.small_spatio_bin,
- config.data_config.large_spatio_bin,
- )
- return bin_data
-
-
-def get_font_config(config: GlobalConfig) -> FontConfig:
- font_config = FontConfig(
- config.data_config.font_num,
- config.data_config.font_emb_type,
- config.data_config.font_emb_weight,
- config.data_config.font_emb_dim,
- config.data_config.font_emb_name,
- )
- return font_config
-
-
-def get_datapreprocess_config(config: GlobalConfig) -> DataPreprocessConfig:
- datapreprocess_config = DataPreprocessConfig(
- config.data_config.order_type,
- config.data_config.seq_length,
- )
- return datapreprocess_config
-
-
-def get_target_prefix_list(tar_cls: Any) -> List:
- all_text_prefix_list = dir(tar_cls)
- target_prefix_list = []
- for prefix in all_text_prefix_list:
- elm = getattr(tar_cls, prefix)
- if elm.flag is True:
- target_prefix_list.append(prefix)
- return target_prefix_list
-
-
-def get_model_input_prefix_list(conf: GlobalConfig) -> List:
- input_prefix_list = []
- input_prefix_list += get_target_prefix_list(
- conf.text_element_embedding_attribute_config
- )
- input_prefix_list += get_target_prefix_list(conf.canvas_embedding_attribute_config)
- input_prefix_list += get_target_prefix_list(
- conf.text_element_prediction_attribute_config
- )
- if "canvas_text_num" in input_prefix_list:
- pass
- else:
- input_prefix_list.append("canvas_text_num")
- return list(set(input_prefix_list))
-
-
-def plusone_num_embeddings(config: GlobalConfig) -> None:
- prefix_lists = dir(config.text_element_embedding_attribute_config)
- for prefix in prefix_lists:
- elm = getattr(config.text_element_embedding_attribute_config, f"{prefix}")
- if elm.emb_layer == "nn.Embedding":
- elm.emb_layer_kwargs["num_embeddings"] = (
- int(elm.emb_layer_kwargs["num_embeddings"]) + 1
- )
-
-
-def show_class_attributes(_class: Any):
- class_dict = _class.__dict__["_content"]
- logger.info(f"{class_dict}")
-
-
-def build_config(
- metainfo_conf: MetaInfo,
- data_conf: DataConfig,
- model_conf: ModelConfig,
- train_conf: TrainConfig,
- test_conf: TestConfig,
- text_element_emb_flag_conf: TextElementEmbeddingFlag,
- canvas_emb_flag_conf: CanvasEmbeddingFlag,
- text_element_pred_flag_conf: TextElementPredictionTargetFlag,
- text_element_emb_attribute_conf: TextElementContextEmbeddingAttributeConfig,
- canvas_emb_attribute_conf: CanvasContextEmbeddingAttributeConfig,
- text_element_pred_attribute_conf: TextElementContextPredictionAttributeConfig,
-) -> Union[Any, GlobalConfig]:
- conf = OmegaConf.structured(GlobalConfig)
- show_class_attributes(text_element_emb_flag_conf)
- show_class_attributes(canvas_emb_flag_conf)
- show_class_attributes(text_element_pred_flag_conf)
- conf = OmegaConf.merge(
- conf,
- metainfo_conf,
- data_conf,
- model_conf,
- train_conf,
- test_conf,
- text_element_emb_flag_conf,
- canvas_emb_flag_conf,
- text_element_pred_flag_conf,
- text_element_emb_attribute_conf,
- canvas_emb_attribute_conf,
- text_element_pred_attribute_conf,
- )
- plusone_num_embeddings(conf)
- return conf
diff --git a/src/typography_generation/io/__init__.py b/src/typography_generation/io/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/src/typography_generation/io/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/typography_generation/io/build_dataset.py b/src/typography_generation/io/build_dataset.py
deleted file mode 100644
index 14bacd3..0000000
--- a/src/typography_generation/io/build_dataset.py
+++ /dev/null
@@ -1,85 +0,0 @@
-from typing import Tuple
-from typography_generation.io.data_loader import CrelloLoader
-from typography_generation.io.data_object import (
- FontConfig,
- PrefixListObject,
-)
-from typography_generation.tools.tokenizer import Tokenizer
-import datasets
-from logzero import logger
-import os
-
-
-def build_train_dataset(
- data_dir: str,
- prefix_list_object: PrefixListObject,
- font_config: FontConfig,
- use_extended_dataset: bool = True,
- dataset_name: str = "crello",
- debug: bool = False,
-) -> Tuple[CrelloLoader, CrelloLoader]:
- tokenizer = Tokenizer(data_dir)
-
- if dataset_name == "crello":
- logger.info("load hugging dataset start")
- if use_extended_dataset is True:
- _dataset = datasets.load_from_disk(
- os.path.join(data_dir, "crello_map_features")
- )
- else:
- _dataset = datasets.load_dataset("cyberagent/crello", revision="3.1")
- logger.info("load hugging dataset done")
- dataset = CrelloLoader(
- data_dir,
- tokenizer,
- _dataset["train"],
- prefix_list_object,
- font_config,
- use_extended_dataset=use_extended_dataset,
- debug=debug,
- )
- dataset_val = CrelloLoader(
- data_dir,
- tokenizer,
- _dataset["validation"],
- prefix_list_object,
- font_config,
- use_extended_dataset=use_extended_dataset,
- debug=debug,
- )
- else:
- raise NotImplementedError()
- return dataset, dataset_val
-
-
-def build_test_dataset(
- data_dir: str,
- prefix_list_object: PrefixListObject,
- font_config: FontConfig,
- use_extended_dataset: bool = True,
- dataset_name: str = "crello",
- debug: bool = False,
-) -> CrelloLoader:
- tokenizer = Tokenizer(data_dir)
-
- if dataset_name == "crello":
- logger.info("load hugging dataset start")
- if use_extended_dataset is True:
- _dataset = datasets.load_from_disk(
- os.path.join(data_dir, "crello_map_features")
- )
- else:
- _dataset = datasets.load_dataset("cyberagent/crello", revision="3.1")
- logger.info("load hugging dataset done")
- dataset = CrelloLoader(
- data_dir,
- tokenizer,
- _dataset["test"],
- prefix_list_object,
- font_config,
- use_extended_dataset=use_extended_dataset,
- debug=debug,
- )
- else:
- raise NotImplementedError()
- return dataset
diff --git a/src/typography_generation/io/crello_util.py b/src/typography_generation/io/crello_util.py
deleted file mode 100644
index 85f1e7d..0000000
--- a/src/typography_generation/io/crello_util.py
+++ /dev/null
@@ -1,490 +0,0 @@
-import math
-import os
-import pickle
-from typing import Any, Dict, List, Tuple
-from einops import repeat
-import skia
-import numpy as np
-import PIL
-from PIL import Image
-import torch
-from typography_generation.io.data_object import FontConfig
-from typography_generation.tools.tokenizer import Tokenizer
-
-from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
-
-from typography_generation.visualization.renderer_util import (
- get_skia_font,
- get_text_actual_width,
- get_texts,
-)
-
-fontmgr = skia.FontMgr()
-
-
-class CrelloProcessor:
- def __init__(
- self,
- data_dir: str,
- tokenizer: Tokenizer,
- dataset: Any,
- font_config: FontConfig,
- use_extended_dataset: bool = True,
- seq_length: int = 50,
- ) -> None:
- self.data_dir = data_dir
- self.tokenizer = tokenizer
- self.font_config = font_config
- self.dataset = dataset
- self.seq_length = seq_length
- if font_config is not None:
- fn = os.path.join(
- self.data_dir, "font_emb", f"{font_config.font_emb_name}.pkl"
- )
- self.fontid2fontemb = pickle.load(open(fn, "rb"))
- self.use_extended_dataset = use_extended_dataset
- if not use_extended_dataset:
- fn = os.path.join(data_dir, "font2ttf.pkl")
- _font2ttf = pickle.load(open(fn, "rb"))
- font2ttf = {}
- for key in _font2ttf.keys():
- tmp = _font2ttf[key].split("/data/dataset/crello/")[1]
- fn = os.path.join(data_dir, tmp)
- font2ttf[key] = fn
- self.font2ttf = font2ttf
-
- fn = os.path.join(data_dir, "svgid2scaleinfo.pkl")
- self.svgid2scaleinfo = pickle.load(open(fn, "rb"))
- self.fontlabel2fontname = self.dataset.features["font"].feature.int2str
-
- self.processor = CLIPProcessor.from_pretrained(
- "openai/clip-vit-base-patch32"
- )
- self.text_tokenizer = CLIPTokenizer.from_pretrained(
- "openai/clip-vit-base-patch32"
- )
- self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
- # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.device = torch.device("cpu")
- self.model.to(self.device)
-
- def get_canvas_text_num(
- self, element_data: dict, **kwargs: Any
- ) -> Tuple[int, List]:
- text_num = 0
- for k in range(len(element_data["text"])):
- if element_data["text"][k] == "":
- pass
- else:
- text_num += 1
- text_num = min(text_num, self.seq_length)
- return text_num
-
- def get_canvas_text_ids(
- self, element_data: dict, **kwargs: Any
- ) -> Tuple[int, List]:
- text_ids = []
- for k in range(len(element_data["text"])):
- if element_data["text"][k] == "":
- pass
- else:
- text_ids.append(k)
- return text_ids
-
- def get_canvas_bg_size(self, element_data: dict) -> Tuple[int, int]:
- return element_data["canvas_bg_size"]
-
- def get_scale_box(self, element_data: dict) -> List:
- if self.use_extended_dataset:
- return tuple(element_data["scale_box"])
- else:
- svgid = element_data["id"]
- return self.svgid2scaleinfo[svgid]
-
- def get_text_font(self, element_data: dict, text_index: int, **kwargs: Any) -> int:
- font = element_data["font"][text_index] - 1
- return int(font)
-
- def denorm_text_font(self, val: int, **kwargs: Any) -> int:
- val += 1
- return val
-
- def get_text_font_emb(
- self,
- element_data: dict,
- text_index: int,
- **kwargs: Any,
- ) -> np.array:
- font = element_data["font"][text_index]
- font_emb = self.fontid2fontemb[font - 1]
- return font_emb
-
- def raw2token_text_font_emb(
- self,
- val: np.array,
- **kwargs: Any,
- ) -> int:
- vec = repeat(val, "c -> n c", n=len(self.fontid2fontemb))
- diff = (vec - self.fontid2fontemb) ** 2
- diff = diff.sum(1)
- font = int(np.argsort(diff)[0])
- return font
-
- def denorm_text_font_emb(self, val: int, **kwargs: Any) -> int:
- val += 1
- return val
-
- def get_skia_font(
- self, element_data: dict, text_index: int, scaleinfo: Tuple, **kwargs: Any
- ) -> int:
- font_label = element_data["font"][text_index]
- font_name = self.fontlabel2fontname(int(font_label))
- font_name = font_name.replace(" ", "_")
- scale_h, _ = scaleinfo
- font_skia, _ = get_skia_font(
- self.font2ttf,
- fontmgr,
- element_data,
- text_index,
- font_name,
- scale_h,
- )
- return font_skia
-
- def get_text_font_size(
- self,
- element_data: dict,
- text_index: int,
- img_size: Tuple,
- scaleinfo: Tuple,
- **kwargs: Any,
- ) -> float:
- fs = element_data["font_size"][text_index]
- scale_h, _ = scaleinfo
- h, _ = img_size
- val = fs * scale_h / h
- return float(val)
-
- def denorm_text_font_size(
- self, val: float, img_height: int, scale_h: float, **kwargs: Any
- ) -> float:
- val = val * img_height / scale_h
- return val
-
- def get_text_font_size_raw(
- self,
- element_data: dict,
- text_index: int,
- img_size: Tuple,
- scaleinfo: Tuple,
- **kwargs: Any,
- ) -> float:
- val = self.get_text_font_size(element_data, text_index, img_size, scaleinfo)
- return val
-
- def raw2token_text_font_size_raw(
- self,
- val: float,
- **kwargs: Any,
- ) -> int:
- val = self.tokenizer.tokenize("text_font_size", float(val))
- return val
-
- def denorm_text_font_size_raw(
- self, val: float, img_height: int, scale_h: float, **kwargs: Any
- ) -> float:
- val = val * img_height / scale_h
- return val
-
- def get_text_font_color(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> Tuple[int, int, int]:
- B, G, R = element_data["color"][text_index]
- return (R, G, B)
-
- def get_text_height(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> float:
- height = element_data["height"][text_index]
- return float(height)
-
- def denorm_text_height(self, val: float, img_height: int, **kwargs: Any) -> float:
- val = val * img_height
- return val
-
- def get_text_width(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> float:
- width = element_data["width"][text_index]
- return float(width)
-
- def denorm_text_width(self, val: float, img_width: int, **kwargs: Any) -> float:
- val = val * img_width
- return val
-
- def get_text_top(self, element_data: dict, text_index: int, **kwargs: Any) -> float:
- top = element_data["top"][text_index]
- return float(top)
-
- def denorm_text_top(self, val: float, img_height: int, **kwargs: Any) -> float:
- val = val * img_height
- return val
-
- def get_text_left(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> float:
- left = element_data["left"][text_index]
- return float(left)
-
- def denorm_text_left(self, val: float, img_width: int, **kwargs: Any) -> float:
- val = val * img_width
- return val
-
- def get_text_actual_width(
- self,
- element_data: dict,
- text_index: int,
- texts: List[str],
- font_skia: skia.Font,
- scaleinfo: Tuple[float, float],
- **kwargs: Any,
- ) -> float:
- _, scale_w = scaleinfo
- text_width = get_text_actual_width(
- element_data, text_index, texts, font_skia, scale_w
- )
- return text_width
-
- def get_text_center_y(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> float:
- if self.use_extended_dataset:
- return element_data["text_center_y"][text_index]
- else:
- text_height = element_data["height"][text_index]
- top = element_data["top"][text_index]
- center_y = (top + top + text_height) / 2.0
- return float(center_y)
-
- def get_text_center_x(
- self,
- element_data: dict,
- text_index: int,
- scaleinfo: Tuple,
- **kwargs: Any,
- ) -> float:
- if self.use_extended_dataset:
- return element_data["text_center_x"][text_index]
- else:
- left = element_data["left"][text_index]
- w = element_data["width"][text_index]
- textAlign = element_data["text_align"][text_index]
- texts = get_texts(element_data, text_index)
- font_skia = self.get_skia_font(element_data, text_index, scaleinfo)
- text_actual_width = self.get_text_actual_width(
- element_data, text_index, texts, font_skia, scaleinfo
- )
- right = left + w
- if textAlign == 1:
- center_x = (left + right) / 2.0
- elif textAlign == 3:
- center_x = right - text_actual_width / 2.0
- elif textAlign == 2:
- center_x = left + text_actual_width / 2.0
- return float(center_x)
-
- def get_text_align_type(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> int:
- align_type = element_data["text_align"][text_index] - 1
- return int(align_type)
-
- def denorm_text_align_type(self, val: int, **kwargs: Any) -> int:
- val += 1
- return val
-
- def get_text_capitalize(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> int:
- capitalize = element_data["capitalize"][text_index]
- return int(capitalize)
-
- def get_text_angle(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> float:
- angle = element_data["angle"][text_index]
- angle = (float(angle) * 180 / math.pi) / 360.0
- angle = math.modf(angle)[0]
- return float(angle)
-
- def denorm_text_angle(self, val: float, **kwargs: Any) -> float:
- val = val * 360 / 180.0 * math.pi
- return val
-
- def get_text_letter_spacing(
- self,
- element_data: dict,
- text_index: int,
- scaleinfo: Tuple,
- img_size: Tuple,
- **kwargs: Any,
- ) -> float:
- _, scale_w = scaleinfo
- _, W = img_size
- letter_space = element_data["letter_spacing"][text_index] * scale_w / W
- return float(letter_space)
-
- def denorm_text_letter_spacing(
- self, val: float, img_width: int, scale_w: float, **kwargs: Any
- ) -> float:
- val = val * img_width / scale_w
- return val
-
- def get_text_line_height_scale(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> float:
- line_height_scale = element_data["line_height"][text_index]
- return float(line_height_scale)
-
- def get_text_char_count(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> int:
- text = self.get_text(element_data, text_index)
- texts = text.split(os.linesep)
- max_char_count = 0
- for t in texts:
- max_char_count = max(max_char_count, len(t))
- return min(max_char_count, 50 - 1)
-
- def get_text_line_count(
- self, element_data: dict, text_index: int, **kwargs: Any
- ) -> int:
- texts = element_data["text"][text_index].split(os.linesep)
- line_count = 0
- for t in texts:
- if t == "":
- pass
- else:
- line_count += 1
- return min(line_count, 49)
-
- def get_text(self, element_data: dict, text_index: int) -> str:
- text = element_data["text"][text_index]
- return str(text)
-
- def get_canvas_aspect_ratio(
- self, element_data: dict, bg_img: Any, **kwargs: Any
- ) -> float:
- h, w = bg_img.size[1], bg_img.size[0]
- ratio = float(h) / float(w)
- return ratio
-
- def get_canvas_group(self, element_data: dict, **kwargs: Any) -> int:
- return element_data["group"]
-
- def get_canvas_format(self, element_data: dict, **kwargs: Any) -> int:
- return element_data["format"]
-
- def get_canvas_category(self, element_data: dict, **kwargs: Any) -> int:
- return element_data["category"]
-
- def get_canvas_width(self, element_data: dict, **kwargs: Any) -> int:
- return element_data["canvas_width"]
-
- def get_canvas_height(self, element_data: dict, **kwargs: Any) -> int:
- return element_data["canvas_height"]
-
- def get_canvas_bg_img_emb(
- self, element_data: dict, bg_img: PIL.Image, **kwargs: Any
- ) -> np.array:
- if self.use_extended_dataset:
- return np.array(element_data["canvas_bg_img_emb"])
- else:
- inputs = self.processor(images=[bg_img], return_tensors="pt")
- inputs["pixel_values"] = inputs["pixel_values"].to(self.device)
- image_feature = self.model.get_image_features(**inputs)
- return image_feature.data.numpy()
-
- def get_text_emb(
- self,
- element_data: dict,
- text_index: int,
- **kwargs: Any,
- ) -> np.array:
- if self.use_extended_dataset:
- return np.array(element_data["text_emb"][text_index])
- else:
- text = element_data["text"][text_index]
- inputs = self.text_tokenizer([text], padding=True, return_tensors="pt")
- if inputs["input_ids"].shape[1] > 77:
- inp = inputs["input_ids"][:, :77]
- else:
- inp = inputs["input_ids"]
- text_features = self.model.get_text_features(inp).data.numpy()[0]
- return text_features
-
- def get_text_local_img(
- self,
- img: Any,
- text_center_y: float,
- text_center_x: float,
- H: int,
- W: int,
- ) -> np.array:
- text_center_y = text_center_y * H
- text_center_x = text_center_x * W
-
- text_center_y = min(max(text_center_y, 0), H)
- text_center_x = min(max(text_center_x, 0), W)
- img = img.resize((640, 640))
- img = np.array(img)
- local_img_size = 64 * 5
- local_img_size_half = local_img_size // 2
- img_pad = np.zeros((640 + local_img_size, 640 + local_img_size, 3))
- img_pad[
- local_img_size_half : 640 + local_img_size_half,
- local_img_size_half : 640 + local_img_size_half,
- ] = img
- h_rate = 640 / float(H)
- w_rate = 640 / float(W)
- text_center_y = int(np.round(text_center_y * h_rate + local_img_size_half))
- text_center_x = int(np.round(text_center_x * w_rate + local_img_size_half))
- local_img = img_pad[
- text_center_y - local_img_size_half : text_center_y + local_img_size_half,
- text_center_x - local_img_size_half : text_center_x + local_img_size_half,
- ]
- return local_img
-
- def get_text_local_img_emb(
- self,
- element_data: dict,
- text_index: int,
- scaleinfo: Tuple,
- img_size: Tuple,
- bg_img: Any,
- **kwargs: Any,
- ) -> np.array:
- if self.use_extended_dataset:
- return np.array(element_data["text_local_img_emb"][text_index])
- else:
- H, W = img_size
- text_center_x = self.get_text_center_x(element_data, text_index, scaleinfo)
- text_center_y = self.get_text_center_y(element_data, text_index)
- local_img = self.get_text_local_img(
- bg_img.copy(), text_center_y, text_center_x, H, W
- )
- local_img = Image.fromarray(local_img.astype(np.uint8)).resize((224, 224))
- inputs = self.processor(images=[local_img], return_tensors="pt")
- inputs["pixel_values"] = inputs["pixel_values"].to(self.device)
- image_feature = self.model.get_image_features(**inputs)
- return image_feature.data.numpy()
-
- def load_samples(self, index: int) -> Tuple[Dict, Any, str, int]:
- element_data = self.dataset[index]
- svg_id = element_data["id"]
- fn = os.path.join(self.data_dir, "generate_bg_png", f"{svg_id}.png")
- bg = Image.open(fn).convert("RGB") # background image
- return element_data, bg, svg_id, index
-
- def __len__(self) -> int:
- return len(self.dataset)
diff --git a/src/typography_generation/io/data_loader.py b/src/typography_generation/io/data_loader.py
deleted file mode 100644
index 89b1771..0000000
--- a/src/typography_generation/io/data_loader.py
+++ /dev/null
@@ -1,123 +0,0 @@
-import time
-from typing import Any, Dict, List, Tuple
-import numpy as np
-
-import torch
-from logzero import logger
-from typography_generation.io.crello_util import CrelloProcessor
-from typography_generation.io.data_object import (
- DesignContext,
- FontConfig,
- PrefixListObject,
-)
-from typography_generation.io.data_utils import get_canvas_context, get_element_context
-from typography_generation.tools.tokenizer import Tokenizer
-
-
-class CrelloLoader(torch.utils.data.Dataset):
- def __init__(
- self,
- data_dir: str,
- tokenizer: Tokenizer,
- dataset: Any,
- prefix_list_object: PrefixListObject,
- font_config: FontConfig,
- use_extended_dataset: bool = True,
- seq_length: int = 50,
- debug: bool = False,
- ) -> None:
- super().__init__()
- self.data_dir = data_dir
- self.prefix_list_object = prefix_list_object
- self.debug = debug
- logger.debug("create crello dataset processor")
- self.dataset = CrelloProcessor(
- data_dir,
- tokenizer,
- dataset,
- font_config,
- use_extended_dataset=use_extended_dataset,
- )
- logger.debug("create crello dataset processor done")
- self.seq_length = seq_length
-
- def get_ordered_text_ids(self, element_data, order_list) -> List:
- text_ids = []
- for i in order_list:
- if element_data["text"][i] == "":
- pass
- else:
- text_ids.append(i)
- return text_ids
-
- def get_order_list(self, elm: Dict[str, Any]) -> List[int]:
- if self.dataset.use_extended_dataset:
- return elm["order_list"]
- else:
- """
- Sort elments based on the raster scan order.
- """
- center_y = []
- center_x = []
- scaleinfo = self.dataset.get_scale_box(elm)
- for text_id in range(len(elm["text"])):
- center_y.append(self.dataset.get_text_center_y(elm, text_id))
- center_x.append(self.dataset.get_text_center_x(elm, text_id, scaleinfo))
- center_y = np.array(center_y)
- center_x = np.array(center_x)
- sortedid = np.argsort(center_y * 1000 + center_x)
- return list(sortedid)
-
- def load_data(self, index: int) -> Tuple:
- logger.debug("load samples")
- element_data, bg_img, svg_id, index = self.dataset.load_samples(index)
-
- # extract text element indexes
- text_num = self.dataset.get_canvas_text_num(element_data)
-
- logger.debug("order elements")
- order_list = self.get_order_list(element_data)
- text_ids = self.get_ordered_text_ids(element_data, order_list)
- elment_prefix_list = (
- self.prefix_list_object.textelement + self.prefix_list_object.target
- )
- logger.debug("get_element_context")
- text_context = get_element_context(
- element_data,
- bg_img,
- self.dataset,
- elment_prefix_list,
- text_num,
- text_ids,
- )
-
- logger.debug("get_canvas_context")
- canvas_context = get_canvas_context(
- element_data,
- self.dataset,
- self.prefix_list_object.canvas,
- bg_img,
- text_num,
- )
- logger.debug("build design context object")
- design_context = DesignContext(text_context, canvas_context)
- return design_context, svg_id, element_data
-
- def __getitem__(self, index: int) -> Tuple[DesignContext, List, str]:
- logger.debug("load data")
- start = time.time()
- design_context, svg_id, element_data = self.load_data(index)
- logger.debug(f"load data {time.time() -start}")
- logger.debug("get model input list")
- model_input_list = design_context.get_model_inputs_from_prefix_list(
- self.prefix_list_object.model_input
- )
- logger.debug(f"get model input list {time.time() -start}")
- logger.debug("get model input list done")
- return design_context, model_input_list, svg_id, index
-
- def __len__(self) -> int:
- if self.debug is True:
- return 2
- else:
- return len(self.dataset)
diff --git a/src/typography_generation/io/data_object.py b/src/typography_generation/io/data_object.py
deleted file mode 100644
index f0fb0eb..0000000
--- a/src/typography_generation/io/data_object.py
+++ /dev/null
@@ -1,266 +0,0 @@
-from dataclasses import dataclass
-from typing import Any, Dict, List, Tuple, Union
-
-import numpy as np
-import torch
-from logzero import logger
-from torch import Tensor
-
-from typography_generation.config.attribute_config import (
- TextElementContextEmbeddingAttributeConfig,
-)
-
-
-@dataclass
-class PrefixListObject:
- textelement: List
- canvas: List
- model_input: List
- target: List
-
-
-@dataclass
-class BinsData:
- color_bin: int
- small_spatio_bin: int
- large_spatio_bin: int
-
-
-@dataclass
-class FontConfig:
- font_num: int
- font_emb_type: str = "label"
- font_emb_weight: float = 1.0
- font_emb_dim: int = 40
- font_emb_name: str = "mfc"
-
-
-@dataclass
-class DataPreprocessConfig:
- order_type: str
- seq_length: int
-
-
-@dataclass
-class SamplingConfig:
- sampling_param: str
- sampling_param_geometry: float
- sampling_param_semantic: float
- sampling_num: int
-
-
-class ModelInput:
- def __init__(self, design_context_list: List, model_input: List, gpu: bool) -> None:
- self.design_context_list = design_context_list
- self.model_input = model_input
- self.prefix_list = self.design_context_list[0].model_input_prefix_list
- self.gpu = gpu
- self.reset()
-
- def reset(self) -> None:
- self.setgt()
- if self.gpu is True:
- self.cuda()
-
- def setgt(self) -> None:
- if len(self.prefix_list) != len(self.model_input):
- raise ValueError("The length between list and input is different.")
- self.batch_num = self.model_input[0].shape[0]
- for prefix, elm in zip(self.prefix_list, self.model_input):
- setattr(self, f"{prefix}", elm.clone())
- if prefix == "canvas_text_num":
- self.canvas_text_num = elm.clone()
-
- def cuda(self) -> None:
- for prefix in self.prefix_list:
- tar = getattr(self, f"{prefix}")
- if type(tar) == Tensor:
- setattr(self, f"{prefix}", tar.cuda())
-
- def target_register(self, prefix: str, elm: Any) -> None:
- setattr(self, f"{prefix}", elm)
-
- def zeroinitialize_style_attributes(self, prefix_list: List) -> None:
- for prefix in prefix_list:
- if prefix == "canvas_text_num":
- continue
- tar = getattr(self, f"{prefix}")
- tar_rep = torch.zeros_like(tar)
- setattr(self, f"{prefix}", tar_rep)
-
- def zeroinitialize_specific_attribute(self, prefix: str) -> None:
- tar = getattr(self, f"{prefix}")
- if prefix != "canvas_text_num":
- setattr(self, f"{prefix}", torch.zeros_like(tar))
-
- def setgt_specific_attribute(self, prefix_tar: str) -> None:
- for prefix, elm in zip(self.prefix_list, self.model_input):
- if prefix == prefix_tar:
- tar = elm.clone()
- if self.gpu is True and type(tar) == Tensor:
- tar = tar.cuda()
- setattr(self, f"{prefix}", tar)
-
- def update_th_style_attributes(
- self,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- prefix_list: List,
- model_output: Tensor,
- text_index: int,
- batch_index: int = 0,
- ) -> None:
- for prefix in prefix_list:
- target_embedding_config = getattr(embedding_config_element, prefix)
- tar = getattr(self, f"{prefix}")
- out = model_output[f"{prefix}"]
- tar[batch_index, text_index] = out
- setattr(self, f"{target_embedding_config.input_prefix}", tar)
-
- def zeroinitialize_th_style_attributes(
- self,
- prefix_list: List,
- text_index: int,
- batch_index: int = 0,
- ) -> None:
- for prefix in prefix_list:
- tar = getattr(self, f"{prefix}")
- tar[batch_index, text_index] = 0
- setattr(self, f"{prefix}", tar)
-
- def additional_input_from_design_context_list(
- self, design_context_list: List
- ) -> None:
- self.texts = []
- for design_context in design_context_list:
- self.texts.append(design_context.text_context.texts)
-
-
-@dataclass
-class ElementContext:
- prefix_list: List
- rawdata2token: Dict
- img_size: Tuple
- scaleinfo: Tuple
- seq_length: int
-
- def __post_init__(self) -> None:
- for prefix in self.prefix_list:
- setattr(self, prefix, [])
- if prefix == "text_local_img":
- self.text_local_img_model_input = np.zeros(
- (self.seq_length, 3, 224, 224), dtype=np.float32
- )
- elif prefix == "text_local_img_emb":
- self.text_local_img_emb_model_input = np.zeros(
- (self.seq_length, 512), dtype=np.float32
- )
- elif prefix == "text_emb":
- self.text_emb_model_input = np.zeros(
- (self.seq_length, 512), dtype=np.float32
- )
- elif prefix == "text_font_emb":
- self.text_font_emb_model_input = (
- np.zeros((self.seq_length, 40), dtype=np.float32) - 10000
- )
- else:
- setattr(
- self,
- f"{prefix}_model_input",
- np.zeros((self.seq_length), dtype=np.float32) - 1,
- )
-
- if prefix in self.rawdata2token.keys():
- setattr(
- self,
- f"{self.rawdata2token[prefix]}_model_input",
- np.zeros((self.seq_length), dtype=np.float32) - 1,
- )
-
-
-@dataclass
-class CanvasContext:
- canvas_bg_img: np.array
- canvas_text_num: int
- img_size: Tuple
- scale_box: Tuple
- prefix_list: List
-
-
-@dataclass
-class DesignContext:
- element_context: ElementContext
- canvas_context: CanvasContext
-
- def __post_init__(self) -> None:
- self.prepare_keys()
-
- def prepare_keys(self) -> None:
- self.canvas_context_keys = dir(self.canvas_context)
- self.element_context_keys = dir(self.element_context)
-
- def get_text_num(self) -> int:
- return self.canvas_context.canvas_text_num
-
- def convert_target_to_torch_format(
- self,
- tar: Any,
- ) -> Any:
- if type(tar) == np.ndarray:
- tar = torch.from_numpy(tar)
- elif type(tar) == float or type(tar) == int:
- tar = torch.Tensor([tar])
- elif type(tar) == str or type(tar) == list:
- pass
- else:
- logger.info(tar)
- raise NotImplementedError()
- return tar
-
- def search_class(self, prefix: str) -> Union[ElementContext, CanvasContext]:
- if prefix in self.canvas_context_keys:
- return self.canvas_context
- elif prefix in self.element_context_keys:
- return self.element_context
- else:
- logger.info(
- f"{prefix}, {self.canvas_context_keys}, {self.element_context_keys}"
- )
- raise NotImplementedError()
-
- def get_model_inputs_from_prefix_list(self, prefix_list: List) -> List:
- self.model_input_prefix_list = prefix_list
- model_inputs = []
- for prefix in prefix_list:
- logger.debug(f"convert_target_to_torch_format {prefix}")
- tar_cls = self.search_class(f"{prefix}")
- tar = getattr(tar_cls, f"{prefix}_model_input")
- tar = self.convert_target_to_torch_format(tar)
- model_inputs.append(tar)
- return model_inputs
-
- def get_data(self, prefix: str) -> Any:
- tar_cls = self.search_class(f"{prefix}")
- tar = getattr(tar_cls, f"{prefix}")
- return tar
-
- def get_canvas_size(self) -> Tuple:
- canvas_size = (
- self.canvas_context.canvas_img_size_h,
- self.canvas_context.canvas_img_size_w,
- )
- return canvas_size
-
- def get_text_context(self) -> ElementContext:
- return self.element_context
-
- def get_bg(self) -> np.array:
- return self.canvas_context.canvas_bg_img
-
- def get_scaleinfo(self) -> Tuple:
- return (self.canvas_context.canvas_h_scale, self.canvas_context.canvas_w_scale)
-
- def convert_torch_format(self, prefix: str) -> None:
- tar_cls = self.search_class(prefix)
- tar = getattr(tar_cls, prefix)
- tar = self.convert_target_to_torch_format(tar)
- setattr(tar_cls, prefix, tar)
diff --git a/src/typography_generation/io/data_utils.py b/src/typography_generation/io/data_utils.py
deleted file mode 100644
index 3d76d71..0000000
--- a/src/typography_generation/io/data_utils.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import time
-from typing import List, Tuple
-import PIL
-import numpy as np
-from typography_generation.io.crello_util import CrelloProcessor
-from typography_generation.io.data_object import CanvasContext, ElementContext
-from logzero import logger
-
-
-def get_canvas_context(
- element_data: dict,
- dataset: CrelloProcessor,
- canvas_prefix_list: List,
- bg_img: np.array,
- text_num: int,
-) -> CanvasContext:
- img_size = (bg_img.size[1], bg_img.size[0]) # (h,w)
- scale_box = dataset.get_scale_box(element_data)
- canvas_context = CanvasContext(
- bg_img, text_num, img_size, scale_box, canvas_prefix_list
- )
- for prefix in canvas_prefix_list:
- data_info = {"element_data": element_data, "bg_img": bg_img}
- data = getattr(dataset, f"get_{prefix}")(**data_info)
- setattr(canvas_context, f"{prefix}", data)
- setattr(canvas_context, f"{prefix}_model_input", data)
-
- return canvas_context
-
-
-def get_element_context(
- element_data: dict,
- bg_img: PIL.Image,
- dataset: CrelloProcessor,
- element_prefix_list: List,
- text_num: int,
- text_ids: List,
-) -> ElementContext:
- scaleinfo = dataset.get_scale_box(element_data)
- img_size = (bg_img.size[1], bg_img.size[0]) # (h,w)
- element_context = ElementContext(
- element_prefix_list,
- dataset.tokenizer.rawdata2token,
- img_size,
- scaleinfo,
- dataset.seq_length,
- )
- for i in range(text_num):
- text_index = text_ids[i]
- for prefix in element_prefix_list:
- start = time.time()
-
- data_info = {
- "element_data": element_data,
- "text_index": text_index,
- "img_size": img_size,
- "scaleinfo": scaleinfo,
- "bg_img": bg_img,
- }
- if prefix == "text_emb":
- data = getattr(dataset, f"get_{prefix}")(**data_info)
- element_context.text_emb_model_input[i] = data
- elif prefix == "text_local_img_emb":
- data = getattr(dataset, f"get_{prefix}")(**data_info)
- element_context.text_local_img_emb_model_input[i] = data
- elif prefix == "text_font_emb":
- data = dataset.get_text_font_emb(element_data, text_index)
- getattr(element_context, prefix).append(data)
- element_context.text_font_emb_model_input[i] = data
- else:
- data = getattr(dataset, f"get_{prefix}")(**data_info)
- getattr(element_context, prefix).append(data)
- if prefix in dataset.tokenizer.prefix_list:
- model_input = dataset.tokenizer.tokenize(prefix, data)
- else:
- model_input = data
- getattr(element_context, f"{prefix}_model_input")[i] = model_input
-
- if prefix in dataset.tokenizer.rawdata_list:
- token = getattr(dataset, f"raw2token_{prefix}")(data)
- getattr(
- element_context,
- f"{dataset.tokenizer.rawdata2token[prefix]}_model_input",
- )[i] = token
-
- end = time.time()
- logger.debug(f"{prefix} {end - start}")
-
- return element_context
diff --git a/src/typography_generation/model/__init__.py b/src/typography_generation/model/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/src/typography_generation/model/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/typography_generation/model/bart.py b/src/typography_generation/model/bart.py
deleted file mode 100644
index 77a8fe7..0000000
--- a/src/typography_generation/model/bart.py
+++ /dev/null
@@ -1,522 +0,0 @@
-import random
-import time
-from typing import Any, Dict, List, Tuple, Union
-
-import numpy as np
-import torch
-from logzero import logger
-from torch import Tensor, nn
-from torch.functional import F
-from typography_generation.config.attribute_config import (
- CanvasContextEmbeddingAttributeConfig,
- TextElementContextEmbeddingAttributeConfig,
- TextElementContextPredictionAttributeConfig,
-)
-from typography_generation.io.crello_util import CrelloProcessor
-from typography_generation.io.data_object import ModelInput
-from typography_generation.model.decoder import Decoder, MultiTask
-from typography_generation.model.embedding import Embedding
-from typography_generation.model.encoder import Encoder
-
-FILTER_VALUE = -float("Inf")
-
-
-def top_p(
- prior: Tensor,
- text_index: int,
- sampling_param: float,
-) -> int:
- prior = F.softmax(prior, 1)
- sorted_prob, sorted_label = torch.sort(input=prior, dim=1, descending=True)
- prior = sorted_prob[text_index]
- sum_p = 0
- for k in range(len(prior)):
- sum_p += prior[k].item()
- if sum_p > sampling_param: # prior
- break
-
- range_class = k
- if range_class == 0:
- index = 0
- else:
- index = random.randint(0, range_class)
- out_label = sorted_label[text_index][index].item()
- return out_label
-
-
-def top_p_weight(
- logits: Tensor,
- text_index: int,
- sampling_param: float,
-) -> int:
- sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=1)
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=1), dim=1)
- S = logits.size(1)
- indices = torch.arange(S).view(1, S).to(logits.device)
- # make sure to keep the first logit (most likely one)
- sorted_logits[(cumulative_probs > sampling_param) & (indices > 0)] = FILTER_VALUE
- logits = sorted_logits.gather(dim=1, index=sorted_indices.argsort(dim=1))
- probs = F.softmax(logits, dim=1)
- output = torch.multinomial(probs, num_samples=1) # (B, 1)
- logger.debug(f"{output.shape}=")
- return output[text_index][0].item()
-
-
-sampler_dict = {"top_p": top_p, "top_p_weight": top_p_weight}
-
-
-def sample_label(
- logits: Tensor, text_index: int, sampling_param: float, mode: str = "top_p"
-) -> int:
- return sampler_dict[mode](logits, text_index, sampling_param)
-
-
-def get_structure(pred: np.array, indexes: List) -> Dict:
- index2samestructureindexes: Dict[str, List]
- index2samestructureindexes = {}
- for i in indexes:
- index2samestructureindexes[i] = []
- label_i = pred[i]
- for j in indexes:
- label_j = pred[j]
- if label_i == label_j:
- index2samestructureindexes[i].append(j)
- return index2samestructureindexes
-
-
-def get_structure_dict(prefix_list: List, preds: Dict, indexes: List) -> Dict:
- index2samestructureindexes = {}
- for prefix in prefix_list:
- index2samestructureindexes[prefix] = get_structure(
- preds[prefix],
- indexes,
- )
- return index2samestructureindexes
-
-
-def get_init_label_link(indexes: List) -> Dict:
- label_link: Dict[str, Union[int, None]]
- label_link = {}
- for i in indexes:
- label_link[i] = None
- return label_link
-
-
-def initialize_link(prefix_list: List, indexes: List) -> Tuple[Dict, Dict]:
- label_link: Dict[str, Dict]
- label_link = {}
- used_labels: Dict[str, List]
- used_labels = {}
- for prefix in prefix_list:
- label_link[prefix] = get_init_label_link(indexes)
- used_labels[prefix] = []
- return label_link, used_labels
-
-
-def label_linkage(
- prefix_list: List,
- text_num: int,
- preds: Dict,
-) -> Tuple[Dict, Dict, Dict]:
- indexes = list(range(text_num))
- index2samestructureindexes = get_structure_dict(prefix_list, preds, indexes)
- label_link, used_labels = initialize_link(prefix_list, indexes)
- return index2samestructureindexes, label_link, used_labels
-
-
-def update_label_link(label_link: Dict, samestructureindexes: List, label: int) -> Dict:
- for i in samestructureindexes:
- label_link[i] = label
- return label_link
-
-
-class BART(nn.Module):
- def __init__(
- self,
- prefix_list_element: List,
- prefix_list_canvas: List,
- prefix_list_target: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- prediction_config_element: TextElementContextPredictionAttributeConfig,
- d_model: int = 256,
- n_head: int = 8,
- dropout: float = 0.1,
- num_encoder_layers: int = 4,
- num_decoder_layers: int = 4,
- seq_length: int = 50,
- bypass: bool = True,
- **kwargs: Any,
- ) -> None:
- super().__init__()
- logger.info(f"BART settings")
- logger.info(f"d_model: {d_model}")
- logger.info(f"n_head: {n_head}")
- logger.info(f"num_encoder_layers: {num_encoder_layers}")
- logger.info(f"num_decoder_layers: {num_decoder_layers}")
- logger.info(f"seq_length: {seq_length}")
- self.embedding_config_element = embedding_config_element
-
- self.emb = Embedding(
- prefix_list_element,
- prefix_list_canvas,
- embedding_config_element,
- embedding_config_canvas,
- d_model=d_model,
- dropout=dropout,
- seq_length=seq_length,
- )
- self.enc = Encoder(
- d_model=d_model,
- n_head=n_head,
- dropout=dropout,
- num_encoder_layers=num_encoder_layers,
- )
- self.dec = Decoder(
- prefix_list_target,
- embedding_config_element,
- d_model=d_model,
- n_head=n_head,
- dropout=dropout,
- num_decoder_layers=num_decoder_layers,
- seq_length=seq_length,
- )
- self.head = MultiTask(
- prefix_list_target,
- prediction_config_element,
- d_model,
- bypass=bypass,
- )
- self.initialize_weights()
-
- for prefix in prefix_list_target:
- target_prediction_config = getattr(prediction_config_element, prefix)
- setattr(self, f"{prefix}_att_type", target_prediction_config.att_type)
-
- def forward(self, model_inputs: ModelInput) -> Tensor:
- start = time.time()
- (
- src,
- text_mask_src,
- feat_cat,
- ) = self.emb(model_inputs)
- logger.debug(f"{time.time()-start} sec emb")
- start = time.time()
- z = self.enc(src, text_mask_src)
- logger.debug(f"{time.time()-start} sec enc")
- start = time.time()
- zd = self.dec(feat_cat, z, model_inputs)
- logger.debug(f"{time.time()-start} sec dec")
- start = time.time()
- outs = self.head(zd, feat_cat)
- logger.debug(f"{time.time()-start} sec head")
- return outs
-
- def tokenize_model_out(
- self,
- dataset: CrelloProcessor,
- prefix: str,
- model_out: Tensor,
- batch_index,
- text_index,
- ) -> int:
- if prefix in dataset.tokenizer.rawdata_list:
- data = model_out[batch_index][text_index].data.cpu().numpy()
- out_label = getattr(dataset, f"raw2token_{prefix}")(data)
- else:
- sorted_label = torch.sort(
- input=model_out[batch_index], dim=1, descending=True
- )[1]
- target_label = 0 # top1
- out_label = sorted_label[text_index][target_label].item()
- return out_label
-
- def get_labels(
- self,
- model_outs: Dict,
- dataset: CrelloProcessor,
- target_prefix_list: List,
- text_index: int,
- batch_index: int = 0,
- ) -> Dict:
- out_labels = {}
- for prefix in target_prefix_list:
- out = model_outs[f"{prefix}"]
- out_label = self.tokenize_model_out(
- dataset, prefix, out, batch_index, text_index
- )
-
- out_labels[f"{prefix}"] = out_label
-
- return out_labels
-
- def get_outs(
- self,
- model_outs: Dict,
- dataset: CrelloProcessor,
- target_prefix_list: List,
- text_index: int,
- batch_index: int = 0,
- ) -> Dict:
- outs = {}
- for prefix in target_prefix_list:
- out = model_outs[f"{prefix}"]
- if prefix in dataset.tokenizer.rawdata_list:
- data = out[batch_index][text_index].data.cpu().numpy()
- else:
- sorted_label = torch.sort(
- input=out[batch_index], dim=1, descending=True
- )[1]
- target_label = 0 # top1
- data = sorted_label[text_index][target_label].item()
-
- outs[f"{prefix}"] = data
-
- return outs
-
- def store(
- self,
- outs_all: Dict,
- outs: Dict,
- target_prefix_list: List,
- text_index: int,
- ) -> Dict:
- for prefix in target_prefix_list:
- outs_all[f"{prefix}"][text_index] = outs[f"{prefix}"]
- return outs_all
-
- def prediction(
- self,
- model_inputs: ModelInput,
- dataset: CrelloProcessor,
- target_prefix_list: List,
- start_index: int = 0,
- ) -> Tensor:
- target_text_num = int(model_inputs.canvas_text_num[0].item())
- start_index = min(start_index, target_text_num)
- for t in range(start_index, target_text_num):
- model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t)
- src, text_mask_src, feat_cat = self.emb(model_inputs)
- z = self.enc(src, text_mask_src)
-
- outs_all = {}
- for prefix in target_prefix_list:
- outs_all[prefix] = {}
- for t in range(0, start_index):
- tar = getattr(model_inputs, f"{prefix}")[0, t].item()
- outs_all[prefix][t] = tar
- for t in range(start_index, target_text_num):
- zd = self.dec(feat_cat, z, model_inputs)
- model_outs = self.head(zd, feat_cat)
- out_labels = self.get_labels(model_outs, dataset, target_prefix_list, t)
- outs = self.get_outs(model_outs, dataset, target_prefix_list, t)
- model_inputs.update_th_style_attributes(
- self.embedding_config_element, target_prefix_list, out_labels, t
- )
- outs_all = self.store(outs_all, outs, target_prefix_list, t)
- return outs_all
-
- def get_transformer_weight(
- self,
- model_inputs: ModelInput,
- dataset: CrelloProcessor,
- target_prefix_list: List,
- start_index: int = 0,
- ) -> Tensor:
- target_text_num = int(model_inputs.canvas_text_num[0].item())
- start_index = min(start_index, target_text_num)
- for t in range(start_index, target_text_num):
- model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t)
- src, text_mask_src, feat_cat = self.emb(model_inputs)
- z = self.enc(src, text_mask_src)
- out_labels_all = {}
- for prefix in target_prefix_list:
- out_labels_all[prefix] = np.zeros((target_text_num, 1))
- for t in range(0, start_index):
- tar = getattr(model_inputs, f"{prefix}")[0, t].item()
- out_labels_all[prefix][t, 0] = tar
- weights = []
- for t in range(start_index, target_text_num):
- zd, _weights = self.dec.get_transformer_weight(feat_cat, z, model_inputs)
- weights.append(_weights[0, t])
- model_outs = self.head(zd, feat_cat)
- out_labels = self.get_labels(model_outs, dataset, target_prefix_list, t)
- model_inputs.update_th_style_attributes(
- self.embedding_config_element, target_prefix_list, out_labels, t
- )
- out_labels_all = self.store(
- out_labels_all, out_labels, target_prefix_list, t
- )
- if len(weights) > 0:
- weights = torch.stack(weights, dim=0)
- return weights
- else:
- dummy_weights = torch.zeros((1, 1)).to(src.device)
- return dummy_weights
-
- def sample_labels(
- self,
- model_outs: Dict,
- target_prefix_list: List,
- text_index: int,
- sampling_param_geometry: float = 0.5,
- sampling_param_semantic: float = 0.9,
- batch_index: int = 0,
- ) -> Dict:
- out_labels = {}
- for prefix in target_prefix_list:
- out = model_outs[f"{prefix}"][batch_index]
- if getattr(self, f"{prefix}_att_type") == "semantic":
- sampling_param = sampling_param_semantic
- elif getattr(self, f"{prefix}_att_type") == "geometry":
- sampling_param = sampling_param_geometry
- out_label = sample_label(out, text_index, sampling_param)
- out_labels[f"{prefix}"] = out_label
-
- return out_labels
-
- def sample(
- self,
- model_inputs: ModelInput,
- target_prefix_list: List,
- sampling_param_geometry: float = 0.7,
- sampling_param_semantic: float = 0.7,
- start_index: int = 0,
- **kwargs: Any,
- ) -> Tensor:
- target_text_num = int(model_inputs.canvas_text_num[0].item())
- start_index = min(start_index, target_text_num)
- for t in range(start_index, target_text_num):
- model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t)
- src, text_mask_src, feat_cat = self.emb(model_inputs)
- z = self.enc(src, text_mask_src)
-
- outs_all = {}
- for prefix in target_prefix_list:
- outs_all[prefix] = {}
- for t in range(start_index, target_text_num):
- zd = self.dec(feat_cat, z, model_inputs)
- model_outs = self.head(zd, feat_cat)
- out_labels = self.sample_labels(
- model_outs,
- target_prefix_list,
- t,
- sampling_param_geometry,
- sampling_param_semantic,
- )
- model_inputs.update_th_style_attributes(
- self.embedding_config_element, target_prefix_list, out_labels, t
- )
- outs_all = self.store(outs_all, out_labels, target_prefix_list, t)
- return outs_all
-
- def sample_labels_with_structure(
- self,
- model_outs: Dict,
- target_prefix_list: List,
- text_index: int,
- label_link: Dict,
- used_labels: Dict,
- index2samestructureindexes: Dict,
- sampling_param_geometry: float = 0.5,
- sampling_param_semantic: float = 0.9,
- batch_index: int = 0,
- ) -> Dict:
- outs_all = {}
- for prefix in target_prefix_list:
- _label_link = label_link[prefix][text_index]
- _used_labels = used_labels[prefix]
- if _label_link is None:
- out = model_outs[f"{prefix}"][batch_index]
- cnt = 0
- sampling_type = getattr(self, f"{prefix}_att_type")
- if sampling_type == "semantic":
- sampling_param = sampling_param_semantic
- elif sampling_type == "geometry":
- sampling_param = sampling_param_geometry
-
- out_label = sample_label(out, text_index, sampling_param)
- max_val = max(
- torch.sum(F.softmax(out, 1)[text_index]).item(), sampling_param
- )
- while out_label in _used_labels:
- out_label = sample_label(out, text_index, sampling_param)
- cnt += 1
- if cnt > 10:
- sampling_param += abs((max_val - sampling_param) * 0.1)
- if cnt > 1000:
- sampling_param *= 2
-
- samestructureindexes = index2samestructureindexes[prefix][text_index]
- label_link[prefix] = update_label_link(
- label_link[prefix], samestructureindexes, out_label
- )
- used_labels[prefix].append(out_label)
- else:
- out_label = _label_link
- outs_all[f"{prefix}"] = out_label
-
- return outs_all
-
- def structure_preserved_sample(
- self,
- model_inputs: ModelInput,
- dataset: CrelloProcessor,
- target_prefix_list: List,
- sampling_param_geometry: float = 0.7,
- sampling_param_semantic: float = 0.7,
- start_index: int = 0,
- ) -> Tensor:
- target_text_num = int(model_inputs.canvas_text_num[0].item())
- preds = self.prediction(model_inputs, dataset, target_prefix_list)
- index2samestructureindexes, label_link, used_labels = label_linkage(
- target_prefix_list, target_text_num, preds
- )
-
- for t in range(start_index, target_text_num):
- model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t)
- src, text_mask_src, feat_cat = self.emb(model_inputs)
- z = self.enc(src, text_mask_src)
-
- outs_all = {}
- for prefix in target_prefix_list:
- outs_all[prefix] = {}
- for t in range(start_index, target_text_num):
- zd = self.dec(feat_cat, z, model_inputs)
- model_outs = self.head(zd, feat_cat)
- out_labels = self.sample_labels_with_structure(
- model_outs,
- target_prefix_list,
- t,
- label_link,
- used_labels,
- index2samestructureindexes,
- sampling_param_geometry,
- sampling_param_semantic,
- )
- model_inputs.update_th_style_attributes(
- self.embedding_config_element, target_prefix_list, out_labels, t
- )
- outs_all = self.store(outs_all, out_labels, target_prefix_list, t)
- return outs_all
-
- def initialize_weights(self) -> None:
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.xavier_normal_(m.weight)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- m.weight.data.normal_(mean=0.0, std=0.02)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.LayerNorm):
- if m.weight is not None:
- m.weight.data.fill_(1)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.Embedding):
- m.weight.data.normal_(mean=0.0, std=0.02)
diff --git a/src/typography_generation/model/baseline.py b/src/typography_generation/model/baseline.py
deleted file mode 100644
index 2fd67a2..0000000
--- a/src/typography_generation/model/baseline.py
+++ /dev/null
@@ -1,343 +0,0 @@
-import pickle
-import random
-from typing import Any, Dict, List, Tuple, Union
-import numpy as np
-
-import torch
-from logzero import logger
-from torch import Tensor, nn
-from typography_generation.config.attribute_config import (
- CanvasContextEmbeddingAttributeConfig,
- TextElementContextEmbeddingAttributeConfig,
- TextElementContextPredictionAttributeConfig,
-)
-from typography_generation.io.data_object import ModelInput
-from typography_generation.model.bottleneck import ImlevelLF
-from typography_generation.model.decoder import Decoder, MultiTask
-from typography_generation.model.embedding import Embedding
-from typography_generation.model.encoder import Encoder
-
-
-class Baseline(nn.Module):
- def __init__(
- self,
- prefix_list_element: List,
- prefix_list_canvas: List,
- prefix_list_target: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- prediction_config_element: TextElementContextPredictionAttributeConfig,
- d_model: int = 256,
- n_head: int = 8,
- dropout: float = 0.1,
- num_encoder_layers: int = 4,
- num_decoder_layers: int = 4,
- seq_length: int = 50,
- std_ratio: float = 1.0,
- **kwargs: Any,
- ) -> None:
- super().__init__()
- logger.info(f"CanvasVAE settings")
- logger.info(f"d_model: {d_model}")
- logger.info(f"n_head: {n_head}")
- logger.info(f"num_encoder_layers: {num_encoder_layers}")
- logger.info(f"num_decoder_layers: {num_decoder_layers}")
- logger.info(f"seq_length: {seq_length}")
-
- self.emb = Embedding(
- prefix_list_element,
- prefix_list_canvas,
- embedding_config_element,
- embedding_config_canvas,
- d_model,
- dropout,
- seq_length,
- )
- self.enc = Encoder(
- d_model=d_model,
- n_head=n_head,
- dropout=dropout,
- num_encoder_layers=num_encoder_layers,
- )
- self.lf = ImlevelLF(vae=True, std_ratio=std_ratio)
-
- self.dec = Decoder(
- prefix_list_target,
- embedding_config_element,
- d_model=d_model,
- n_head=n_head,
- dropout=dropout,
- num_decoder_layers=num_decoder_layers,
- seq_length=seq_length,
- autoregressive_scheme=False,
- )
- self.head = MultiTask(
- prefix_list_target, prediction_config_element, d_model, bypass=False
- )
- self.initialize_weights()
-
- def forward(self, model_inputs: ModelInput) -> Tensor:
- (
- src,
- text_mask_src,
- feat_cat,
- ) = self.emb(model_inputs)
- z = self.enc(src, text_mask_src)
- z, vae_data = self.lf(z, text_mask_src)
- zd = self.dec(feat_cat, z, model_inputs)
- outs = self.head(zd, feat_cat)
- outs["vae_data"] = vae_data
- return outs
-
- def store(
- self,
- out_labels_all: Dict,
- out_labels: Dict,
- target_prefix_list: List,
- text_index: int,
- ) -> Dict:
- for prefix in target_prefix_list:
- out_labels_all[f"{prefix}"][text_index, 0] = out_labels[f"{prefix}"]
- return out_labels_all
-
- def prediction(
- self,
- model_inputs: ModelInput,
- target_prefix_list: List,
- start_index: int = 0,
- ) -> Tensor:
- target_text_num = int(model_inputs.canvas_text_num[0].item())
- start_index = min(start_index, target_text_num)
- for t in range(start_index, target_text_num):
- model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t)
- (
- src,
- text_mask_src,
- feat_cat,
- ) = self.emb(model_inputs)
- z = self.enc(src, text_mask_src)
- z = self.lf.prediction(z, text_mask_src)
- zd = self.dec(feat_cat, z, model_inputs)
- model_outs = self.head(zd, feat_cat)
-
- out_labels_all = {}
- for prefix in target_prefix_list:
- out_labels_all[prefix] = np.zeros((target_text_num, 1))
- for t in range(0, start_index):
- tar = getattr(model_inputs, f"{prefix}")[0, t].item()
- out_labels_all[prefix][t, 0] = tar
-
- for t in range(start_index, target_text_num):
- out_labels = self.get_labels(model_outs, target_prefix_list, t)
- out_labels_all = self.store(
- out_labels_all, out_labels, target_prefix_list, t
- )
- return out_labels_all
-
- def sample(
- self,
- model_inputs: ModelInput,
- target_prefix_list: List,
- start_index: int = 0,
- **kwargs: Any,
- ) -> Tensor:
- target_text_num = int(model_inputs.canvas_text_num[0].item())
- start_index = min(start_index, target_text_num)
- for t in range(start_index, target_text_num):
- model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t)
- (
- src,
- text_mask_src,
- feat_cat,
- ) = self.emb(model_inputs)
- z = self.enc(src, text_mask_src)
- z = self.lf.sample(z, text_mask_src)
- zd = self.dec(feat_cat, z, model_inputs)
- model_outs = self.head(zd, feat_cat)
- out_labels_all = {}
- for prefix in target_prefix_list:
- out_labels_all[prefix] = np.zeros((target_text_num, 1))
- for t in range(0, start_index):
- tar = getattr(model_inputs, f"{prefix}")[0, t].item()
- out_labels_all[prefix][t, 0] = tar
-
- for t in range(start_index, target_text_num):
- out_labels = self.get_labels(model_outs, target_prefix_list, t)
- out_labels_all = self.store(
- out_labels_all, out_labels, target_prefix_list, t
- )
- return out_labels_all
-
- def initialize_weights(self) -> None:
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.xavier_normal_(m.weight)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- m.weight.data.normal_(mean=0.0, std=0.02)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.LayerNorm):
- if m.weight is not None:
- m.weight.data.fill_(1)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.Embedding):
- m.weight.data.normal_(mean=0.0, std=0.02)
-
-
-class AllZero(Baseline):
- def __init__(
- self,
- prefix_list_element: List,
- prefix_list_canvas: List,
- prefix_list_target: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- prediction_config_element: TextElementContextPredictionAttributeConfig,
- d_model: int = 256,
- n_head: int = 8,
- dropout: float = 0.1,
- num_encoder_layers: int = 4,
- num_decoder_layers: int = 4,
- seq_length: int = 50,
- std_ratio: float = 1.0,
- **kwargs: Any,
- ) -> None:
- super().__init__(
- prefix_list_element,
- prefix_list_canvas,
- prefix_list_target,
- embedding_config_element,
- embedding_config_canvas,
- prediction_config_element,
- d_model,
- n_head,
- dropout,
- num_encoder_layers,
- num_decoder_layers,
- seq_length,
- std_ratio,
- )
-
- def get_labels(
- self,
- model_outs: Union[Dict, None],
- target_prefix_list: List,
- text_index: int,
- batch_index: int = 0,
- ) -> Dict:
- out_labels = {}
- for prefix in target_prefix_list:
- out_labels[f"{prefix}"] = 0
-
- return out_labels
-
-
-class AllRandom(Baseline):
- def __init__(
- self,
- prefix_list_element: List,
- prefix_list_canvas: List,
- prefix_list_target: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- prediction_config_element: TextElementContextPredictionAttributeConfig,
- d_model: int = 256,
- n_head: int = 8,
- dropout: float = 0.1,
- num_encoder_layers: int = 4,
- num_decoder_layers: int = 4,
- seq_length: int = 50,
- std_ratio: float = 1.0,
- **kwargs: Any,
- ) -> None:
- super().__init__(
- prefix_list_element,
- prefix_list_canvas,
- prefix_list_target,
- embedding_config_element,
- embedding_config_canvas,
- prediction_config_element,
- d_model,
- n_head,
- dropout,
- num_encoder_layers,
- num_decoder_layers,
- seq_length,
- std_ratio,
- )
-
- def get_labels(
- self,
- model_outs: Dict,
- target_prefix_list: List,
- text_index: int,
- batch_index: int = 0,
- ) -> Dict:
- out_labels = {}
- for prefix in target_prefix_list:
- out = model_outs[f"{prefix}"]
- label_range = len(out[batch_index][text_index])
- out_labels[f"{prefix}"] = random.randint(0, label_range - 1)
-
- return out_labels
-
-
-class Mode(Baseline):
- def __init__(
- self,
- prefix_list_element: List,
- prefix_list_canvas: List,
- prefix_list_target: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- prediction_config_element: TextElementContextPredictionAttributeConfig,
- d_model: int = 256,
- n_head: int = 8,
- dropout: float = 0.1,
- num_encoder_layers: int = 4,
- num_decoder_layers: int = 4,
- seq_length: int = 50,
- std_ratio: float = 1.0,
- **kwargs: Any,
- ) -> None:
- super().__init__(
- prefix_list_element,
- prefix_list_canvas,
- prefix_list_target,
- embedding_config_element,
- embedding_config_canvas,
- prediction_config_element,
- d_model,
- n_head,
- dropout,
- num_encoder_layers,
- num_decoder_layers,
- seq_length,
- std_ratio,
- )
- self.prefix2mode = pickle.load(
- open(
- f"prefix2mode.pkl",
- "rb",
- )
- )
-
- def get_labels(
- self,
- model_outs: Dict,
- target_prefix_list: List,
- text_index: int,
- batch_index: int = 0,
- ) -> Dict:
- out_labels = {}
- for prefix in target_prefix_list:
- out = model_outs[f"{prefix}"]
- out_labels[f"{prefix}"] = self.prefix2mode[prefix]
-
- return out_labels
diff --git a/src/typography_generation/model/bottleneck.py b/src/typography_generation/model/bottleneck.py
deleted file mode 100644
index 0c7569e..0000000
--- a/src/typography_generation/model/bottleneck.py
+++ /dev/null
@@ -1,100 +0,0 @@
-from typing import Tuple
-
-import torch
-from torch import Tensor, nn
-
-
-class VAE(nn.Module):
- def __init__(
- self,
- d_model: int = 256,
- reparametric: bool = True,
- std_ratio: float = 1.0,
- ):
- super(VAE, self).__init__()
- self.d_model = d_model
- self.enc_mu_fcn = nn.Linear(d_model, d_model)
- self.enc_sigma_fcn = nn.Linear(d_model, d_model)
- self.reparametoric = reparametric
- self.stdrate = std_ratio
-
- def _init_embeddings(self) -> None:
- nn.init.normal_(self.enc_mu_fcn.weight, std=0.001)
- nn.init.constant_(self.enc_mu_fcn.bias, 0)
- nn.init.normal_(self.enc_sigma_fcn.weight, std=0.001)
- nn.init.constant_(self.enc_sigma_fcn.bias, 0)
-
- def forward(self, z: Tensor) -> Tuple:
- mu, logsigma = self.enc_mu_fcn(z), self.enc_sigma_fcn(z)
- sigma = torch.exp(logsigma / 2.0)
- z = mu + sigma * torch.randn_like(sigma) * self.stdrate
- return z, (mu, logsigma)
-
- def prediction(self, z: Tensor) -> Tensor:
- mu = self.enc_mu_fcn(z)
- z = mu
- return z
-
- def sample(self, z: Tensor) -> Tensor:
- mu, logsigma = self.enc_mu_fcn(z), self.enc_sigma_fcn(z)
- sigma = torch.exp(logsigma / 2.0)
- z = mu + sigma * torch.randn_like(sigma) * self.stdrate
- return z
-
-
-class ImlevelLF(nn.Module):
- def __init__(
- self,
- vae: bool = False,
- std_ratio: float = 1.0,
- ):
- super().__init__()
- self.vae_flag = vae
- if vae is True:
- self.vae = VAE(std_ratio=std_ratio)
-
- def forward(self, z: Tensor, text_mask: Tensor) -> Tensor:
- text_mask = (
- text_mask.permute(1, 0)
- .view(z.shape[0], z.shape[1], 1)
- .repeat(1, 1, z.shape[2])
- )
- z_tmp = torch.sum(z * text_mask, dim=0) / (torch.sum(text_mask, dim=0) + 1e-20)
-
- vae_item = []
- if self.vae_flag is True:
- z_tmp, vae_item_iml = self.vae(z_tmp)
- vae_item.append(vae_item_iml)
- return z_tmp.unsqueeze(0), vae_item
-
- def prediction(self, z: Tensor, text_mask: Tensor) -> Tensor:
- text_mask = (
- text_mask.permute(1, 0)
- .view(z.shape[0], z.shape[1], 1)
- .repeat(1, 1, z.shape[2])
- )
- z_tmp = torch.sum(z * text_mask, dim=0) / (torch.sum(text_mask, dim=0) + 1e-20)
-
- if self.vae_flag is True:
- z_tmp = self.vae.prediction(z_tmp)
- return z_tmp.unsqueeze(0)
-
- def sample(self, z: Tensor, text_mask: Tensor) -> Tensor:
- text_mask = (
- text_mask.permute(1, 0)
- .view(z.shape[0], z.shape[1], 1)
- .repeat(1, 1, z.shape[2])
- )
- z_tmp = torch.sum(z * text_mask, dim=0) / (torch.sum(text_mask, dim=0) + 1e-20)
-
- z_tmp = self.vae.sample(z_tmp)
- return z_tmp.unsqueeze(0)
-
-
-class Bottleneck(nn.Module):
- def __init__(self) -> None:
- super(Bottleneck, self).__init__()
- pass
-
- def forward(self, z: Tensor) -> Tensor:
- return z
diff --git a/src/typography_generation/model/canvas_vae.py b/src/typography_generation/model/canvas_vae.py
deleted file mode 100644
index b05bb8f..0000000
--- a/src/typography_generation/model/canvas_vae.py
+++ /dev/null
@@ -1,207 +0,0 @@
-from typing import Any, Dict, List, Tuple
-import numpy as np
-
-import torch
-from logzero import logger
-from torch import Tensor, nn
-from typography_generation.config.attribute_config import (
- CanvasContextEmbeddingAttributeConfig,
- TextElementContextEmbeddingAttributeConfig,
- TextElementContextPredictionAttributeConfig,
-)
-from typography_generation.io.crello_util import CrelloProcessor
-from typography_generation.io.data_object import ModelInput
-from typography_generation.model.bottleneck import ImlevelLF
-from typography_generation.model.decoder import Decoder, MultiTask
-from typography_generation.model.embedding import Embedding
-from typography_generation.model.encoder import Encoder
-
-
-class CanvasVAE(nn.Module):
- def __init__(
- self,
- prefix_list_element: List,
- prefix_list_canvas: List,
- prefix_list_target: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- prediction_config_element: TextElementContextPredictionAttributeConfig,
- d_model: int = 256,
- n_head: int = 8,
- dropout: float = 0.1,
- num_encoder_layers: int = 4,
- num_decoder_layers: int = 4,
- seq_length: int = 50,
- std_ratio: float = 1.0,
- **kwargs: Any,
- ) -> None:
- super().__init__()
- logger.info(f"CanvasVAE settings")
- logger.info(f"d_model: {d_model}")
- logger.info(f"n_head: {n_head}")
- logger.info(f"num_encoder_layers: {num_encoder_layers}")
- logger.info(f"num_decoder_layers: {num_decoder_layers}")
- logger.info(f"seq_length: {seq_length}")
-
- self.emb = Embedding(
- prefix_list_element,
- prefix_list_canvas,
- embedding_config_element,
- embedding_config_canvas,
- d_model,
- dropout,
- seq_length,
- )
- self.enc = Encoder(
- d_model=d_model,
- n_head=n_head,
- dropout=dropout,
- num_encoder_layers=num_encoder_layers,
- )
- self.lf = ImlevelLF(vae=True, std_ratio=std_ratio)
-
- self.dec = Decoder(
- prefix_list_target,
- embedding_config_element,
- d_model=d_model,
- n_head=n_head,
- dropout=dropout,
- num_decoder_layers=num_decoder_layers,
- seq_length=seq_length,
- autoregressive_scheme=False,
- )
- self.head = MultiTask(
- prefix_list_target, prediction_config_element, d_model, bypass=False
- )
- self.initialize_weights()
-
- def forward(self, model_inputs: ModelInput) -> Tensor:
- (
- src,
- text_mask_src,
- feat_cat,
- ) = self.emb(model_inputs)
- z = self.enc(src, text_mask_src)
- z, vae_data = self.lf(z, text_mask_src)
- zd = self.dec(feat_cat, z, model_inputs)
- outs = self.head(zd, feat_cat)
- outs["vae_data"] = vae_data
- return outs
-
- def get_labels(
- self,
- model_outs: Dict,
- target_prefix_list: List,
- text_index: int,
- batch_index: int = 0,
- ) -> Dict:
- out_labels = {}
- for prefix in target_prefix_list:
- out = model_outs[f"{prefix}"]
- sorted_label = torch.sort(input=out[batch_index], dim=1, descending=True)[1]
- target_label = 0 # top1
- out_label = sorted_label[text_index][target_label]
- out_labels[f"{prefix}"] = out_label
-
- return out_labels
-
- def store(
- self,
- out_labels_all: Dict,
- out_labels: Dict,
- target_prefix_list: List,
- text_index: int,
- ) -> Dict:
- for prefix in target_prefix_list:
- out_labels_all[f"{prefix}"][text_index] = out_labels[f"{prefix}"].item()
- return out_labels_all
-
- def prediction(
- self,
- model_inputs: ModelInput,
- dataset: CrelloProcessor,
- target_prefix_list: List,
- start_index: int = 0,
- ) -> Tensor:
- target_text_num = int(model_inputs.canvas_text_num[0].item())
- start_index = min(start_index, target_text_num)
- for t in range(start_index, target_text_num):
- model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t)
- (
- src,
- text_mask_src,
- feat_cat,
- ) = self.emb(model_inputs)
- z = self.enc(src, text_mask_src)
- z = self.lf.prediction(z, text_mask_src)
- zd = self.dec(feat_cat, z, model_inputs)
- model_outs = self.head(zd, feat_cat)
-
- out_labels_all = {}
- for prefix in target_prefix_list:
- out_labels_all[prefix] = np.zeros((target_text_num, 1))
- for t in range(0, start_index):
- tar = getattr(model_inputs, f"{prefix}")[0, t].item()
- out_labels_all[prefix][t, 0] = tar
-
- for t in range(start_index, target_text_num):
- out_labels = self.get_labels(model_outs, target_prefix_list, t)
- out_labels_all = self.store(
- out_labels_all, out_labels, target_prefix_list, t
- )
- return out_labels_all
-
- def sample(
- self,
- model_inputs: ModelInput,
- target_prefix_list: List,
- start_index: int = 0,
- **kwargs: Any,
- ) -> Tensor:
- target_text_num = int(model_inputs.canvas_text_num[0].item())
- start_index = min(start_index, target_text_num)
- for t in range(start_index, target_text_num):
- model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t)
- (
- src,
- text_mask_src,
- feat_cat,
- ) = self.emb(model_inputs)
- z = self.enc(src, text_mask_src)
- z = self.lf.sample(z, text_mask_src)
- zd = self.dec(feat_cat, z, model_inputs)
- model_outs = self.head(zd, feat_cat)
- out_labels_all = {}
- for prefix in target_prefix_list:
- out_labels_all[prefix] = np.zeros((target_text_num, 1))
- for t in range(0, start_index):
- tar = getattr(model_inputs, f"{prefix}")[0, t].item()
- out_labels_all[prefix][t, 0] = tar
-
- for t in range(start_index, target_text_num):
- out_labels = self.get_labels(model_outs, target_prefix_list, t)
- out_labels_all = self.store(
- out_labels_all, out_labels, target_prefix_list, t
- )
- return out_labels_all
-
- def initialize_weights(self) -> None:
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.xavier_normal_(m.weight)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- m.weight.data.normal_(mean=0.0, std=0.02)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.LayerNorm):
- if m.weight is not None:
- m.weight.data.fill_(1)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.Embedding):
- m.weight.data.normal_(mean=0.0, std=0.02)
diff --git a/src/typography_generation/model/common.py b/src/typography_generation/model/common.py
deleted file mode 100644
index 30296b2..0000000
--- a/src/typography_generation/model/common.py
+++ /dev/null
@@ -1,446 +0,0 @@
-import copy
-from collections import OrderedDict
-from typing import Any, Callable, Optional, Union
-
-import torch
-from torch import Tensor, nn
-from torch.functional import F
-
-
-def _get_clones(module: Any, N: int) -> nn.ModuleList:
- return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
-
-
-def _get_activation_fn(activation: Any) -> Any:
- if activation == "relu":
- return F.relu
- elif activation == "gelu":
- return F.gelu
- raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
-
-
-def _get_seq_len(src: Tensor, batch_first: bool) -> Optional[int]:
- if src.is_nested:
- return None
- else:
- src_size = src.size()
- if len(src_size) == 2:
- # unbatched: S, E
- return src_size[0]
- else:
- # batched: B, S, E if batch_first else S, B, E
- seq_len_pos = 1 if batch_first else 0
- return src_size[seq_len_pos]
-
-
-def _generate_square_subsequent_mask(
- sz: int,
- device: torch.device = torch.device(
- torch._C._get_default_device()
- ), # torch.device('cpu'),
- dtype: torch.dtype = torch.get_default_dtype(),
-) -> Tensor:
- r"""Generate a square causal mask for the sequence. The masked positions are filled with float('-inf').
- Unmasked positions are filled with float(0.0).
- """
- return torch.triu(
- torch.full((sz, sz), float("-inf"), dtype=dtype, device=device),
- diagonal=1,
- )
-
-
-def _detect_is_causal_mask(
- mask: Optional[Tensor],
- is_causal: Optional[bool] = None,
- size: Optional[int] = None,
-) -> bool:
- # Prevent type refinement
- make_causal = is_causal is True
-
- if is_causal is None and mask is not None:
- sz = size if size is not None else mask.size(-2)
- causal_comparison = _generate_square_subsequent_mask(
- sz, device=mask.device, dtype=mask.dtype
- )
-
- # Do not use `torch.equal` so we handle batched masks by
- # broadcasting the comparison.
- if mask.size() == causal_comparison.size():
- make_causal = bool((mask == causal_comparison).all())
- else:
- make_causal = False
-
- return make_causal
-
-
-class MyTransformerDecoder(nn.Module):
- __constants__ = ["norm"]
-
- def __init__(self, decoder_layer, num_layers, norm=None):
- super().__init__()
- torch._C._log_api_usage_once(f"torch.nn.modules.{self.__class__.__name__}")
- self.layers = _get_clones(decoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
-
- def forward(
- self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Optional[Tensor] = None,
- memory_mask: Optional[Tensor] = None,
- tgt_key_padding_mask: Optional[Tensor] = None,
- memory_key_padding_mask: Optional[Tensor] = None,
- get_weight: bool = False,
- ) -> Tensor:
- output = tgt
-
- seq_len = _get_seq_len(tgt, self.layers[0].self_attn.batch_first)
-
- for mod in self.layers:
- output, w = mod(
- output,
- memory,
- tgt_mask=tgt_mask,
- memory_mask=memory_mask,
- tgt_key_padding_mask=tgt_key_padding_mask,
- memory_key_padding_mask=memory_key_padding_mask,
- )
-
- if self.norm is not None:
- output = self.norm(output)
-
- if get_weight is True:
- return output, w
- else:
- return output
-
-
-class MyTransformerDecoderLayer(nn.Module):
- __constants__ = ["batch_first"]
-
- def __init__(
- self,
- d_model: float = 256,
- nhead: int = 8,
- dim_feedforward: int = 2048,
- dropout: float = 0.1,
- activation: str = "relu",
- layer_norm_eps: float = 1e-5,
- batch_first: bool = False,
- device: Any = None,
- dtype: Any = None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.self_attn = nn.MultiheadAttention(
- d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
- )
- self.multihead_attn = nn.MultiheadAttention(
- d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
- )
- # Implementation of Feedforward model
- self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
- self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
- self.dropout = nn.Dropout(dropout)
- self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.dropout3 = nn.Dropout(dropout)
-
- self.activation = _get_activation_fn(activation)
-
- def __setstate__(self, state: Any) -> None:
- if "activation" not in state:
- state["activation"] = F.relu
- super().__setstate__(state)
-
- def forward(
- self,
- tgt: torch.Tensor,
- memory: torch.Tensor,
- tgt_mask: Optional[torch.Tensor] = None,
- memory_mask: Optional[torch.Tensor] = None,
- tgt_key_padding_mask: Optional[torch.Tensor] = None,
- memory_key_padding_mask: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- tgt2 = self.self_attn(
- tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
- )[0]
- tgt = tgt + self.dropout1(tgt2)
- tgt = self.norm1(tgt)
- tgt2, weight = self.multihead_attn(
- tgt,
- memory,
- memory,
- attn_mask=memory_mask,
- key_padding_mask=memory_key_padding_mask,
- )
- tgt = tgt + self.dropout2(tgt2)
- tgt = self.norm2(tgt)
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
- tgt = tgt + self.dropout3(tgt2)
- tgt = self.norm3(tgt)
- return tgt, weight
-
-
-class MyTransformerEncoder(nn.Module):
- __constants__ = ["norm"]
-
- def __init__(
- self, encoder_layer: nn.Module, num_layers: int, norm: Any = None
- ) -> None:
- super(MyTransformerEncoder, self).__init__()
- self.layers = _get_clones(encoder_layer, num_layers)
- self.num_layers = num_layers
- self.norm = norm
-
- def forward(
- self,
- src: Tensor,
- mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- get_weight: bool = False,
- ) -> Tensor:
- output = src
- for mod in self.layers:
- output, w = mod(
- output,
- src_mask=mask,
- src_key_padding_mask=src_key_padding_mask,
- )
-
- if self.norm is not None:
- output = self.norm(output)
- if get_weight is True:
- return output, w
- else:
- return output
-
-
-class MyTransformerEncoderLayer(nn.Module):
- __constants__ = ["batch_first"]
-
- def __init__(
- self,
- d_model: int,
- nhead: int,
- dim_feedforward: int = 2048,
- dropout: float = 0.1,
- activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
- layer_norm_eps: float = 1e-5,
- batch_first: bool = False,
- norm_first: bool = False,
- device: Any = None,
- dtype: Any = None,
- ) -> None:
- factory_kwargs = {"device": device, "dtype": dtype}
- super(MyTransformerEncoderLayer, self).__init__()
- self.self_attn = nn.MultiheadAttention(
- d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs
- )
- # Implementation of Feedforward model
- self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
- self.dropout = nn.Dropout(dropout)
- self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
-
- self.norm_first = norm_first
- self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
-
- # Legacy string support for activation function.
- if isinstance(activation, str):
- self.activation = _get_activation_fn(activation)
- else:
- self.activation = activation
-
- def __setstate__(self, state: Any) -> None:
- if "activation" not in state:
- state["activation"] = F.relu
- super(MyTransformerEncoderLayer, self).__setstate__(state)
-
- def forward(
- self,
- src: Tensor,
- src_mask: Optional[Tensor] = None,
- src_key_padding_mask: Optional[Tensor] = None,
- ) -> Tensor:
- x = src
- if self.norm_first:
- sa, w = self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
- x = x + sa
- x = x + self._ff_block(self.norm2(x))
- else:
- sa, w = self._sa_block(x, src_mask, src_key_padding_mask)
- x = self.norm1(x + sa)
- x = self.norm2(x + self._ff_block(x))
-
- return x, w
-
- # self-attention block
- def _sa_block(
- self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
- ) -> Tensor:
- x, w = self.self_attn(
- x,
- x,
- x,
- attn_mask=attn_mask,
- key_padding_mask=key_padding_mask,
- need_weights=True,
- )
- return self.dropout1(x), w
-
- # feed forward block
- def _ff_block(self, x: Tensor) -> Tensor:
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
- return self.dropout2(x)
-
-
-def _conv3x3_bn_relu(
- in_channels: int,
- out_channels: int,
- dilation: int = 1,
- kernel_size: int = 3,
- stride: int = 1,
-) -> nn.Sequential:
- if dilation == 0:
- dilation = 1
- padding = 0
- else:
- padding = dilation
- return nn.Sequential(
- OrderedDict(
- [
- (
- "conv",
- nn.Conv2d(
- in_channels,
- out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- bias=False,
- ),
- ),
- ("bn", nn.BatchNorm2d(out_channels)),
- ("relu", nn.ReLU()),
- ]
- )
- )
-
-
-class LinearView(nn.Module):
- def __init__(self, channel: int) -> None:
- super(LinearView, self).__init__()
- self.channel = channel
-
- def forward(self, x: Tensor) -> Tensor:
- if len(x.shape) == 3:
- BN, CHARN, CN = x.shape
- if CHARN == 1:
- x = x.view(BN, self.channel)
- return x
-
-
-def fn_ln_relu(fn: Any, out_channels: int, dp: float = 0.1) -> nn.Sequential:
- return nn.Sequential(
- OrderedDict(
- [
- ("fn", fn),
- ("ln", nn.LayerNorm(out_channels)),
- ("view", LinearView(out_channels)),
- ("dp", nn.Dropout(dp)),
- ]
- )
- )
-
-
-class Linearx2(nn.Module):
- def __init__(self, in_ch: int, out_ch: int) -> None:
- super().__init__()
- self.fcn = nn.Sequential(
- OrderedDict(
- [
- ("l1", nn.Linear(in_ch, in_ch)),
- ("r1", nn.LeakyReLU(0.2)),
- ("l2", nn.Linear(in_ch, out_ch)),
- ]
- )
- )
-
- def forward(self, out: Tensor) -> Tensor:
- logits = self.fcn(out) # Shape [G, N, 2]
- return logits
-
-
-class Linearx3(nn.Module):
- def __init__(self, in_ch: int, out_ch: int) -> None:
- super().__init__()
- self.fcn = nn.Sequential(
- OrderedDict(
- [
- ("l1", nn.Linear(in_ch, in_ch)),
- ("r1", nn.LeakyReLU(0.2)),
- ("l1", nn.Linear(in_ch, in_ch)),
- ("r1", nn.LeakyReLU(0.2)),
- ("l2", nn.Linear(in_ch, out_ch)),
- ]
- )
- )
-
- def forward(self, out: Tensor) -> Tensor:
- logits = self.fcn(out) # Shape [G, N, 2]
- return logits
-
-
-class ConstEmbedding(nn.Module):
- def __init__(
- self, d_model: int = 256, seq_len: int = 50, positional_encoding: bool = True
- ):
- super().__init__()
- self.d_model = d_model
- self.seq_len = seq_len
- self.PE = PositionalEncodingLUT(
- d_model, max_len=seq_len, positional_encoding=positional_encoding
- )
-
- def forward(self, z: Tensor) -> Tensor:
- if len(z.shape) == 2:
- N = z.size(0)
- elif len(z.shape) == 3:
- N = z.size(1)
- else:
- raise Exception
- pos = self.PE(z.new_zeros(self.seq_len, N, self.d_model))
- return pos
-
-
-class PositionalEncodingLUT(nn.Module):
- def __init__(
- self,
- d_model: int = 256,
- dropout: float = 0.1,
- max_len: int = 50,
- positional_encoding: bool = True,
- ):
- super(PositionalEncodingLUT, self).__init__()
- self.PS = positional_encoding
- self.dropout = nn.Dropout(p=dropout)
- position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(1)
- self.register_buffer("position", position)
- self.pos_embed = nn.Embedding(max_len, d_model)
- self._init_embeddings()
-
- def _init_embeddings(self) -> None:
- nn.init.kaiming_normal_(self.pos_embed.weight, mode="fan_in")
-
- def forward(self, x: Tensor, inp_ignore: bool = False) -> Tensor:
- pos = self.position[: x.size(0)]
- x = self.pos_embed(pos).repeat(1, x.size(1), 1)
- return self.dropout(x)
diff --git a/src/typography_generation/model/decoder.py b/src/typography_generation/model/decoder.py
deleted file mode 100644
index a2b33ff..0000000
--- a/src/typography_generation/model/decoder.py
+++ /dev/null
@@ -1,210 +0,0 @@
-from typing import Any, List, Tuple
-
-import torch
-from torch import Tensor, nn
-from typography_generation.config.attribute_config import (
- TextElementContextEmbeddingAttributeConfig,
- TextElementContextPredictionAttributeConfig,
-)
-from typography_generation.io.data_object import ModelInput
-from typography_generation.model.common import (
- ConstEmbedding,
- Linearx2,
- MyTransformerDecoder,
- MyTransformerDecoderLayer,
- fn_ln_relu,
-)
-
-
-class Decoder(nn.Module):
- def __init__(
- self,
- prefix_list_target: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- d_model: int = 256,
- n_head: int = 8,
- dropout: float = 0.1,
- num_decoder_layers: int = 4,
- seq_length: int = 50,
- positional_encoding: bool = True,
- autoregressive_scheme: bool = True,
- ):
- super().__init__()
-
- self.prefix_list_target = prefix_list_target
- self.d_model = d_model
- self.dropout_element = dropout
- self.seq_length = seq_length
- self.autoregressive_scheme = autoregressive_scheme
- dim_feedforward = d_model * 2
-
- # Positional encoding in transformer
- position = torch.arange(0, 1, dtype=torch.long).unsqueeze(1)
- self.register_buffer("position", position)
- self.pos_embed = nn.Embedding(1, d_model)
- self.embedding = ConstEmbedding(d_model, seq_length, positional_encoding)
-
- # Decoder layer
- # decoder_layer = nn.TransformerDecoderLayer(
- # d_model, n_head, dim_feedforward, dropout
- # )
- decoder_norm = nn.LayerNorm(d_model)
- # self.decoder = nn.TransformerDecoder(
- # decoder_layer, num_decoder_layers, decoder_norm
- # )
- decoder_layer = MyTransformerDecoderLayer(
- d_model, n_head, dim_feedforward, dropout
- )
- self.decoder = MyTransformerDecoder(
- decoder_layer, num_decoder_layers, decoder_norm
- )
-
- # mask config in training
- mask_tgt = torch.ones((seq_length, seq_length))
- mask_tgt = torch.triu(mask_tgt == 1).transpose(0, 1)
- self.mask_tgt = mask_tgt.float().masked_fill(mask_tgt == 0, float("-inf"))
- # Build bart embedding
- self.build_bart_embedding(prefix_list_target, embedding_config_element)
-
- def build_bart_embedding(
- self,
- prefix_list_target: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- ) -> None:
- for prefix in prefix_list_target:
- target_embedding_config = getattr(embedding_config_element, prefix)
- kwargs = target_embedding_config.emb_layer_kwargs # dict of args
- emb_layer = nn.Embedding(**kwargs)
- emb_layer = fn_ln_relu(emb_layer, self.d_model, self.dropout_element)
- setattr(self, f"{prefix}_emb", emb_layer)
- setattr(self, f"{prefix}_flag", target_embedding_config.flag)
-
- def get_features_via_fn(
- self, fn: Any, inputs: List, batch_num: int, text_num: Tensor
- ) -> Tensor:
- inputs_fn = []
- for b in range(batch_num):
- tn = int(text_num[b].item())
- for t in range(tn):
- inputs_fn.append(inputs[b][t].view(-1))
- outs = None
- if len(inputs_fn) > 0:
- inputs_fn = torch.stack(inputs_fn)
- outs = fn(inputs_fn)
- outs = outs.view(len(inputs_fn), self.d_model)
- feat = torch.zeros(self.seq_length, batch_num, self.d_model)
- feat = feat.to(text_num.device).float()
- cnt = 0
- for b in range(batch_num):
- tn = int(text_num[b].item())
- for t in range(tn):
- if outs is not None:
- feat[t, b] = outs[cnt]
- cnt += 1
- return feat
-
- def get_style_context_embedding(
- self, model_inputs: ModelInput, batch_num: int, text_num: Tensor
- ) -> Tensor:
- feat = torch.zeros(self.seq_length, batch_num, self.d_model)
- feat = feat.to(text_num.device).float()
- for prefix in self.prefix_list_target:
- inp = getattr(model_inputs, prefix).long()
- layer = getattr(self, f"{prefix}_emb")
- f = self.get_features_via_fn(layer, inp, batch_num, text_num)
- feat = feat + f
- feat = feat / len(self.prefix_list_target)
- return feat
-
- def shift_context_feat(
- self, context_feat: Tensor, batch_num: int, text_num: Tensor
- ) -> Tensor:
- shifted_context_feat = torch.zeros(self.seq_length, batch_num, self.d_model)
- shifted_context_feat = shifted_context_feat.to(text_num.device).float()
- for b in range(batch_num):
- tn = int(text_num[b].item())
- for t in range(tn - 1):
- shifted_context_feat[t + 1, b] = context_feat[t, b]
- return shifted_context_feat
-
- def get_bart_embedding(self, src: Tensor, model_inputs: ModelInput) -> Tensor:
- batch_num, text_num = model_inputs.batch_num, model_inputs.canvas_text_num
- context_feat = self.get_style_context_embedding(
- model_inputs, batch_num, text_num
- )
- shifted_context_feat = self.shift_context_feat(
- context_feat, batch_num, text_num
- )
- position_feat = self.embedding(src)
- return shifted_context_feat + position_feat
-
- def forward(
- self,
- src: Tensor,
- z: Tensor,
- model_inputs: ModelInput,
- ) -> Tensor:
- if self.autoregressive_scheme is True:
- src = self.get_bart_embedding(src, model_inputs)
- else:
- src = self.embedding(src)
- mask_tgt = self.mask_tgt.to(src.device)
- out = self.decoder(src, z, mask_tgt, tgt_key_padding_mask=None)
- return out
-
- def get_transformer_weight(
- self,
- src: Tensor,
- z: Tensor,
- model_inputs: ModelInput,
- ) -> Tuple:
- if self.autoregressive_scheme is True:
- src = self.get_bart_embedding(src, model_inputs)
- else:
- src = self.embedding(src)
- mask_tgt = self.mask_tgt.to(src.device)
- out, weights = self.decoder(
- src, z, mask_tgt, tgt_key_padding_mask=None, get_weight=True
- )
- return out, weights
-
-
-class FCN(nn.Module):
- def __init__(self, d_model: int, label_num: int, bypass: bool) -> None:
- super().__init__()
- self.bypass = bypass
- if bypass is True:
- self.fcn = Linearx2(d_model * 2, label_num)
- else:
- self.fcn = Linearx2(d_model, label_num)
-
- def forward(self, inp: Tensor, elm_agg_emb: Tensor = None) -> Tensor:
- if self.bypass is True:
- inp = torch.cat((inp, elm_agg_emb), 2)
- logits = inp.permute(1, 0, 2)
- logits = self.fcn(logits) # Shape [G, N, 2]
- return logits
-
-
-class MultiTask(nn.Module):
- def __init__(
- self,
- prefix_list_target: List,
- prediction_config: TextElementContextPredictionAttributeConfig,
- d_model: int = 256,
- bypass: bool = True,
- ):
- super().__init__()
- self.prefix_list_target = prefix_list_target
- for prefix in self.prefix_list_target:
- target_prediction_config = getattr(prediction_config, prefix)
- layer = FCN(d_model, target_prediction_config.out_dim, bypass)
- setattr(self, f"{prefix}_layer", layer)
-
- def forward(self, z: Tensor, elm_agg_emb: Tensor) -> dict:
- outputs = {}
- for prefix in self.prefix_list_target:
- layer = getattr(self, f"{prefix}_layer")
- out = layer(z, elm_agg_emb)
- outputs[prefix] = out
- return outputs
diff --git a/src/typography_generation/model/embedding.py b/src/typography_generation/model/embedding.py
deleted file mode 100644
index 725ddfa..0000000
--- a/src/typography_generation/model/embedding.py
+++ /dev/null
@@ -1,407 +0,0 @@
-from collections import OrderedDict
-import time
-from typing import Any, List, Tuple
-from einops import rearrange, repeat
-from logzero import logger
-import torch
-from torch import Tensor, nn
-from torch.functional import F
-from typography_generation.config.attribute_config import (
- CanvasContextEmbeddingAttributeConfig,
- TextElementContextEmbeddingAttributeConfig,
- EmbeddingConfig,
-)
-from typography_generation.io.data_object import ModelInput
-from typography_generation.model.common import (
- _conv3x3_bn_relu,
- fn_ln_relu,
-)
-
-
-def set_tensor_type(inp: Tensor, tensor_type: str) -> Tensor:
- if tensor_type == "float":
- inp = inp.float()
- elif tensor_type == "long":
- inp = inp.long()
- else:
- raise NotImplementedError()
- return inp
-
-
-def setup_emb_layer(
- self: nn.Module, prefix: str, target_embedding_config: EmbeddingConfig
-) -> None:
- if target_embedding_config.emb_layer is not None:
- kwargs = target_embedding_config.emb_layer_kwargs # dict of args
- if target_embedding_config.emb_layer == "nn.Embedding":
- emb_layer = nn.Embedding(**kwargs)
- setattr(self, f"{prefix}_type", "long")
- elif target_embedding_config.emb_layer == "nn.Linear":
- emb_layer = nn.Linear(**kwargs)
- setattr(self, f"{prefix}_type", "float")
- else:
- raise NotImplementedError()
- emb_layer = fn_ln_relu(emb_layer, self.d_model, self.dropout)
- setattr(self, f"{prefix}_emb", emb_layer)
- else:
- setattr(self, f"{prefix}_type", "float")
- setattr(self, f"{prefix}_flag", target_embedding_config.flag)
- setattr(self, f"{prefix}_specific", target_embedding_config.specific_func)
- setattr(self, f"{prefix}_inp", target_embedding_config.input_prefix)
-
-
-def get_output(
- self: nn.Module, prefix: str, model_inputs: ModelInput, batch_num: int
-) -> Tensor:
- inp_prefix = getattr(self, f"{prefix}_inp")
- specific_func = getattr(self, f"{prefix}_specific")
- tensor_type = getattr(self, f"{prefix}_type")
- inp = getattr(model_inputs, inp_prefix)
- text_num = getattr(model_inputs, "canvas_text_num")
- inp = set_tensor_type(inp, tensor_type)
- if specific_func is not None:
- fn = getattr(self, specific_func)
- out = fn(
- inputs=inp,
- batch_num=batch_num,
- text_num=text_num,
- )
- else:
- fn = getattr(self, f"{prefix}_emb")
- out = self.get_features_via_fn(
- fn=fn, inputs=inp, batch_num=batch_num, text_num=text_num
- )
- return out
-
-
-class Down(nn.Module):
- def __init__(self, input_dim: int, output_dim: int) -> None:
- super(Down, self).__init__()
- self.conv1 = _conv3x3_bn_relu(input_dim, output_dim, kernel_size=3)
- self.drop = nn.Dropout()
-
- def forward(self, feat: Tensor) -> Tensor:
- feat = self.conv1(feat)
- return self.drop(feat)
-
-
-class Embedding(nn.Module):
- def __init__(
- self,
- prefix_list_element: List,
- prefix_list_canvas: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- d_model: int = 256,
- dropout: float = 0.1,
- seq_length: int = 50,
- ) -> None:
- super(Embedding, self).__init__()
- self.emb_element = EmbeddingElementContext(
- prefix_list_element, embedding_config_element, d_model, dropout, seq_length
- )
- self.emb_canvas = EmbeddingCanvasContext(
- prefix_list_canvas,
- embedding_config_canvas,
- d_model,
- dropout,
- )
- self.prefix_list_element = prefix_list_element
- self.prefix_list_canvas = prefix_list_canvas
-
- self.d_model = d_model
- self.seq_length = seq_length
-
- mlp_dim = (
- self.compute_modality_num(embedding_config_canvas, embedding_config_element)
- * d_model
- )
- self.mlp = nn.Sequential(
- OrderedDict(
- [
- ("l1", nn.Linear(mlp_dim, d_model)),
- ("r1", nn.LeakyReLU(0.2)),
- ("l2", nn.Linear(d_model, d_model)),
- ("r2", nn.LeakyReLU(0.2)),
- ]
- )
- )
- self.mlp_dim = mlp_dim
-
- def compute_modality_num(
- self,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- embedding_config_text: TextElementContextEmbeddingAttributeConfig,
- ) -> int:
- modality_num = 0
- for prefix in self.prefix_list_element:
- target_embedding_config = getattr(embedding_config_text, prefix)
- flag = target_embedding_config.flag
- if flag is False:
- pass
- else:
- modality_num += 1
- for prefix in self.prefix_list_canvas:
- target_embedding_config = getattr(embedding_config_canvas, prefix)
- flag = target_embedding_config.flag
- if flag is True:
- modality_num += 1
- return modality_num
-
- def flatten_feature(
- self,
- feats: List,
- ) -> Tensor:
- feats_flatten = []
- modarity_num = len(feats)
- for m in range(modarity_num):
- feats_flatten.append(feats[m])
- return feats_flatten
-
- def reshape_tbc(
- self,
- feat_elements: List, # modality num(M) x [B, S, C]
- feat_canvas: List, # canvas num(CV) x [B, C]
- batch_num: int,
- text_num: Tensor,
- ) -> Tensor:
- feat_elements = torch.stack(feat_elements) # M, B, S, C
- feat_elements = rearrange(feat_elements, "m b s c -> s b (m c)")
- feat_canvas = torch.stack(feat_canvas) # CV, B, C
- feat_canvas = rearrange(feat_canvas, "cv b c -> 1 b (cv c)")
- feat_canvas = repeat(feat_canvas, "1 b c -> s b c", s=self.seq_length)
- feat = torch.cat((feat_canvas, feat_elements), dim=2)
- feat = self.mlp(feat)
- return feat
-
- def get_transformer_inputs(
- self,
- feat_elements: List[Tensor],
- feat_canvas: Tensor,
- batch_num: int,
- text_num: Tensor,
- ) -> Tuple[Tensor, Tensor]:
- feat_element, text_mask_element = self.lineup_features(
- feat_elements, batch_num, text_num
- )
- feat_canvas = torch.stack(feat_canvas)
- feat_canvas = feat_canvas.view(
- feat_canvas.shape[0], feat_element.shape[1], feat_element.shape[2]
- )
- src = torch.cat((feat_canvas, feat_element), dim=0)
- canvas_mask = torch.zeros(src.shape[1], feat_canvas.shape[0]) + 1
- canvas_mask = canvas_mask.float().to(text_num.device)
- text_mask_src = torch.cat((canvas_mask, text_mask_element), dim=1)
- return src, text_mask_src
-
- def lineup_features(
- self, feats: Tensor, batch_num: int, text_num: Tensor
- ) -> Tuple[Tensor, Tensor]:
- modality_num = len(feats)
- device = text_num.device
- feat = (
- torch.zeros(self.seq_length, modality_num, batch_num, self.d_model)
- .float()
- .to(device)
- )
- for m in range(modality_num):
- logger.debug(f"{feats[m].shape=}")
- feat[:, m, :, :] = rearrange(feats[m], "b t c -> t b c")
- feat = rearrange(feat, "t m b c -> (t m) b c")
- indices = rearrange(torch.arange(self.seq_length), "s -> 1 s").to(device)
- mask = (
- (indices < text_num).to(device).float()
- ) # (B, S), indicating valid attribute locations
- text_mask = repeat(mask, "b s -> b (s m)", m=modality_num)
- return feat, text_mask
-
- def forward(self, model_inputs: ModelInput) -> Tuple[Tensor, Tensor, Tensor]:
- feat_elements = self.emb_element(model_inputs)
- feat_canvas = self.emb_canvas(model_inputs)
-
- src, text_mask_src = self.get_transformer_inputs(
- feat_elements,
- feat_canvas,
- model_inputs.batch_num,
- model_inputs.canvas_text_num,
- )
- feat_cat = self.reshape_tbc(
- feat_elements,
- feat_canvas,
- model_inputs.batch_num,
- model_inputs.canvas_text_num,
- )
-
- return (src, text_mask_src, feat_cat)
-
-
-class EmbeddingElementContext(nn.Module):
- def __init__(
- self,
- prefix_list_element: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- d_model: int,
- dropout: float = 0.1,
- seq_length: int = 50,
- ):
- super(EmbeddingElementContext, self).__init__()
- self.d_model = d_model
- self.dropout = dropout
- self.seq_length = seq_length
- self.prefix_list_element = prefix_list_element
-
- for prefix in self.prefix_list_element:
- target_embedding_config = getattr(embedding_config_element, prefix)
- setup_emb_layer(self, prefix, target_embedding_config)
- if target_embedding_config.specific_build is not None:
- build_func = getattr(self, target_embedding_config.specific_build)
- build_func()
-
- def get_local_feat(
- self, inputs: Tensor, batch_num: int, text_num: Tensor, **kwargs: Any
- ) -> Tensor:
- inputs_fn = []
- for b in range(batch_num):
- tn = int(text_num[b].item())
- for t in range(tn):
- inputs_fn.append(inputs[b, t])
- outs = None
- if len(inputs_fn) > 0:
- inputs_fn = torch.stack(inputs_fn)
- with torch.no_grad():
- outs = self.imgencoder(inputs_fn).detach()
- outs = self.downdim(outs)
- outs = outs.view(len(outs), -1)
- outs = self.fcimg(outs)
- return outs
-
- # def get_features_via_fn(
- # self, fn: Any, inputs: List, batch_num: int, text_num: Tensor
- # ) -> Tensor:
- # inputs_layer = []
- # for b in range(batch_num):
- # tn = int(text_num[b].item())
- # for t in range(tn):
- # inputs_layer.append(inputs[b][t].view(-1))
- # outs = None
- # if len(inputs_layer) > 0:
- # inputs_layer = torch.stack(inputs_layer)
- # outs = fn(inputs_layer)
- # outs = outs.view(len(inputs_layer), self.d_model)
- # return outs
-
- def get_features_via_fn(
- self, fn: Any, inputs: Tensor, batch_num: int, text_num: Tensor
- ) -> Tensor:
- device = inputs.device
- feat = torch.zeros(batch_num, self.seq_length, self.d_model).float().to(device)
- indices = rearrange(torch.arange(self.seq_length), "s -> 1 s").to(device)
- mask = (indices < text_num).to(
- device
- ) # (B, S), indicating valid attribute locations
- feat[mask] = fn(inputs[mask]) # (B, S, C)
- return feat
-
- def text_emb_layer(
- self, inputs: Tensor, batch_num: int, text_num: Tensor, **kwargs: Any
- ) -> Tensor:
- inputs_fn = []
- for b in range(batch_num):
- tn = int(text_num[b].item())
- for t in range(tn):
- inputs_fn.append(inputs[b, t].view(-1))
- outs = None
- if len(inputs_fn) > 0:
- inputs_fn = torch.stack(inputs_fn)
- outs = self.text_emb_emb(inputs_fn)
- outs = outs.view(len(inputs_fn), self.d_model)
- return outs
-
- def text_local_img_emb_layer(
- self, inputs: Tensor, batch_num: int, text_num: Tensor, **kwargs: Any
- ) -> Tensor:
- inputs_fn = []
- for b in range(batch_num):
- tn = int(text_num[b].item())
- for t in range(tn):
- inputs_fn.append(inputs[b, t].view(-1))
- outs = None
- if len(inputs_fn) > 0:
- inputs_fn = torch.stack(inputs_fn)
- outs = self.text_local_img_emb_emb(inputs_fn)
- outs = outs.view(len(inputs_fn), self.d_model)
- return outs
-
- def text_font_emb_layer(
- self, inputs: Tensor, batch_num: int, text_num: Tensor
- ) -> Tensor:
- inputs_fn = []
- for b in range(batch_num):
- tn = int(text_num[b].item())
- for t in range(tn):
- inputs_fn.append(inputs[b, t].view(-1))
- outs = None
- if len(inputs_fn) > 0:
- inputs_fn = torch.stack(inputs_fn)
- outs = self.text_font_emb_emb(inputs_fn)
- outs = outs.view(len(inputs_fn), self.d_model)
- return outs
-
- def forward(self, model_inputs: ModelInput) -> Tensor:
- feats = []
- for prefix in self.prefix_list_element:
- flag = getattr(self, f"{prefix}_flag")
- if flag is True:
- start = time.time()
- logger.debug(f"{prefix=}")
- out = get_output(self, prefix, model_inputs, model_inputs.batch_num)
- logger.debug(f"{prefix=} {out.shape=} {time.time()-start}")
- feats.append(out)
- return feats
-
-
-class EmbeddingCanvasContext(nn.Module):
- def __init__(
- self,
- canvas_prefix_list: List,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- d_model: int,
- dropout: float,
- ):
- super(EmbeddingCanvasContext, self).__init__()
- self.d_model = d_model
- self.dropout = dropout
- self.prefix_list = canvas_prefix_list
- for prefix in self.prefix_list:
- target_embedding_config = getattr(embedding_config_canvas, prefix)
- setup_emb_layer(self, prefix, target_embedding_config)
- if target_embedding_config.specific_build is not None:
- build_func = getattr(self, target_embedding_config.specific_build)
- build_func()
-
- def get_feat(self, inputs: Tensor, **kwargs: Any) -> Tensor:
- with torch.no_grad():
- feat = self.imgencoder(inputs.float()).detach()
- feat = F.relu(self.avgpool(feat))
- feat = self.downdim(feat)
- feat = feat.view(feat.shape[0], feat.shape[1])
- return feat
-
- def canvas_bg_img_emb_layer(self, inputs: Tensor, **kwargs: Any) -> Tensor:
- feat = self.canvas_bg_img_emb_emb(inputs.float())
- feat = feat.view(feat.shape[0], feat.shape[1])
- return feat
-
- def get_features_via_fn(self, fn: Any, inputs: Tensor, **kwargs: Any) -> Tensor:
- out = fn(inputs)
- return out
-
- def forward(self, model_inputs: ModelInput) -> List:
- feats = []
- for prefix in self.prefix_list:
- flag = getattr(self, f"{prefix}_flag")
- if flag is True:
- logger.debug(f"canvas prefix {prefix}")
- out = get_output(self, prefix, model_inputs, model_inputs.batch_num)
- feats.append(out)
- return feats
diff --git a/src/typography_generation/model/encoder.py b/src/typography_generation/model/encoder.py
deleted file mode 100644
index a8ce865..0000000
--- a/src/typography_generation/model/encoder.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from typing import Tuple
-
-import torch
-from torch import Tensor, nn
-
-from typography_generation.model.common import (
- MyTransformerEncoder,
- MyTransformerEncoderLayer,
-)
-
-
-class Encoder(nn.Module):
- def __init__(
- self,
- d_model: int = 256,
- n_head: int = 8,
- dropout: float = 0.1,
- num_encoder_layers: int = 4,
- ):
- super(Encoder, self).__init__()
- encoder_norm = nn.LayerNorm(d_model)
- dim_feedforward = d_model * 2
- encoder_layer = MyTransformerEncoderLayer(
- d_model, n_head, dim_feedforward, dropout
- )
- self.encoder = MyTransformerEncoder(
- encoder_layer, num_encoder_layers, encoder_norm
- )
-
- def forward(self, src: Tensor, text_mask: Tensor) -> Tuple:
- text_mask_fill = text_mask.masked_fill(
- text_mask == 0, float("-inf")
- ).masked_fill(text_mask == 1, float(0.0))
- z = self.encoder(src, mask=None, src_key_padding_mask=text_mask_fill.bool())
- if torch.sum(torch.isnan(z)) > 0:
- z = z.masked_fill(torch.isnan(z), 0)
- return z
-
- def get_transformer_weight(self, src: Tensor, text_mask: Tensor) -> Tuple:
- text_mask_fill = text_mask.masked_fill(
- text_mask == 0, float("-inf")
- ).masked_fill(text_mask == 1, float(0.0))
- z, weights = self.encoder(
- src, mask=None, src_key_padding_mask=text_mask_fill, get_weight=True
- )
- if torch.sum(torch.isnan(z)) > 0:
- z = z.masked_fill(torch.isnan(z), 0)
- return z, weights
diff --git a/src/typography_generation/model/mfc.py b/src/typography_generation/model/mfc.py
deleted file mode 100644
index cc83c2e..0000000
--- a/src/typography_generation/model/mfc.py
+++ /dev/null
@@ -1,154 +0,0 @@
-from typing import Any, List, Dict
-import numpy as np
-
-import torch
-from torch import Tensor, nn
-from typography_generation.config.attribute_config import (
- CanvasContextEmbeddingAttributeConfig,
- TextElementContextEmbeddingAttributeConfig,
- TextElementContextPredictionAttributeConfig,
-)
-from typography_generation.io.crello_util import CrelloProcessor
-from typography_generation.io.data_object import ModelInput
-
-from logzero import logger
-
-from typography_generation.model.mlp import MLP
-
-
-class MFC(MLP):
- def __init__(
- self,
- prefix_list_element: List,
- prefix_list_canvas: List,
- prefix_list_target: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- prediction_config_element: TextElementContextPredictionAttributeConfig,
- d_model: int = 256,
- dropout: float = 0.1,
- seq_length: int = 50,
- **kwargs: Any,
- ) -> None:
- logger.info(f"MFC model")
- super().__init__(
- prefix_list_element,
- prefix_list_canvas,
- prefix_list_target,
- embedding_config_element,
- embedding_config_canvas,
- prediction_config_element,
- d_model,
- dropout,
- seq_length,
- )
- for prefix in prefix_list_target:
- target_prediction_config = getattr(prediction_config_element, prefix)
- setattr(self, f"{prefix}_loss_type", target_prediction_config.loss_type)
-
- def get_label(
- self,
- out: Tensor,
- text_index: int,
- batch_index: int = 0,
- ) -> Tensor:
- sorted_label = torch.sort(input=out[batch_index], dim=1, descending=True)[1]
- target_label = 0 # top1
- out_label = sorted_label[text_index][target_label].item()
- return out_label
-
- def get_out(
- self,
- out: Tensor,
- text_index: int,
- batch_index: int = 0,
- ) -> Tensor:
- _out = out[batch_index, text_index].data.cpu().numpy()
- return _out
-
- def get_outs(
- self,
- model_outs: Dict,
- target_prefix_list: List,
- text_index: int,
- batch_index: int = 0,
- ) -> Dict:
- outs = {}
- for prefix in target_prefix_list:
- out = model_outs[f"{prefix}"]
- loss_type = getattr(self, f"{prefix}_loss_type")
- if loss_type == "cre":
- _out = self.get_label(out, text_index, batch_index)
- else:
- _out = self.get_out(out, text_index, batch_index)
- outs[f"{prefix}"] = _out
- return outs
-
- def store(
- self,
- out_all: Dict,
- out_labels: Dict,
- target_prefix_list: List,
- text_index: int,
- ) -> Dict:
- for prefix in target_prefix_list:
- out_all[f"{prefix}"][text_index, 0] = out_labels[f"{prefix}"]
- return out_all
-
- def prediction(
- self,
- model_inputs: ModelInput,
- dataset: CrelloProcessor,
- target_prefix_list: List,
- start_index: int = 0,
- ) -> Tensor:
- target_text_num = int(model_inputs.canvas_text_num[0].item())
- start_index = min(start_index, target_text_num)
- for t in range(start_index, target_text_num):
- model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t)
-
- out_all = {}
- for prefix in target_prefix_list:
- loss_type = getattr(self, f"{prefix}_loss_type")
- if loss_type == "cre":
- out_all[prefix] = np.zeros((target_text_num, 1))
- else:
- out_all[prefix] = []
- for t in range(0, start_index):
- tar = getattr(model_inputs, f"{prefix}")[0, t].item()
- if loss_type == "cre":
- out_all[prefix][t, 0] = tar
- else:
- out_all[prefix].append(tar)
- for t in range(start_index, target_text_num):
- _, _, feat_cat = self.emb(model_inputs)
- model_outs = self.head(feat_cat)
- out = self.get_outs(model_outs, target_prefix_list, t)
- for prefix in target_prefix_list:
- loss_type = getattr(self, f"{prefix}_loss_type")
- if loss_type == "cre":
- out_all[f"{prefix}"][t, 0] = out[f"{prefix}"]
- else:
- out_all[f"{prefix}"].append(out[f"{prefix}"])
- return out_all
-
- def initialize_weights(self) -> None:
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.xavier_normal_(m.weight)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- m.weight.data.normal_(mean=0.0, std=0.02)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.LayerNorm):
- if m.weight is not None:
- m.weight.data.fill_(1)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.Embedding):
- m.weight.data.normal_(mean=0.0, std=0.02)
diff --git a/src/typography_generation/model/mlp.py b/src/typography_generation/model/mlp.py
deleted file mode 100644
index 81491f4..0000000
--- a/src/typography_generation/model/mlp.py
+++ /dev/null
@@ -1,175 +0,0 @@
-from typing import Any, List, Dict
-import numpy as np
-
-import torch
-from torch import Tensor, nn
-from typography_generation.config.attribute_config import (
- CanvasContextEmbeddingAttributeConfig,
- TextElementContextEmbeddingAttributeConfig,
- TextElementContextPredictionAttributeConfig,
-)
-from typography_generation.io.crello_util import CrelloProcessor
-from typography_generation.io.data_object import ModelInput
-from typography_generation.model.common import Linearx2
-from typography_generation.model.embedding import Embedding
-
-from logzero import logger
-
-
-class FCN(nn.Module):
- def __init__(self, d_model: int, label_num: int) -> None:
- super().__init__()
- self.fcn = Linearx2(d_model, label_num)
-
- def forward(self, elm_agg_emb: Tensor = None) -> Tensor:
- logits = elm_agg_emb.permute(1, 0, 2)
- logits = self.fcn(logits) # Shape [G, N, 2]
- return logits
-
-
-class MultiTask(nn.Module):
- def __init__(
- self,
- prefix_list_target: List,
- prediction_config: TextElementContextPredictionAttributeConfig,
- d_model: int = 256,
- ):
- super().__init__()
- self.prefix_list_target = prefix_list_target
- for prefix in self.prefix_list_target:
- target_prediction_config = getattr(prediction_config, prefix)
- layer = FCN(d_model, target_prediction_config.out_dim)
- setattr(self, f"{prefix}_layer", layer)
-
- def forward(self, elm_agg_emb: Tensor) -> dict:
- outputs = {}
- for prefix in self.prefix_list_target:
- layer = getattr(self, f"{prefix}_layer")
- out = layer(elm_agg_emb)
- outputs[prefix] = out
- return outputs
-
-
-class MLP(nn.Module):
- def __init__(
- self,
- prefix_list_element: List,
- prefix_list_canvas: List,
- prefix_list_target: List,
- embedding_config_element: TextElementContextEmbeddingAttributeConfig,
- embedding_config_canvas: CanvasContextEmbeddingAttributeConfig,
- prediction_config_element: TextElementContextPredictionAttributeConfig,
- d_model: int = 256,
- dropout: float = 0.1,
- seq_length: int = 50,
- **kwargs: Any,
- ) -> None:
- super().__init__()
- logger.info(f"MLP settings")
- logger.info(f"d_model: {d_model}")
- logger.info(f"seq_length: {seq_length}")
- self.embedding_config_element = embedding_config_element
-
- self.emb = Embedding(
- prefix_list_element,
- prefix_list_canvas,
- embedding_config_element,
- embedding_config_canvas,
- d_model,
- dropout,
- seq_length,
- )
- self.head = MultiTask(
- prefix_list_target,
- prediction_config_element,
- d_model,
- )
- self.initialize_weights()
-
- def forward(self, model_inputs: ModelInput) -> Tensor:
- (
- _,
- _,
- feat_cat,
- ) = self.emb(model_inputs)
- outs = self.head(feat_cat)
- return outs
-
- def get_labels(
- self,
- model_outs: Dict,
- target_prefix_list: List,
- text_index: int,
- batch_index: int = 0,
- ) -> Dict:
- out_labels = {}
- for prefix in target_prefix_list:
- out = model_outs[f"{prefix}"]
- sorted_label = torch.sort(input=out[batch_index], dim=1, descending=True)[1]
- target_label = 0 # top1
- out_label = sorted_label[text_index][target_label]
- out_labels[f"{prefix}"] = out_label
-
- return out_labels
-
- def store(
- self,
- out_labels_all: Dict,
- out_labels: Dict,
- target_prefix_list: List,
- text_index: int,
- ) -> Dict:
- for prefix in target_prefix_list:
- out_labels_all[f"{prefix}"][text_index, 0] = out_labels[f"{prefix}"].item()
- return out_labels_all
-
- def prediction(
- self,
- model_inputs: ModelInput,
- dataset: CrelloProcessor,
- target_prefix_list: List,
- start_index: int = 0,
- ) -> Tensor:
- target_text_num = int(model_inputs.canvas_text_num[0].item())
- start_index = min(start_index, target_text_num)
- for t in range(start_index, target_text_num):
- model_inputs.zeroinitialize_th_style_attributes(target_prefix_list, t)
-
- out_labels_all = {}
- for prefix in target_prefix_list:
- out_labels_all[prefix] = np.zeros((target_text_num, 1))
- for t in range(0, start_index):
- tar = getattr(model_inputs, f"{prefix}")[0, t].item()
- out_labels_all[prefix][t, 0] = tar
- for t in range(start_index, target_text_num):
- _, _, feat_cat = self.emb(model_inputs)
- model_outs = self.head(feat_cat)
- out_labels = self.get_labels(model_outs, target_prefix_list, t)
- model_inputs.update_th_style_attributes(
- self.embedding_config_element, target_prefix_list, out_labels, t
- )
- out_labels_all = self.store(
- out_labels_all, out_labels, target_prefix_list, t
- )
- return out_labels_all
-
- def initialize_weights(self) -> None:
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.xavier_normal_(m.weight)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.BatchNorm2d):
- m.weight.data.fill_(1)
- m.bias.data.zero_()
- elif isinstance(m, nn.Linear):
- m.weight.data.normal_(mean=0.0, std=0.02)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.LayerNorm):
- if m.weight is not None:
- m.weight.data.fill_(1)
- if m.bias is not None:
- m.bias.data.zero_()
- elif isinstance(m, nn.Embedding):
- m.weight.data.normal_(mean=0.0, std=0.02)
diff --git a/src/typography_generation/model/model.py b/src/typography_generation/model/model.py
deleted file mode 100644
index 6b0d633..0000000
--- a/src/typography_generation/model/model.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from typing import Any, Dict
-from torch import nn
-
-from typography_generation.model.bart import BART
-from typography_generation.model.baseline import AllRandom, AllZero, Mode
-from typography_generation.model.canvas_vae import CanvasVAE
-from typography_generation.model.mfc import MFC
-from typography_generation.model.mlp import MLP
-
-MODEL_REGISTRY: Dict[str, nn.Module] = {
- "bart": BART,
- "mlp": MLP,
- "mfc": MFC,
- "canvasvae": CanvasVAE,
- "allzero": AllZero,
- "allrandom": AllRandom,
- "mode": Mode,
-}
-
-
-def create_model(model_name: str, **kwargs: Any) -> nn.Module:
- """Factory function to create a model instance."""
- model = MODEL_REGISTRY[model_name](**kwargs)
- model.model_name = model_name
- return model
diff --git a/src/typography_generation/preprocess/__init__.py b/src/typography_generation/preprocess/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/src/typography_generation/preprocess/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/typography_generation/preprocess/map_features.py b/src/typography_generation/preprocess/map_features.py
deleted file mode 100644
index ece5c2d..0000000
--- a/src/typography_generation/preprocess/map_features.py
+++ /dev/null
@@ -1,280 +0,0 @@
-import pickle
-from typing import Any, Dict, List, Tuple
-import datasets
-import os
-import numpy as np
-from PIL import Image
-from logzero import logger
-import torch
-import skia
-from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
-
-from typography_generation.visualization.renderer_util import (
- get_skia_font,
- get_text_actual_width,
- get_texts,
-)
-
-
-def get_scaleinfo(
- element_data: Dict,
-) -> List:
- svgid = element_data["id"]
- scaleinfo = svgid2scaleinfo[svgid]
- return list(scaleinfo)
-
-
-def get_canvassize(
- bg_img: Any,
-) -> List:
- img_size = (bg_img.size[1], bg_img.size[0])
- return list(img_size)
-
-
-def get_canvas_bg_img_emb(bg_img: Any, **kwargs: Any) -> np.array:
- inputs = processor(images=[bg_img], return_tensors="pt")
- inputs["pixel_values"] = inputs["pixel_values"].to(device)
- image_feature = model.get_image_features(**inputs)
- return list(image_feature.data.cpu().numpy().flatten())
-
-
-def get_text_emb_list(
- element_data: Dict,
-) -> List:
- text_emb_list: List[List]
- text_emb_list = []
- for k in range(len(element_data["text"])):
- text = element_data["text"][k]
- inputs = text_tokenizer([text], padding=True, return_tensors="pt")
- if inputs["input_ids"].shape[1] > 77:
- inp = inputs["input_ids"][:, :77]
- else:
- inp = inputs["input_ids"]
- inp = inp.to(device)
- text_features = model.get_text_features(inp).data.cpu().numpy()[0]
- text_emb_list.append(text_features)
- return text_emb_list
-
-
-def _get_text_actual_width(
- element_data: Dict,
- W: int,
-) -> List:
- svgid = element_data["id"]
- scaleinfo = svgid2scaleinfo[svgid]
- scale_h, scale_w = scaleinfo
- text_actual_width_list = []
- element_num = len(element_data["text"])
-
- for i in range(element_num):
- if element_data["text"][i] == "":
- text_actual_width_list.append(None)
- else:
- texts = get_texts(element_data, i)
-
- font_label = element_data["font"][i]
- font_name = fontlabel2fontname(int(font_label))
- font_name = font_name.replace(" ", "_")
- font_skia, _ = get_skia_font(
- font2ttf, fontmgr, element_data, i, font_name, scale_h
- )
- text_width = get_text_actual_width(
- element_data, i, texts, font_skia, scale_w
- )
- text_actual_width_list.append(text_width / float(W))
- return text_actual_width_list
-
-
-def get_text_center_y(element_data: dict, text_index: int) -> float:
- text_height = element_data["height"][text_index]
- top = element_data["top"][text_index]
- center_y = (top + top + text_height) / 2.0
- return float(center_y)
-
-
-def get_text_center_y_list(
- element_data: dict,
-) -> List:
- text_center_y_list = []
- element_num = len(element_data["text"])
- for text_id in range(element_num):
- if element_data["text"][text_id] == "":
- text_center_y_list.append(None)
- else:
- text_center_y_list.append(get_text_center_y(element_data, text_id))
- return text_center_y_list
-
-
-def get_text_center_x(
- element_data: dict,
- text_index: int,
- text_actual_width: float,
-) -> float:
- left = element_data["left"][text_index]
- w = element_data["width"][text_index]
- textAlign = element_data["text_align"][text_index]
- right = left + w
- if textAlign == 1:
- center_x = (left + right) / 2.0
- elif textAlign == 3:
- center_x = right - text_actual_width / 2.0
- elif textAlign == 2:
- center_x = left + text_actual_width / 2.0
- return float(center_x)
-
-
-def get_text_center_x_list(
- element_data: dict,
- text_actual_width: float,
-) -> List:
- svgid = element_data["id"]
- text_center_x_list = []
- element_num = len(element_data["text"])
- for text_id in range(element_num):
- if element_data["text"][text_id] == "":
- text_center_x_list.append(None)
- else:
- _text_actual_width = text_actual_width[text_id]
- text_center_x_list.append(
- get_text_center_x(element_data, text_id, _text_actual_width)
- )
- return text_center_x_list
-
-
-def get_text_local_img(
- img: Any,
- text_center_y: float,
- text_center_x: float,
- H: int,
- W: int,
-) -> np.array:
- text_center_y = text_center_y * H
- text_center_x = text_center_x * W
-
- text_center_y = min(max(text_center_y, 0), H)
- text_center_x = min(max(text_center_x, 0), W)
- img = img.resize((640, 640))
- img = np.array(img)
- local_img_size = 64 * 5
- local_img_size_half = local_img_size // 2
- img_pad = np.zeros((640 + local_img_size, 640 + local_img_size, 3))
- img_pad[
- local_img_size_half : 640 + local_img_size_half,
- local_img_size_half : 640 + local_img_size_half,
- ] = img
- h_rate = 640 / float(H)
- w_rate = 640 / float(W)
- text_center_y = int(np.round(text_center_y * h_rate + local_img_size_half))
- text_center_x = int(np.round(text_center_x * w_rate + local_img_size_half))
- local_img = img_pad[
- text_center_y - local_img_size_half : text_center_y + local_img_size_half,
- text_center_x - local_img_size_half : text_center_x + local_img_size_half,
- ]
- return local_img
-
-
-def get_text_local_img_emb_list(
- element_data: Dict,
- bg_img: Any,
- text_center_y: float,
- text_center_x: float,
-) -> List:
- text_local_img_emb_list: List[List]
- text_local_img_emb_list = []
- for k in range(len(element_data["text"])):
- if element_data["text"][k] == "":
- text_local_img_emb_list.append([])
- else:
- H, W = bg_img.size[1], bg_img.size[0]
- local_img = get_text_local_img(
- bg_img.copy(), text_center_y[k], text_center_x[k], H, W
- )
- local_img = Image.fromarray(local_img.astype(np.uint8)).resize((224, 224))
- inputs = processor(images=[local_img], return_tensors="pt")
- inputs["pixel_values"] = inputs["pixel_values"].to(device)
- image_feature = model.get_image_features(**inputs)
- text_local_img_emb_list.append(image_feature.data.cpu().numpy())
-
- return text_local_img_emb_list
-
-
-def get_orderlist(
- center_y: List[float],
- center_x: List[float],
-) -> List:
- """
- Sort elments based on the raster scan order.
- """
- center_y = [10000 if y is None else y for y in center_y]
- center_x = [10000 if x is None else x for x in center_x]
- center_y = np.array(center_y)
- center_x = np.array(center_x)
- sortedid = np.argsort(center_y * 1000 + center_x)
- return list(sortedid)
-
-
-def add_features(
- element_data: Dict,
-) -> Dict:
- svgid = element_data["id"]
- fn = os.path.join(data_dir, "generate_bg_png", f"{svgid}.png")
- bg_img = Image.open(fn).convert("RGB") # background image
- element_data["scale_box"] = get_scaleinfo(element_data)
- element_data["canvas_bg_size"] = get_canvassize(bg_img)
- element_data["canvas_bg_img_emb"] = get_canvas_bg_img_emb(bg_img)
- element_data["text_emb"] = get_text_emb_list(element_data)
- text_actual_width = _get_text_actual_width(element_data, bg_img.size[0])
- text_center_y = get_text_center_y_list(element_data)
- text_center_x = get_text_center_x_list(element_data, text_actual_width)
- element_data["text_center_y"] = text_center_y
- element_data["text_center_x"] = text_center_x
- element_data["text_actual_width"] = text_actual_width
- element_data["text_local_img_emb"] = get_text_local_img_emb_list(
- element_data, bg_img, text_center_y, text_center_x
- )
- element_data["order_list"] = get_orderlist(text_center_y, text_center_x)
- return element_data
-
-
-def map_features(
- _data_dir: str,
-):
- fn = os.path.join(_data_dir, "svgid2scaleinfo.pkl")
- global svgid2scaleinfo
- global processor
- global text_tokenizer
- global model
- global device
- global data_dir
- global font2ttf
- global fontmgr
- global fontlabel2fontname
-
- dataset = datasets.load_dataset("cyberagent/crello", revision="3.1")
-
- data_dir = _data_dir
- svgid2scaleinfo = pickle.load(open(fn, "rb"))
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
- text_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- model.to(device)
-
- fn = os.path.join(data_dir, "font2ttf.pkl")
- _font2ttf = pickle.load(open(fn, "rb"))
- font2ttf = {}
- for key in _font2ttf.keys():
- tmp = _font2ttf[key].split("/data/dataset/crello/")[1]
- fn = os.path.join(data_dir, tmp)
- font2ttf[key] = fn
- font2ttf = font2ttf
- fontmgr = skia.FontMgr()
- fontlabel2fontname = dataset["train"].features["font"].feature.int2str
-
- dataset_new = {}
- for dataset_division in ["train", "validation", "test"]:
- logger.info(f"{dataset_division=}")
- _dataset = dataset[dataset_division]
- dataset_new[dataset_division] = _dataset.map(add_features)
- dataset_new = datasets.DatasetDict(dataset_new)
- dataset_new.save_to_disk(f"{_data_dir}/crello_map_features")
diff --git a/src/typography_generation/tools/__init__.py b/src/typography_generation/tools/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/src/typography_generation/tools/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/typography_generation/tools/color_func.py b/src/typography_generation/tools/color_func.py
deleted file mode 100644
index 2675bbf..0000000
--- a/src/typography_generation/tools/color_func.py
+++ /dev/null
@@ -1,457 +0,0 @@
-from typing import Any, Tuple
-import numpy as np
-
-_illuminants = {
- "A": {
- "2": (1.098466069456375, 1, 0.3558228003436005),
- "10": (1.111420406956693, 1, 0.3519978321919493),
- "R": (1.098466069456375, 1, 0.3558228003436005),
- },
- "B": {
- "2": (0.9909274480248003, 1, 0.8531327322886154),
- "10": (0.9917777147717607, 1, 0.8434930535866175),
- "R": (0.9909274480248003, 1, 0.8531327322886154),
- },
- "C": {
- "2": (0.980705971659919, 1, 1.1822494939271255),
- "10": (0.9728569189782166, 1, 1.1614480488951577),
- "R": (0.980705971659919, 1, 1.1822494939271255),
- },
- "D50": {
- "2": (0.9642119944211994, 1, 0.8251882845188288),
- "10": (0.9672062750333777, 1, 0.8142801513128616),
- "R": (0.9639501491621826, 1, 0.8241280285499208),
- },
- "D55": {
- "2": (0.956797052643698, 1, 0.9214805860173273),
- "10": (0.9579665682254781, 1, 0.9092525159847462),
- "R": (0.9565317453467969, 1, 0.9202554587037198),
- },
- "D65": {
- "2": (0.95047, 1.0, 1.08883), # This was: `lab_ref_white`
- "10": (0.94809667673716, 1, 1.0730513595166162),
- "R": (0.9532057125493769, 1, 1.0853843816469158),
- },
- "D75": {
- "2": (0.9497220898840717, 1, 1.226393520724154),
- "10": (0.9441713925645873, 1, 1.2064272211720228),
- "R": (0.9497220898840717, 1, 1.226393520724154),
- },
- "E": {"2": (1.0, 1.0, 1.0), "10": (1.0, 1.0, 1.0), "R": (1.0, 1.0, 1.0)},
-}
-xyz_from_rgb = np.array(
- [
- [0.412453, 0.357580, 0.180423],
- [0.212671, 0.715160, 0.072169],
- [0.019334, 0.119193, 0.950227],
- ]
-)
-
-
-def xyz_tristimulus_values(illuminant: str, observer: str, dtype: Any) -> np.array:
- """Get the CIE XYZ tristimulus values.
-
- Given an illuminant and observer, this function returns the CIE XYZ tristimulus
- values [2]_ scaled such that :math:`Y = 1`.
-
- Parameters
- ----------
- illuminant : {"A", "B", "C", "D50", "D55", "D65", "D75", "E"}
- The name of the illuminant (the function is NOT case sensitive).
- observer : {"2", "10", "R"}
- One of: 2-degree observer, 10-degree observer, or 'R' observer as in
- R function ``grDevices::convertColor`` [3]_.
- dtype: dtype, optional
- Output data type.
-
- Returns
- -------
- values : array
- Array with 3 elements :math:`X, Y, Z` containing the CIE XYZ tristimulus values
- of the given illuminant.
-
- Raises
- ------
- ValueError
- If either the illuminant or the observer angle are not supported or
- unknown.
-
- References
- ----------
- .. [1] https://en.wikipedia.org/wiki/Standard_illuminant#White_points_of_standard_illuminants
- .. [2] https://en.wikipedia.org/wiki/CIE_1931_color_space#Meaning_of_X,_Y_and_Z
- .. [3] https://www.rdocumentation.org/packages/grDevices/versions/3.6.2/topics/convertColor
-
- Notes
- -----
- The CIE XYZ tristimulus values are calculated from :math:`x, y` [1]_, using the
- formula
-
- .. math:: X = x / y
-
- .. math:: Y = 1
-
- .. math:: Z = (1 - x - y) / y
-
- The only exception is the illuminant "D65" with aperture angle 2° for
- backward-compatibility reasons.
-
- Examples
- --------
- Get the CIE XYZ tristimulus values for a "D65" illuminant for a 10 degree field of
- view
-
- >>> xyz_tristimulus_values(illuminant="D65", observer="10")
- array([0.94809668, 1. , 1.07305136])
- """
- illuminant = illuminant.upper()
- observer = observer.upper()
- try:
- return np.asarray(_illuminants[illuminant][observer], dtype=dtype)
- except KeyError:
- raise ValueError(
- f"Unknown illuminant/observer combination "
- f"(`{illuminant}`, `{observer}`)"
- )
-
-
-def xyz2lab(
- xyz: np.array, illuminant: str = "D65", observer: str = "2", channel_axis: int = -1
-) -> np.array:
- """XYZ to CIE-LAB color space conversion.
-
- Parameters
- ----------
- xyz : (..., 3, ...) array_like
- The image in XYZ format. By default, the final dimension denotes
- channels.
- illuminant : {"A", "B", "C", "D50", "D55", "D65", "D75", "E"}, optional
- The name of the illuminant (the function is NOT case sensitive).
- observer : {"2", "10", "R"}, optional
- One of: 2-degree observer, 10-degree observer, or 'R' observer as in
- R function grDevices::convertColor.
- channel_axis : int, optional
- This parameter indicates which axis of the array corresponds to
- channels.
-
- .. versionadded:: 0.19
- ``channel_axis`` was added in 0.19.
-
- Returns
- -------
- out : (..., 3, ...) ndarray
- The image in CIE-LAB format. Same dimensions as input.
-
- Raises
- ------
- ValueError
- If `xyz` is not at least 2-D with shape (..., 3, ...).
- ValueError
- If either the illuminant or the observer angle is unsupported or
- unknown.
-
- Notes
- -----
- By default Observer="2", Illuminant="D65". CIE XYZ tristimulus values
- x_ref=95.047, y_ref=100., z_ref=108.883. See function
- :func:`~.xyz_tristimulus_values` for a list of supported illuminants.
-
- References
- ----------
- .. [1] http://www.easyrgb.com/en/math.php
- .. [2] https://en.wikipedia.org/wiki/CIELAB_color_space
-
- Examples
- --------
- >>> from skimage import data
- >>> from skimage.color import rgb2xyz, xyz2lab
- >>> img = data.astronaut()
- >>> img_xyz = rgb2xyz(img)
- >>> img_lab = xyz2lab(img_xyz)
- """
- arr = np.array(xyz)
-
- xyz_ref_white = xyz_tristimulus_values(
- illuminant=illuminant, observer=observer, dtype=arr.dtype
- )
-
- # scale by CIE XYZ tristimulus values of the reference white point
- arr = arr / xyz_ref_white
-
- # Nonlinear distortion and linear transformation
- mask = arr > 0.008856
- arr[mask] = np.cbrt(arr[mask])
- arr[~mask] = 7.787 * arr[~mask] + 16.0 / 116.0
-
- x, y, z = arr[..., 0], arr[..., 1], arr[..., 2]
-
- # Vector scaling
- L = (116.0 * y) - 16.0
- a = 500.0 * (x - y)
- b = 200.0 * (y - z)
-
- return np.concatenate([x[..., np.newaxis] for x in [L, a, b]], axis=-1)
-
-
-def rgb2xyz(rgb: np.array, channel_axis: int = -1) -> np.array:
- """RGB to XYZ color space conversion.
-
- Parameters
- ----------
- rgb : (..., 3, ...) array_like
- The image in RGB format. By default, the final dimension denotes
- channels.
- channel_axis : int, optional
- This parameter indicates which axis of the array corresponds to
- channels.
-
- .. versionadded:: 0.19
- ``channel_axis`` was added in 0.19.
-
- Returns
- -------
- out : (..., 3, ...) ndarray
- The image in XYZ format. Same dimensions as input.
-
- Raises
- ------
- ValueError
- If `rgb` is not at least 2-D with shape (..., 3, ...).
-
- Notes
- -----
- The CIE XYZ color space is derived from the CIE RGB color space. Note
- however that this function converts from sRGB.
-
- References
- ----------
- .. [1] https://en.wikipedia.org/wiki/CIE_1931_color_space
-
- Examples
- --------
- >>> from skimage import data
- >>> img = data.astronaut()
- >>> img_xyz = rgb2xyz(img)
- """
- # Follow the algorithm from http://www.easyrgb.com/index.php
- # except we don't multiply/divide by 100 in the conversion
- arr = np.array(rgb)
- mask = arr > 0.04045
- arr[mask] = np.power((arr[mask] + 0.055) / 1.055, 2.4)
- arr[~mask] /= 12.92
- return arr @ xyz_from_rgb.T.astype(arr.dtype)
-
-
-def rgb2lab(
- rgb: np.array,
- illuminant: str = "D65",
- observer: str = "2",
- channel_axis: int = -1,
-) -> np.array:
- """Conversion from the sRGB color space (IEC 61966-2-1:1999)
- to the CIE Lab colorspace under the given illuminant and observer.
-
- Parameters
- ----------
- rgb : (..., 3, ...) array_like
- The image in RGB format. By default, the final dimension denotes
- channels.
- illuminant : {"A", "B", "C", "D50", "D55", "D65", "D75", "E"}, optional
- The name of the illuminant (the function is NOT case sensitive).
- observer : {"2", "10", "R"}, optional
- The aperture angle of the observer.
- channel_axis : int, optional
- This parameter indicates which axis of the array corresponds to
- channels.
-
- .. versionadded:: 0.19
- ``channel_axis`` was added in 0.19.
-
- Returns
- -------
- out : (..., 3, ...) ndarray
- The image in Lab format. Same dimensions as input.
-
- Raises
- ------
- ValueError
- If `rgb` is not at least 2-D with shape (..., 3, ...).
-
- Notes
- -----
- RGB is a device-dependent color space so, if you use this function, be
- sure that the image you are analyzing has been mapped to the sRGB color
- space.
-
- This function uses rgb2xyz and xyz2lab.
- By default Observer="2", Illuminant="D65". CIE XYZ tristimulus values
- x_ref=95.047, y_ref=100., z_ref=108.883. See function
- :func:`~.xyz_tristimulus_values` for a list of supported illuminants.
-
- References
- ----------
- .. [1] https://en.wikipedia.org/wiki/Standard_illuminant
- """
- return xyz2lab(rgb2xyz(rgb), illuminant, observer)
-
-
-def _cart2polar_2pi(x: np.array, y: np.array) -> Tuple[np.array, np.array]:
- """convert cartesian coordinates to polar (uses non-standard theta range!)
-
- NON-STANDARD RANGE! Maps to ``(0, 2*pi)`` rather than usual ``(-pi, +pi)``
- """
- r, t = np.hypot(x, y), np.arctan2(y, x)
- t += np.where(t < 0.0, 2 * np.pi, 0)
- return r, t
-
-
-def _float_inputs(
- lab1: np.array, lab2: np.array, allow_float32: bool = True
-) -> Tuple[np.array, np.array]:
- lab1 = np.asarray(lab1)
- lab2 = np.asarray(lab2)
- float_dtype = np.float64
- lab1 = lab1.astype(float_dtype, copy=False)
- lab2 = lab2.astype(float_dtype, copy=False)
- return lab1, lab2
-
-
-def deltaE_ciede2000(
- lab1: np.array,
- lab2: np.array,
- kL: int = 1,
- kC: int = 1,
- kH: int = 1,
- channel_axis: int = -1,
-) -> np.array:
- """Color difference as given by the CIEDE 2000 standard.
-
- CIEDE 2000 is a major revision of CIDE94. The perceptual calibration is
- largely based on experience with automotive paint on smooth surfaces.
-
- Parameters
- ----------
- lab1 : array_like
- reference color (Lab colorspace)
- lab2 : array_like
- comparison color (Lab colorspace)
- kL : float (range), optional
- lightness scale factor, 1 for "acceptably close"; 2 for "imperceptible"
- see deltaE_cmc
- kC : float (range), optional
- chroma scale factor, usually 1
- kH : float (range), optional
- hue scale factor, usually 1
- channel_axis : int, optional
- This parameter indicates which axis of the arrays corresponds to
- channels.
-
- .. versionadded:: 0.19
- ``channel_axis`` was added in 0.19.
-
- Returns
- -------
- deltaE : array_like
- The distance between `lab1` and `lab2`
-
- Notes
- -----
- CIEDE 2000 assumes parametric weighting factors for the lightness, chroma,
- and hue (`kL`, `kC`, `kH` respectively). These default to 1.
-
- References
- ----------
- .. [1] https://en.wikipedia.org/wiki/Color_difference
- .. [2] http://www.ece.rochester.edu/~gsharma/ciede2000/ciede2000noteCRNA.pdf
- :DOI:`10.1364/AO.33.008069`
- .. [3] M. Melgosa, J. Quesada, and E. Hita, "Uniformity of some recent
- color metrics tested with an accurate color-difference tolerance
- dataset," Appl. Opt. 33, 8069-8077 (1994).
- """
- lab1, lab2 = _float_inputs(lab1, lab2, allow_float32=True)
-
- channel_axis = channel_axis % lab1.ndim
- unroll = False
- if lab1.ndim == 1 and lab2.ndim == 1:
- unroll = True
- if lab1.ndim == 1:
- lab1 = lab1[None, :]
- if lab2.ndim == 1:
- lab2 = lab2[None, :]
- channel_axis += 1
- L1, a1, b1 = np.moveaxis(lab1, source=channel_axis, destination=0)[:3]
- L2, a2, b2 = np.moveaxis(lab2, source=channel_axis, destination=0)[:3]
-
- # distort `a` based on average chroma
- # then convert to lch coordinates from distorted `a`
- # all subsequence calculations are in the new coordinates
- # (often denoted "prime" in the literature)
- Cbar = 0.5 * (np.hypot(a1, b1) + np.hypot(a2, b2))
- c7 = Cbar**7
- G = 0.5 * (1 - np.sqrt(c7 / (c7 + 25**7)))
- scale = 1 + G
- C1, h1 = _cart2polar_2pi(a1 * scale, b1)
- C2, h2 = _cart2polar_2pi(a2 * scale, b2)
- # recall that c, h are polar coordinates. c==r, h==theta
-
- # cide2000 has four terms to delta_e:
- # 1) Luminance term
- # 2) Hue term
- # 3) Chroma term
- # 4) hue Rotation term
-
- # lightness term
- Lbar = 0.5 * (L1 + L2)
- tmp = (Lbar - 50) ** 2
- SL = 1 + 0.015 * tmp / np.sqrt(20 + tmp)
- L_term = (L2 - L1) / (kL * SL)
-
- # chroma term
- Cbar = 0.5 * (C1 + C2) # new coordinates
- SC = 1 + 0.045 * Cbar
- C_term = (C2 - C1) / (kC * SC)
-
- # hue term
- h_diff = h2 - h1
- h_sum = h1 + h2
- CC = C1 * C2
-
- dH = h_diff.copy()
- dH[h_diff > np.pi] -= 2 * np.pi
- dH[h_diff < -np.pi] += 2 * np.pi
- dH[CC == 0.0] = 0.0 # if r == 0, dtheta == 0
- dH_term = 2 * np.sqrt(CC) * np.sin(dH / 2)
-
- Hbar = h_sum.copy()
- mask = np.logical_and(CC != 0.0, np.abs(h_diff) > np.pi)
- Hbar[mask * (h_sum < 2 * np.pi)] += 2 * np.pi
- Hbar[mask * (h_sum >= 2 * np.pi)] -= 2 * np.pi
- Hbar[CC == 0.0] *= 2
- Hbar *= 0.5
-
- T = (
- 1
- - 0.17 * np.cos(Hbar - np.deg2rad(30))
- + 0.24 * np.cos(2 * Hbar)
- + 0.32 * np.cos(3 * Hbar + np.deg2rad(6))
- - 0.20 * np.cos(4 * Hbar - np.deg2rad(63))
- )
- SH = 1 + 0.015 * Cbar * T
-
- H_term = dH_term / (kH * SH)
-
- # hue rotation
- c7 = Cbar**7
- Rc = 2 * np.sqrt(c7 / (c7 + 25**7))
- dtheta = np.deg2rad(30) * np.exp(-(((np.rad2deg(Hbar) - 275) / 25) ** 2))
- R_term = -np.sin(2 * dtheta) * Rc * C_term * H_term
-
- # put it all together
- dE2 = L_term**2
- dE2 += C_term**2
- dE2 += H_term**2
- dE2 += R_term
- ans = np.sqrt(np.maximum(dE2, 0))
- if unroll:
- ans = ans[0]
- return ans
diff --git a/src/typography_generation/tools/denormalizer.py b/src/typography_generation/tools/denormalizer.py
deleted file mode 100644
index d0d06d2..0000000
--- a/src/typography_generation/tools/denormalizer.py
+++ /dev/null
@@ -1,154 +0,0 @@
-from typing import Dict, List, Tuple, Union
-
-import numpy as np
-from typography_generation.io.crello_util import CrelloProcessor
-from typography_generation.io.data_object import DesignContext
-from logzero import logger
-
-
-class Denormalizer:
- def __init__(self, dataset: CrelloProcessor):
- self.dataset = dataset
-
- def denormalize(
- self,
- prefix: str,
- text_num: int,
- prediction: np.array,
- design_context: DesignContext,
- ) -> Tuple:
- pred = prediction[prefix]
- gt = design_context.get_data(prefix)
- pred_token = []
- gt_token = []
- pred_denorm = []
- gt_denorm = []
- canvas_img_size_h, canvas_img_size_w = design_context.canvas_context.img_size
- canvas_h_scale, canvas_w_scale = design_context.canvas_context.scale_box
- for t in range(text_num):
- g = gt[t]
- if prefix in self.dataset.tokenizer.prefix_list:
- p_token = pred[t]
- p = self.dataset.tokenizer.detokenize(prefix, p_token)
- g_token = self.dataset.tokenizer.tokenize(prefix, g)
- elif prefix in self.dataset.tokenizer.rawdata_list:
- p = pred[t]
- p_token = getattr(self.dataset, f"raw2token_{prefix}")(p)
- g_token = getattr(self.dataset, f"raw2token_{prefix}")(g)
- if self.dataset.tokenizer.rawdata_out_format[prefix] == "token":
- p = p_token
- g = g_token
- else:
- p = pred[t]
- else:
- p_token = pred[t]
- p = p_token
- g_token = g
- p = self.denormalize_elm(
- prefix,
- p,
- canvas_img_size_h,
- canvas_img_size_w,
- canvas_h_scale,
- canvas_w_scale,
- )
- g = self.denormalize_elm(
- prefix,
- g,
- canvas_img_size_h,
- canvas_img_size_w,
- canvas_h_scale,
- canvas_w_scale,
- )
- pred_token.append([p_token])
- gt_token.append(g_token)
- pred_denorm.append([p])
- gt_denorm.append(g)
- return pred_token, gt_token, pred_denorm, gt_denorm
-
- def denormalize_gt(
- self,
- prefix: str,
- text_num: int,
- design_context: DesignContext,
- ) -> Tuple:
- gt = design_context.get_data(prefix)
- gt_token = []
- gt_denorm = []
- canvas_img_size_h, canvas_img_size_w = design_context.canvas_context.img_size
- canvas_h_scale, canvas_w_scale = design_context.canvas_context.scale_box
- logger.debug(f"{prefix} {gt}")
- for t in range(text_num):
- g = gt[t]
- if prefix in self.dataset.tokenizer.prefix_list:
- g_token = self.dataset.tokenizer.tokenize(prefix, g)
- elif prefix in self.dataset.tokenizer.rawdata_list:
- g_token = getattr(self.dataset, f"raw2token_{prefix}")(g)
- if self.dataset.tokenizer.rawdata_out_format[prefix] == "token":
- g = g_token
- else:
- g_token = g
- g = self.denormalize_elm(
- prefix,
- g,
- canvas_img_size_h,
- canvas_img_size_w,
- canvas_h_scale,
- canvas_w_scale,
- )
- gt_token.append(g_token)
- gt_denorm.append(g)
- return gt_token, gt_denorm
-
- def denormalize_elm(
- self,
- prefix: str,
- val: Union[int, float, Tuple],
- canvas_img_size_h: int,
- canvas_img_size_w: int,
- canvas_h_scale: float,
- canvas_w_scale: float,
- ) -> Union[int, float, Tuple]:
- if hasattr(self.dataset, f"denorm_{prefix}"):
- func = getattr(self.dataset, f"denorm_{prefix}")
- data_info = {
- "val": val,
- "img_height": canvas_img_size_h,
- "img_width": canvas_img_size_w,
- "scale_h": canvas_h_scale,
- "scale_w": canvas_w_scale,
- }
- val = func(**data_info)
- return val
-
- def convert_attributes(
- self,
- prefix: str,
- pred: List,
- element_data: Dict,
- text_ids: List,
- scale_h: float,
- ) -> Dict:
- converter = getattr(self.dataset, f"convert_{prefix}")
- converted_attributes = []
- for i, text_index in enumerate(text_ids):
- inputs = {
- "val": pred[i][0],
- "element_data": element_data,
- "text_index": text_index,
- "scale_h": scale_h,
- }
- _converted_attributes = converter(**inputs)
- converted_attributes.append(_converted_attributes)
- if len(text_ids) > 0:
- converted_attribute_prefixes = list(_converted_attributes.keys())
- converted_attribute_dict = {}
- for prefix in converted_attribute_prefixes:
- attribute_data = []
- for _converted_attributes in converted_attributes:
- val = _converted_attributes[prefix]
- attribute_data.append([val])
- converted_attribute_dict[prefix] = attribute_data
- return converted_attribute_dict
- else:
- return None
diff --git a/src/typography_generation/tools/evaluator.py b/src/typography_generation/tools/evaluator.py
deleted file mode 100644
index 1099221..0000000
--- a/src/typography_generation/tools/evaluator.py
+++ /dev/null
@@ -1,225 +0,0 @@
-import os
-import pickle
-import time
-from typing import Dict, List
-
-import torch
-import torch.nn as nn
-import torch.utils.data
-from logzero import logger
-from typography_generation.io.data_loader import CrelloLoader
-from typography_generation.io.data_object import (
- DesignContext,
- ModelInput,
- PrefixListObject,
-)
-from typography_generation.tools.denormalizer import Denormalizer
-from typography_generation.tools.score_func import (
- EvalDataEntire,
- EvalDataInstance,
-)
-from typography_generation.tools.train import collate_batch
-from typography_generation.visualization.renderer import TextRenderer
-from typography_generation.visualization.visualizer import (
- get_text_ids,
-)
-
-show_classfication_score_att = [
- "text_font",
- "text_font_emb",
- "text_align_type",
- "text_capitalize",
-]
-show_abs_erros_att = [
- "text_font_size",
- "text_font_size_raw",
- "text_center_y",
- "text_center_x",
- "text_letter_spacing",
- "text_angle",
- "text_line_height_scale",
-]
-
-show_bigram_score_att = [
- "text_font",
- "text_font_emb",
- "text_font_color",
- "text_align_type",
- "text_capitalize",
- "text_font_size",
- "text_font_size_raw",
- "text_center_y",
- "text_center_x",
- "text_angle",
- "text_letter_spacing",
- "text_line_height_scale",
-]
-
-
-############################################################
-# Evaluator
-############################################################
-class Evaluator:
- def __init__(
- self,
- model: nn.Module,
- gpu: bool,
- save_dir: str,
- dataset: CrelloLoader,
- prefix_list_object: PrefixListObject,
- batch_size: int = 1,
- num_worker: int = 2,
- show_interval: int = 100,
- dataset_division: str = "test",
- save_file_prefix: str = "score",
- debug: bool = False,
- ) -> None:
- self.gpu = gpu
- self.save_dir = save_dir
- self.batch_size = batch_size
- self.num_worker = num_worker
- self.show_interval = show_interval
- self.debug = debug
- self.prefix_list_target = prefix_list_object.target
- self.dataset_division = dataset_division
-
- self.dataset = dataset
-
- self.model = model
- if gpu is True:
- self.model.cuda()
-
- self.entire_data = EvalDataEntire(
- self.prefix_list_target,
- save_dir,
- save_file_prefix=save_file_prefix,
- )
- self.denormalizer = Denormalizer(self.dataset.dataset)
-
- self.save_data: Dict[str, Dict[str, List]]
- self.save_data = dict()
-
- fontlabel2fontname = dataset.dataset.dataset.features["font"].feature.int2str
-
- self.renderer = TextRenderer(dataset.data_dir, fontlabel2fontname)
-
- def eval_model(self) -> None:
- # Data generators
- dataloader = torch.utils.data.DataLoader(
- self.dataset,
- batch_size=self.batch_size,
- shuffle=False,
- num_workers=self.num_worker,
- pin_memory=True,
- collate_fn=collate_batch,
- )
- with torch.no_grad():
- self.steps = len(dataloader)
- self.step = 0
- self.cnt = 0
- self.model.eval()
- end = time.time()
- for inputs in dataloader:
- (
- design_context_list,
- model_input_batchdata,
- svg_id,
- index,
- ) = inputs
- self.eval_step(
- design_context_list,
- model_input_batchdata,
- svg_id,
- index,
- end,
- )
- end = time.time()
- self.step += 1
-
- self.show_scores()
- if self.dataset_division == "test":
- self.save_prediction()
-
- def eval_step(
- self,
- design_context_list: List[DesignContext],
- model_input_batchdata: List,
- svg_id: List,
- index: List,
- end: float,
- ) -> None:
- start = time.time()
- model_inputs = ModelInput(design_context_list, model_input_batchdata, self.gpu)
- predictions = self.model.prediction(
- model_inputs, self.dataset.dataset, self.prefix_list_target
- )
- data_index = svg_id[0]
-
- self.instance_data = EvalDataInstance(self.prefix_list_target)
-
- text_num = design_context_list[0].canvas_context.canvas_text_num
-
- self.save_data[data_index] = dict()
- for prefix in self.prefix_list_target:
- pred_token, gt_token, pred, gt = self.denormalizer.denormalize(
- prefix, text_num, predictions, design_context_list[0]
- )
- self.instance_data.rigister_att(
- text_num,
- prefix,
- pred_token,
- gt_token,
- pred,
- gt,
- )
- self.entire_data.update_prediction_data(
- data_index, self.instance_data, f"{prefix}"
- )
- self.save_data[data_index][prefix] = pred
- if hasattr(self.denormalizer.dataset, f"convert_{prefix}"):
- element_data = self.dataset.dataset.dataset[index[0]]
- text_ids = get_text_ids(element_data)
- canvas_h_scale = design_context_list[0].canvas_context.scale_box[0]
- converted_attribute_dict = self.denormalizer.convert_attributes(
- prefix,
- pred,
- element_data,
- text_ids,
- canvas_h_scale,
- )
- if converted_attribute_dict is not None:
- for prefix, v in converted_attribute_dict.items():
- self.save_data[data_index][prefix] = v
- self.entire_data.text_num[data_index] = text_num
-
- forward_time = time.time()
- if self.step % 200 == 0:
- data_show = "{}/{}/{}, forward_time: {:.3f} data {:.3f}".format(
- self.cnt,
- self.step + 1,
- self.steps,
- forward_time - start,
- (start - end),
- )
- logger.info(data_show)
-
- def show_scores(self) -> None:
- for prefix in show_classfication_score_att:
- if prefix in self.prefix_list_target:
- self.entire_data.show_classification_score(
- prefix, topk=5, show_topk=[0, 2, 4]
- )
- for prefix in show_abs_erros_att:
- if prefix in self.prefix_list_target:
- self.entire_data.show_abs_erros(prefix)
- if "text_font_color" in self.prefix_list_target:
- self.entire_data.show_font_color_scores()
- for prefix in show_bigram_score_att:
- if prefix in self.prefix_list_target:
- self.entire_data.show_structure_score(prefix)
- self.entire_data.show_alpha_overlap_score()
-
- def save_prediction(self) -> None:
- file_name = os.path.join(self.save_dir, "prediction.pkl")
- with open(file_name, mode="wb") as f:
- pickle.dump(self.save_data, f)
diff --git a/src/typography_generation/tools/loss.py b/src/typography_generation/tools/loss.py
deleted file mode 100644
index 2960fa6..0000000
--- a/src/typography_generation/tools/loss.py
+++ /dev/null
@@ -1,194 +0,0 @@
-from typing import Dict, List, Tuple, Union
-
-import numpy as np
-import torch
-from logzero import logger
-from torch import Tensor, nn
-from torch.functional import F
-from typography_generation.config.attribute_config import (
- TextElementContextPredictionAttributeConfig,
-)
-from typography_generation.io.data_object import ModelInput
-from typography_generation.model.common import Linearx3
-
-
-class LossFunc(object):
- def __init__(
- self,
- model_name: str,
- prefix_list_target: List,
- prediction_config_element: TextElementContextPredictionAttributeConfig,
- gpu: bool,
- topk: int = 5,
- ) -> None:
- super(LossFunc, self).__init__()
- self.prefix_list_target = prefix_list_target
- self.topk = topk
- for prefix in self.prefix_list_target:
- target_prediction_config = getattr(prediction_config_element, prefix)
- setattr(
- self, f"{prefix}_ignore_label", target_prediction_config.ignore_label
- )
- setattr(self, f"{prefix}_loss_type", target_prediction_config.loss_type)
- if model_name == "canvasvae":
- self.vae_weight = 0
- self.fix_vae_weight = 0.002
- if model_name == "mfc":
- self.d = Linearx3(40, 1)
- if gpu is True:
- self.d.cuda()
- self.d_optimizer = torch.optim.AdamW(
- self.d.parameters(),
- lr=0.0002,
- betas=(0.5, 0.999),
- weight_decay=0.01,
- )
-
- def get_loss(
- self, prefix: str, pred: Tensor, gt: Tensor, loss_type: str, training: bool
- ) -> Tensor:
- if loss_type == "cre":
- loc = gt != getattr(self, f"{prefix}_ignore_label")
- pred = pred[loc]
- gt = gt[loc]
- loss = F.cross_entropy(pred, gt.long())
- elif loss_type == "l1":
- loc = gt != getattr(self, f"{prefix}_ignore_label")
- pred = pred[loc]
- gt = gt[loc]
- logger.debug(f"get loss {prefix} {loss_type} {gt}")
- loss = F.l1_loss(pred.reshape(gt.shape), gt.float())
- elif loss_type == "mfc_gan":
- fake_output = self.d(pred.reshape(gt.shape).detach())
- real_output = self.d(gt)
- loc = gt[:, :, 0:1] != getattr(self, f"{prefix}_ignore_label")
- if training is True:
- real_output = real_output[loc]
- Ldadv = (
- fake_output.mean()
- - real_output.mean() # + self.lambda_gp * gradient_penalty
- )
- self.d_optimizer.zero_grad()
- Ldadv.backward()
- self.d_optimizer.step()
- fake_output = self.d(pred.reshape(gt.shape))
- fake_output = fake_output[loc]
- Lsadv = -fake_output.mean()
- loc = gt != getattr(self, f"{prefix}_ignore_label")
- pred = pred[loc]
- gt = gt[loc]
- loss = 10 * F.mse_loss(pred.reshape(gt.shape), gt.float()) + Lsadv
- else:
- raise NotImplementedError()
- return loss
-
- def update_vae_weight(self, epoch: int, epochs: int, step: int, steps: int) -> None:
- logger.debug("update vae weight")
- self.vae_weight = (epoch + step / steps) / epochs
- logger.debug(f"vae weight value {self.vae_weight}")
-
- def vae_loss(self, mu: Tensor, logsigma: Tensor) -> Tuple:
- loss_kl = -0.5 * torch.mean(
- 1 + logsigma - mu.pow(2) - torch.exp(logsigma)
- ) # vae kl divergence
- loss_kl_weighted = loss_kl * self.vae_weight * self.fix_vae_weight
- return loss_kl_weighted, loss_kl
-
- def get_vae_loss(self, vae_items: Tuple) -> Tensor:
- loss_kl_weighted = 0
- for mu, logsigma in vae_items:
- loss, _ = self.vae_loss(mu, logsigma)
- loss_kl_weighted += loss
- return loss_kl_weighted # +self.config.VAE_L2_LOSS_WEIGHT*loss_l2
-
- def __call__(self, model_inputs: ModelInput, preds: Dict, training: bool) -> Tuple:
- total_loss = 0
- record_items = {}
-
- for prefix in self.prefix_list_target:
- pred = preds[prefix]
- gt = getattr(model_inputs, prefix)
- loss_type = getattr(self, f"{prefix}_loss_type")
- loss = self.get_loss(
- prefix,
- pred,
- gt,
- loss_type,
- training,
- )
- pred_label, gt_data = self.get_pred_gt_label(prefix, pred, gt, loss_type)
- record_items[prefix] = (pred_label, gt_data, loss.item())
- total_loss += loss
-
- if "vae_data" in preds.keys():
- logger.debug("compute vae loss")
- vae_loss = self.get_vae_loss(preds["vae_data"])
- total_loss = total_loss + vae_loss
- logger.debug(f"vae loss {vae_loss}")
- return total_loss, record_items
-
- def get_pred_gt_label(
- self, prefix: str, pred: Tensor, gt: Tensor, loss_type: str
- ) -> Tuple[List, Union[List, np.array]]:
- loc = gt != getattr(self, f"{prefix}_ignore_label")
- pred = pred[loc]
- gt = gt[loc]
- pred_label = self.get_pred_label(pred, loss_type)
- gt_label = self.get_gt_label(gt, loss_type)
- return pred_label, gt_label
-
- def get_pred_label(
- self,
- pred: Tensor,
- loss_type: str = "cre",
- ) -> List:
- if loss_type == "cre":
- pred = torch.sort(input=pred, dim=1, descending=True)[1].data.cpu().numpy()
- preds = []
- for i in range(len(pred)):
- p = []
- for k in range(min(self.topk, pred.shape[1])):
- p.append(pred[i, k])
- preds.append(p)
- elif loss_type == "l1":
- preds = []
- for i in range(len(pred)):
- p = []
- for k in range(self.topk):
- p.append(0) # dummy data
- preds.append(p)
- elif loss_type == "mfc_gan":
- preds = []
- for i in range(len(pred)):
- p = []
- for k in range(self.topk):
- p.append(0) # dummy data
- preds.append(p)
- else:
- raise NotImplementedError()
- return preds
-
- def get_gt_label(
- self,
- gt: Tensor,
- loss_type: str = "cre",
- ) -> Union[List, np.array]:
- if loss_type == "cre":
- gt_labels = gt.data.cpu().numpy()
- elif loss_type == "l1":
- gt_labels = []
- for i in range(len(gt)):
- g = []
- for k in range(self.topk):
- g.append(0) # dummy data
- gt_labels.append(g)
- elif loss_type == "mfc_gan":
- gt_labels = []
- for i in range(len(gt)):
- g = []
- for k in range(self.topk):
- g.append(0) # dummy data
- gt_labels.append(g)
- else:
- raise NotImplementedError()
- return gt_labels
diff --git a/src/typography_generation/tools/prediction_recorder.py b/src/typography_generation/tools/prediction_recorder.py
deleted file mode 100644
index 9743990..0000000
--- a/src/typography_generation/tools/prediction_recorder.py
+++ /dev/null
@@ -1,141 +0,0 @@
-from typing import Dict, List
-from logzero import logger
-
-############################################################
-# Prediction Recoder
-############################################################
-
-
-class PredictionRecoder:
- def __init__(self, prefix_list_target: List, topk: int = 5):
- super(PredictionRecoder, self).__init__()
- self.prefix_list_target = prefix_list_target
- self.topk = topk
- self.show_topk = list(range(self.topk))
- self.register(prefix_list_target)
- self.epoch = 0
-
- def register(self, prefix_list_target: List) -> None:
- self.all_target_list = []
- self.cl_target_list = []
- self.loss_target_list = []
- for prefix in prefix_list_target:
- self.all_target_list.append(f"{prefix}_cl")
- self.all_target_list.append(f"{prefix}_loss")
- self.cl_target_list.append(f"{prefix}_cl")
- self.loss_target_list.append(f"{prefix}_loss")
- self.regist_cl_recorder_set(self.cl_target_list)
- self.regist_recorder_set(self.loss_target_list)
-
- def regist_recorder_set(self, target_list: List) -> None:
- for tar in target_list:
- self.regist_record_target(tar)
-
- def regist_cl_recorder_set(self, target_list: List) -> None:
- for tar in target_list:
- self.regist_cl_recorder(tar)
-
- def regist_cl_recorder(self, registration_name: str) -> None:
- self.regist_record_target(f"{registration_name}_pred")
- self.regist_record_target(f"{registration_name}_gt")
-
- def regist_record_target(self, registration_name: str) -> None:
- setattr(self, registration_name, [])
- setattr(self, f"{registration_name}_history", [])
- for k in range(len(self.show_topk)):
- setattr(self, f"{registration_name}_history_{self.show_topk[k]}", [])
-
- def __call__(self, recoder_items: Dict) -> None:
- for name, (pred, gt, loss) in recoder_items.items():
- clname = f"{name}_cl"
- lossname = f"{name}_loss"
- if clname in self.cl_target_list:
- for p, g in zip(pred, gt):
- getattr(self, f"{clname}_pred").append(p)
- getattr(self, f"{clname}_gt").append(g)
- if lossname in self.loss_target_list:
- getattr(self, lossname).append(loss)
-
- def reset(self) -> None:
- for name in self.all_target_list:
- if name in self.cl_target_list:
- setattr(self, f"{name}_pred", [])
- setattr(self, f"{name}_gt", [])
- if name in self.loss_target_list:
- setattr(self, name, [])
-
- def store_score(self) -> None:
- for name in self.all_target_list:
- if name in self.cl_target_list:
- for k in range(len(self.show_topk)):
- getattr(self, f"{name}_pred_history_{self.show_topk[k]}").append(
- getattr(self, f"{name}{self.show_topk[k]}_acc")
- )
- if name in (self.loss_target_list):
- getattr(self, f"{name}_history").append(getattr(self, f"{name}_mean"))
-
- def compute_score(self) -> None:
- for name in self.all_target_list:
- if name in self.cl_target_list:
- topkacc = {}
- for k in range(self.topk):
- topkacc[k] = 0
- for p, g in zip(
- getattr(self, f"{name}_pred"), getattr(self, f"{name}_gt")
- ):
- flag = 0
- for k in range(min(self.topk, len(p))):
- if p[k] == g:
- flag = 1
- topkacc[k] += flag
- for k in range(min(self.topk, len(topkacc))):
- setattr(
- self,
- f"{name}{k}_acc",
- topkacc[k] / max(len(getattr(self, f"{name}_pred")), 1),
- )
- if name in (self.loss_target_list):
- mean_loss = 0
- for l in getattr(self, f"{name}"):
- mean_loss += l
- setattr(
- self,
- f"{name}_mean",
- mean_loss / max(len(getattr(self, f"{name}")), 1),
- )
- score_dict = self.show_scores()
- return score_dict
-
- def show_scores(self) -> None:
- score_dict = {}
- for name in self.all_target_list:
- if name in self.cl_target_list:
- for k in range(len(self.show_topk)):
- logger.info(
- f"{name}{self.show_topk[k]}_acc:{getattr(self, f'{name}{self.show_topk[k]}_acc')}"
- )
- score_dict[name] = getattr(self, f"{name}{self.show_topk[0]}_acc")
- if name in self.loss_target_list:
- logger.info(f"{name}_mean:{getattr(self, f'{name}_mean')}")
- score_dict[name] = getattr(self, f"{name}_mean")
- return score_dict
-
- def show_history_scores(self) -> None:
- for i in range(self.epoch):
- logger.info(f"epoch {i}")
- for name in self.all_target_list:
- if name in self.cl_target_list:
- for k in range(len(self.show_topk)):
- logger.info(
- f"{name}{self.show_topk[k]}acc:{getattr(self, f'{name}_pred_history_{self.show_topk[k]}')[i]}"
- )
- if name in self.loss_target_list:
- logger.info(f"{name}_mean:{getattr(self, f'{name}_history')[i]}")
-
- def step_epoch(self) -> None:
- self.compute_score()
- self.store_score()
- self.update_epoch()
-
- def update_epoch(self) -> None:
- self.epoch += 1
diff --git a/src/typography_generation/tools/sampler.py b/src/typography_generation/tools/sampler.py
deleted file mode 100644
index 1808e32..0000000
--- a/src/typography_generation/tools/sampler.py
+++ /dev/null
@@ -1,150 +0,0 @@
-import time
-from typing import List
-
-import torch
-import torch.nn as nn
-import torch.utils.data
-from typography_generation.io.data_loader import CrelloLoader
-from typography_generation.tools.score_func import EvalDataInstance
-from typography_generation.io.data_object import (
- DataPreprocessConfig,
- DesignContext,
- FontConfig,
- ModelInput,
- PrefixListObject,
- SamplingConfig,
-)
-from logzero import logger
-
-from typography_generation.tools.train import collate_batch
-
-from typography_generation.tools.evaluator import (
- Evaluator,
-)
-
-
-############################################################
-# Sampler
-############################################################
-class Sampler(Evaluator):
- def __init__(
- self,
- model: nn.Module,
- gpu: bool,
- save_dir: str,
- dataset: CrelloLoader,
- prefix_list_object: PrefixListObject,
- sampling_config: SamplingConfig,
- batch_size: int = 1,
- num_worker: int = 2,
- show_interval: int = 100,
- dataset_division: str = "test",
- debug: bool = False,
- ) -> None:
- super().__init__(
- model,
- gpu,
- save_dir,
- dataset,
- prefix_list_object,
- batch_size=batch_size,
- num_worker=num_worker,
- show_interval=show_interval,
- dataset_division=dataset_division,
- debug=debug,
- )
- self.sampling_config = sampling_config
-
- def sample(self) -> None:
- # Data generators
- dataloader = torch.utils.data.DataLoader(
- self.dataset,
- batch_size=self.batch_size,
- shuffle=False,
- num_workers=self.num_worker,
- pin_memory=True,
- collate_fn=collate_batch,
- )
- with torch.no_grad():
- self.steps = len(dataloader)
- self.step = 0
- self.cnt = 0
- self.model.eval()
- end = time.time()
- for inputs in dataloader:
- design_context_list, model_input_batchdata, svg_id, _ = inputs
- self.sample_step(
- design_context_list, model_input_batchdata, svg_id, end
- )
- end = time.time()
- self.step += 1
-
- self.show_scores()
- self.entire_data.show_diversity_scores(self.prefix_list_target)
- self.save_prediction()
-
- def sample_step(
- self,
- design_context_list: List[DesignContext],
- model_input_batchdata: List,
- svg_id: List,
- end: float,
- ) -> None:
- _data_index = svg_id[0]
- self.entire_data.data_index_list.append(_data_index)
- for iter in range(self.sampling_config.sampling_num):
- data_index = f"{_data_index}_{iter}"
- self.sample_iter(
- design_context_list, model_input_batchdata, data_index, end
- )
-
- def sample_iter(
- self,
- design_context_list: List[DesignContext],
- model_input_batchdata: List,
- data_index: str,
- end: float,
- ) -> None:
- start = time.time()
- model_inputs = ModelInput(design_context_list, model_input_batchdata, self.gpu)
- sampler_input = {
- "model_inputs": model_inputs,
- "target_prefix_list": self.prefix_list_target,
- "sampling_param_geometry": self.sampling_config.sampling_param_geometry,
- "sampling_param_semantic": self.sampling_config.sampling_param_semantic,
- }
- predictions = self.model.sample(**sampler_input)
-
- self.instance_data = EvalDataInstance(self.prefix_list_target)
-
- text_num = design_context_list[0].canvas_context.canvas_text_num
-
- self.save_data[data_index] = dict()
- for prefix in self.prefix_list_target:
- pred_token, gt_token, pred, gt = self.denormalizer.denormalize(
- prefix, text_num, predictions, design_context_list[0]
- )
- self.instance_data.rigister_att(
- text_num,
- prefix,
- pred_token,
- gt_token,
- pred,
- gt,
- )
- self.entire_data.update_prediction_data(
- data_index, self.instance_data, f"{prefix}"
- )
- self.save_data[data_index][prefix] = pred
- self.entire_data.text_num[data_index] = text_num
-
- forward_time = time.time()
- if self.step % 200 == 0:
- data_show = "{}/{}/{}, forward_time: {:.3f} data {:.3f}".format(
- self.cnt,
- self.step + 1,
- self.steps,
- forward_time - start,
- (start - end),
- )
- logger.info(data_show)
diff --git a/src/typography_generation/tools/score_func.py b/src/typography_generation/tools/score_func.py
deleted file mode 100644
index f8a1709..0000000
--- a/src/typography_generation/tools/score_func.py
+++ /dev/null
@@ -1,657 +0,0 @@
-import itertools
-from typing import Dict, List, Tuple
-
-import numpy as np
-from _io import TextIOWrapper
-from logzero import logger as log
-from typography_generation.tools.color_func import deltaE_ciede2000, rgb2lab
-
-
-class EvalDataInstance:
- def __init__(
- self,
- attribute_list: List,
- ) -> None:
- self.attribute_list = attribute_list
- self.target_list = ["pred_token", "gt_token", "pred", "gt"]
- self.reset()
-
- def reset(self) -> None:
- for att in self.attribute_list:
- for tar in self.target_list:
- registration_name = f"{att}_{tar}"
- setattr(self, registration_name, [])
-
- def rigister_att(
- self,
- text_num: int,
- prefix: str,
- target_pred_token: np.array,
- target_gt_token: np.array,
- target_pred: np.array,
- target_gt: np.array,
- start_index: int = 0,
- ) -> None:
- for i in range(start_index, text_num):
- registration_name = f"{prefix}_pred_token"
- getattr(self, registration_name).append(target_pred_token[i])
- registration_name = f"{prefix}_gt_token"
- getattr(self, registration_name).append(target_gt_token[i])
- registration_name = f"{prefix}_pred"
- getattr(self, registration_name).append(target_pred[i])
- registration_name = f"{prefix}_gt"
- getattr(self, registration_name).append(target_gt[i])
-
-
-class EvalDataEntire:
- def __init__(
- self, attribute_list: List, save_dir: str, save_file_prefix: str = "score"
- ) -> None:
- self.attribute_list = attribute_list
- self.text_num = {}
- self.overlap_scores = {}
- self.data_index_list = []
- self.target_list = ["pred_token", "gt_token", "pred", "gt"]
- for att in self.attribute_list:
- for tar in self.target_list:
- registration_name = f"{att}_{tar}"
- setattr(self, registration_name, {})
-
- save_file = f"{save_dir}/{save_file_prefix}.txt"
- self.f = open(save_file, "w")
-
- def update_prediction_data(
- self,
- index: str,
- instance_obj: EvalDataInstance,
- prefix: str,
- ) -> None:
- getattr(self, f"{prefix}_pred_token")[index] = getattr(
- instance_obj, f"{prefix}_pred_token"
- )
- getattr(self, f"{prefix}_gt_token")[index] = getattr(
- instance_obj, f"{prefix}_gt_token"
- )
- getattr(self, f"{prefix}_pred")[index] = getattr(instance_obj, f"{prefix}_pred")
- getattr(self, f"{prefix}_gt")[index] = getattr(instance_obj, f"{prefix}_gt")
-
- def update_sampling_data(
- self,
- index: str,
- instance_obj: EvalDataInstance,
- prefix: str,
- ) -> None:
- primary_index, sub_index = index.split("_")
- if int(sub_index) > 0:
- getattr(self, prefix)[primary_index].append(getattr(instance_obj, prefix))
- else:
- getattr(self, prefix)[primary_index] = []
- getattr(self, prefix)[primary_index].append(getattr(instance_obj, prefix))
-
- def show_classification_score(
- self, att: str, topk: int = 10, show_topk: List = [0, 5]
- ) -> None:
- log.info(f"{att}")
- self.f.write(f"{att} \n")
- compute_score(
- getattr(self, f"{att}_pred"),
- getattr(self, f"{att}_gt"),
- topk,
- show_topk,
- self.f,
- )
-
- def show_abs_erros(
- self, prefix: str, blanktype: str = "", topk: int = 10, show_topk: List = [0, 5]
- ) -> None:
- log.info(f"{prefix}")
- self.f.write(f"{prefix} \n")
- compute_abs_error_score(
- getattr(self, f"{prefix}_pred"),
- getattr(self, f"{prefix}_gt"),
- self.f,
- )
-
- def show_font_color_scores(
- self, blanktype: str = "", topk: int = 10, show_topk: List = [0, 5]
- ) -> None:
- log.info(f"font_color")
- self.f.write(f"font_color \n")
- compute_color_score(
- getattr(self, f"text_font_color_pred"),
- getattr(self, f"text_font_color_gt"),
- self.f,
- )
-
- def show_structure_score(self, att: str) -> None:
- log.info(f"{att}")
- self.f.write(f"{att} \n")
- compute_bigram_score(
- getattr(self, f"text_num"),
- getattr(self, f"{att}_pred_token"),
- getattr(self, f"{att}_gt_token"),
- self.f,
- )
-
- def show_visual_similarity_scores(self) -> None:
- registration_name = "l2error"
- dict_average_score(registration_name, getattr(self, registration_name))
- registration_name = "psnr"
- dict_average_score(registration_name, getattr(self, registration_name))
-
- def show_time_score(self) -> None:
- registration_name = "time"
- dict_average_score(registration_name, getattr(self, registration_name))
-
- def show_diversity_scores(self, attribute_list: List) -> None:
- for att in attribute_list:
- log.info(f"{att}")
- self.f.write(f"{att} \n")
- compute_label_diversity_score(
- getattr(self, "data_index_list"),
- getattr(self, f"text_num"),
- getattr(self, f"{att}_pred_token"),
- self.f,
- )
-
- def show_alpha_overlap_score(self) -> None:
- compute_alpha_overlap(
- getattr(self, "overlap_scores"),
- self.f,
- )
-
-
-def compute_abs_error_score(
- eval_list: dict,
- gt_list: dict,
- f: TextIOWrapper,
-) -> None:
- cnt_img = 0
- l1_distance = 0.0
- r_l1_distance = 0.0
- for index in eval_list.keys():
- g = gt_list[index]
- if len(g) > 0:
- e = eval_list[index]
- d = 0.0
- rd = 0.0
- for i, (pred, gt) in enumerate(zip(e, g)):
- pred = pred[0]
- _d = abs(gt - pred)
- _rd = abs(gt - pred) / max(abs(float(gt)), 1e-5)
- d += _d
- rd += _rd
- l1_distance += d / len(g)
- r_l1_distance += rd / len(g)
- cnt_img += 1
- l1_distance /= cnt_img
- log.info(f"l1_distance {l1_distance}")
- f.write(f"l1_distance {l1_distance} \n")
- r_l1_distance /= cnt_img
- log.info(f"r_l1_distance {r_l1_distance}")
- f.write(f"r_l1_distance {r_l1_distance} \n")
-
-
-def compute_color_score(
- font_color_eval_list: dict,
- font_color_gt_list: dict,
- f: TextIOWrapper,
-) -> None:
- cnt_img = 0
- color_distance = 0.0
- for index in font_color_eval_list.keys():
- e = font_color_eval_list[index]
- g = font_color_gt_list[index]
- if len(g) > 0:
- d = 0.0
- for i, (pred, gt) in enumerate(zip(e, g)):
- pred = pred[0]
- lab_p = rgb2lab(np.array(pred).reshape(1, 1, 3).astype(np.float32))
- lab_g = rgb2lab(np.array(gt).reshape(1, 1, 3).astype(np.float32))
- _d = deltaE_ciede2000(lab_p, lab_g)
- d += _d[0][0]
- color_distance += d / len(g)
- cnt_img += 1
- color_distance /= cnt_img
- log.info(f"color_distance {color_distance}")
- f.write(f"color_distance {color_distance} \n")
-
-
-def compute_score(
- eval_list: dict,
- gt_list: dict,
- topk: int,
- show_topk: List,
- f: TextIOWrapper,
-) -> Tuple:
- cnt_elm = 0
- cnt_img = 0
- topk_acc_elm = {}
- topk_acc_img = {}
- for k in range(topk):
- topk_acc_elm[k] = 0.0
- topk_acc_img[k] = 0.0
- for index in eval_list.keys():
- e = eval_list[index]
- g = gt_list[index]
- topk_acc_tmp = {}
- for k in range(topk):
- topk_acc_tmp[k] = 0.0
- if len(g) > 0:
- cnt_gt = 0
- for i, gt in enumerate(g):
- flag = 0
- for k in range(min(topk, len(e[i]))):
- if int(gt) == int(e[i][k]):
- flag = 1.0
- topk_acc_elm[k] += flag
- topk_acc_tmp[k] += flag
- cnt_gt += 1
- cnt_elm += 1
- for k in range(topk):
- if cnt_gt > 0:
- topk_acc_img[k] += topk_acc_tmp[k] / cnt_gt
- cnt_img += 1
- for k in range(topk):
- topk_acc_elm[k] /= cnt_elm
- topk_acc_img[k] /= cnt_img
- if k in show_topk:
- log.info(f"top{k} img_level_acc {topk_acc_img[k]}")
- f.write(f"top{k} img_level_acc {topk_acc_img[k]} \n")
- return topk_acc_elm, topk_acc_img
-
-
-def dict_average_score(prefix: str, score_dict: dict) -> None:
- score_mean = 0
- cnt = 0
- for index in score_dict.keys():
- score_mean += score_dict[index]
- cnt += 1
- log.info("{} {}".format(prefix, score_mean / cnt))
-
-
-def compute_unigram_label_score(
- text_num: int,
- text_target_mask: np.array,
- font_pred: np.array,
- font_gt: np.array,
- ignore_labels: List[int],
-) -> float:
- cnt = 0
- correct_cnt = 0
- for i in range(text_num):
- fi = font_pred[i]
- if fi in ignore_labels or text_target_mask[i] == 0:
- continue
- else:
- cnt += 1
- for j in range(text_num):
- fj = int(font_gt[j])
- if fi == fj:
- correct_cnt += 1
- break
- if cnt > 0:
- score = float(correct_cnt) / cnt
- else:
- score = 0
- return score
-
-
-def compute_bigram_label_score(
- pred: np.array,
- gt: np.array,
-) -> float:
- text_num = len(gt)
- text_cmb = list(itertools.combinations(list(range(text_num)), 2))
- cnt = 0
- correct_cnt = 0
- for pi, pj in text_cmb:
- fpi = pred[pi][0]
- fpj = pred[pj][0]
- cnt += 1
- for gi, gj in text_cmb:
- fgi = int(gt[gi])
- fgj = int(gt[gj])
- if (fpi == fgi) and (fpj == fgj):
- correct_cnt += 1
- break
- elif (fpi == fgj) and (fpj == fgi):
- correct_cnt += 1
- break
- if cnt > 0:
- score = float(correct_cnt) / cnt
- else:
- score = 0
- return score
-
-
-def get_binary_classification_scores(
- l11cnt: int, l00cnt: int, l10cnt: int, l01cnt: int
-) -> Tuple:
- if l11cnt + l10cnt > 0:
- precision = float(l11cnt) / (l11cnt + l10cnt)
- else:
- precision = 0
-
- if l11cnt + l01cnt > 0:
- recall = float(l11cnt) / (l11cnt + l01cnt)
- else:
- recall = 0
-
- if l00cnt + l01cnt > 0:
- precision_inv = float(l00cnt) / (l00cnt + l01cnt)
- else:
- precision_inv = 0
-
- if l00cnt + l10cnt > 0:
- recall_inv = float(l00cnt) / (l00cnt + l10cnt)
- else:
- recall_inv = 0
-
- if precision + recall > 0:
- fvalue = 2 * precision * recall / (precision + recall)
- else:
- fvalue = 0
-
- if 2 * l11cnt + l01cnt + l10cnt == 0:
- _fvalue = np.nan
- else:
- _fvalue = 2 * l11cnt / (2 * l11cnt + l01cnt + l10cnt)
-
- if precision_inv + recall_inv > 0:
- fvalue_inv = 2 * precision_inv * recall_inv / (precision_inv + recall_inv)
- else:
- fvalue_inv = 0
-
- if 2 * l00cnt + l01cnt + l10cnt == 0:
- _fvalue_inv = np.nan
- else:
- _fvalue_inv = 2 * l00cnt / (2 * l00cnt + l01cnt + l10cnt)
-
- if l11cnt + l00cnt + l01cnt + l10cnt > 0:
- accuracy = float(l11cnt + l00cnt) / (l11cnt + l00cnt + l01cnt + l10cnt)
- else:
- accuracy = 0
- if l00cnt + l01cnt > 0:
- spcecificity = float(l00cnt) / (l00cnt + l01cnt)
- else:
- spcecificity = 0
- return (
- accuracy,
- spcecificity,
- precision,
- recall,
- fvalue,
- precision_inv,
- recall_inv,
- fvalue_inv,
- _fvalue,
- _fvalue_inv,
- )
-
-
-def compute_bigram_structure_score(
- pred: np.array,
- gt: np.array,
-) -> Tuple:
- text_num = len(gt)
- text_cmb = list(itertools.combinations(list(range(text_num)), 2))
- l11cnt = 0
- l00cnt = 0
- l10cnt = 0
- l01cnt = 0
- for pi, pj in text_cmb:
- fpi = pred[pi][0]
- fpj = pred[pj][0]
- fgi = int(gt[pi])
- fgj = int(gt[pj])
- if (fpi == fpj) and (fgi == fgj):
- l11cnt += 1
- # l00cnt += 1
- if (fpi != fpj) and (fgi != fgj):
- l00cnt += 1
- # l11cnt += 1
- if (fpi != fpj) and (fgi == fgj):
- l10cnt += 1
- # l01cnt += 1
- if (fpi == fpj) and (fgi != fgj):
- l01cnt += 1
- # l10cnt += 1
- scores = get_binary_classification_scores(l11cnt, l00cnt, l10cnt, l01cnt)
- return scores, (l11cnt, l00cnt, l10cnt, l01cnt)
-
-
-def get_structure_type(
- gt: np.array,
-) -> float:
- text_num = len(gt)
- text_cmb = list(itertools.combinations(list(range(text_num)), 2))
- consistency_num = 0
- contrast_num = 0
- for pi, pj in text_cmb:
- fgi = int(gt[pi])
- fgj = int(gt[pj])
- if fgi == fgj:
- consistency_num += 1
- else:
- contrast_num += 1
- if text_num <= 1:
- flag = 0 # uncount
- elif consistency_num == 0:
- flag = 1 # no consistency
- elif contrast_num == 0:
- flag = 2 # no contrast
- else:
- flag = 3 # others
- return flag
-
-
-def compute_bigram_score(
- text_num_list: dict,
- pred_list: dict,
- gt_list: dict,
- f: TextIOWrapper,
-) -> Tuple:
- cnt = 0
- structure_accuracy_mean = 0.0
- structure_precision_mean = 0.0
- structure_recall_mean = 0.0
- structure_fvalue_mean = 0.0
- structure_spcecificity_mean = 0.0
- structure_precision_inv_mean = 0.0
- structure_recall_inv_mean = 0.0
- structure_fvalue_inv_mean = 0.0
- label_score_mean = 0.0
- l11cnt_all = 0
- l00cnt_all = 0
- l10cnt_all = 0
- l01cnt_all = 0
- diff_case_scores = {}
- diff_case_counts = {}
- diff_case_scores[1] = 0.0 # no consistency
- diff_case_scores[2] = 0.0 # no contrast
- diff_case_scores[3] = 0.0 # others
- diff_case_counts[1] = 0
- diff_case_counts[2] = 0
- diff_case_counts[3] = 0
- structure_nanmean = []
- for index in pred_list.keys():
- text_num = text_num_list[index]
- if text_num == 0:
- continue
- pred = pred_list[index]
- gt = gt_list[index]
- flag = get_structure_type(gt)
- scores, counts = compute_bigram_structure_score(pred, gt)
- (
- structure_accuracy,
- structure_spcecificity,
- structure_precision,
- structure_recall,
- structure_fvalue,
- structure_precision_inv,
- structure_recall_inv,
- structure_fvalue_inv,
- _structure_fvalue,
- _structure_fvalue_inv,
- ) = scores
- l11cnt, l00cnt, l10cnt, l01cnt = counts
- l11cnt_all += l11cnt
- l00cnt_all += l00cnt
- l10cnt_all += l10cnt
- l01cnt_all += l01cnt
- label_score = compute_bigram_label_score(pred, gt)
- structure_accuracy_mean += structure_accuracy
- structure_spcecificity_mean += structure_spcecificity
- structure_precision_mean += structure_precision
- structure_recall_mean += structure_recall
- structure_fvalue_mean += structure_fvalue
- structure_precision_inv_mean += structure_precision_inv
- structure_recall_inv_mean += structure_recall_inv
- structure_fvalue_inv_mean += structure_fvalue_inv
- label_score_mean += label_score
-
- if flag == 0:
- pass
- elif flag == 1: # no consistency
- diff_case_scores[1] += structure_fvalue_inv
- diff_case_counts[1] += 1
- elif flag == 2: # no contrast
- diff_case_scores[2] += structure_fvalue
- diff_case_counts[2] += 1
- elif flag == 3: # others
- diff_case_scores[3] += (structure_fvalue + structure_fvalue_inv) / 2.0
- diff_case_counts[3] += 1
- structure_nanmean.append(_structure_fvalue)
- structure_nanmean.append(_structure_fvalue_inv)
-
- cnt += 1
- structure_accuracy_mean /= cnt
- structure_spcecificity_mean /= cnt
- structure_precision_mean /= cnt
- structure_recall_mean /= cnt
- structure_fvalue_mean /= cnt
- structure_precision_inv_mean /= cnt
- structure_recall_inv_mean /= cnt
- structure_fvalue_inv_mean /= cnt
- label_score_mean /= cnt
- log.info("structure_accuracy {:.3f}".format(structure_accuracy_mean))
- f.write("structure_accuracy {:.3f} \n".format(structure_accuracy_mean))
-
- # log.info("label_score {:.3f}".format(label_score_mean))
- # f.write("label_score {:.3f} \n".format(label_score_mean))
- log.info("structure nanmean {:.3f}".format(np.nanmean(structure_nanmean)))
- f.write("structure nanmean {:.3f} \n".format(np.nanmean(structure_nanmean)))
- for i in range(1, 4):
- if diff_case_counts[i] > 0:
- log.info(
- "structure_case_score{} count:{} {:.3f}".format(
- i, diff_case_counts[i], diff_case_scores[i] / diff_case_counts[i]
- )
- )
- f.write(
- "structure_case_score{} count:{} {:.3f} \n".format(
- i, diff_case_counts[i], diff_case_scores[i] / diff_case_counts[i]
- )
- )
- else:
- log.info("structure_case_score{} count:{} -".format(i, diff_case_counts[i]))
- f.write(
- "structure_case_score{} count:{} - \n".format(i, diff_case_counts[i])
- )
-
- scores = get_binary_classification_scores(
- l11cnt_all, l00cnt_all, l10cnt_all, l01cnt_all
- )
- (
- structure_accuracy,
- structure_spcecificity,
- structure_precision,
- structure_recall,
- structure_fvalue,
- structure_precision_inv,
- structure_recall_inv,
- structure_fvalue_inv,
- _structure_fvalue,
- _structure_fvalue_inv,
- ) = scores
- return structure_accuracy_mean, label_score_mean
-
-
-def compute_label_diversity_score(
- data_index_list: List,
- text_num_list: Dict,
- pred_list: Dict,
- f: TextIOWrapper,
- sampling_num: int = 10,
-) -> None:
- def compute_label_diversity(pred_labels: List, text_num: int) -> float:
- unique_num_rate_avg = 0.0
- for k in range(text_num):
- labels = []
- for j in range(len(pred_labels)):
- l = int(pred_labels[j][k][0])
- labels.append(l)
- unique_num_rate = len(set(labels)) / float(len(pred_labels))
- unique_num_rate_avg += unique_num_rate
- unique_num_rate_avg /= text_num
- return unique_num_rate_avg
-
- label_diversity_avg = 0.0
- cnt = 0
- for index in data_index_list:
- text_num = text_num_list[f"{index}_0"]
- pred_lists = []
- for n in range(sampling_num):
- preds = pred_list[f"{index}_{n}"]
- pred_lists.append(preds)
- if text_num > 0:
- label_diversity = compute_label_diversity(pred_lists, text_num)
- label_diversity_avg += label_diversity
- cnt += 1
- label_diversity_avg /= cnt
- log.info("diversity score {:.1f}".format(label_diversity_avg * 100))
- f.write("diversity score {:.1f}\n".format(label_diversity_avg * 100))
-
-
-def compute_alpha_overlap(
- overlap_scores: Dict,
- f: TextIOWrapper,
-) -> None:
- overlap_score_all = 0
- cnt_all = 0
- data_index_list = list(overlap_scores.keys())
- for index in data_index_list:
- overlap_score = overlap_scores[f"{index}"]
- if overlap_score is not None:
- overlap_score_all += overlap_score
- cnt_all += 1
- if cnt_all > 0:
- overlap_score_all = overlap_score_all / cnt_all
-
- log.info("alpha overlap score {:.2f}".format(overlap_score_all))
- f.write("alpha overlap score {:.2f}\n".format(overlap_score_all))
-
-
-def _compute_alpha_overlap(
- alpha_map_list: List,
-) -> None:
- overlap_score = 0
- cnt = 0
- for i in range(len(alpha_map_list)):
- alpha_i = alpha_map_list[i]
- for j in range(len(alpha_map_list)):
- if i == j:
- continue
- else:
- alpha_j = alpha_map_list[j]
- overlap = np.sum(alpha_i * alpha_j)
- if np.sum(alpha_i) > 0:
- recall = overlap / np.sum(alpha_i)
- overlap_score += recall
- cnt += 1
- if cnt > 0:
- overlap_score = overlap_score / cnt
- return overlap_score
- else:
- return None
diff --git a/src/typography_generation/tools/structure_preserved_sampler.py b/src/typography_generation/tools/structure_preserved_sampler.py
deleted file mode 100644
index d8e5c96..0000000
--- a/src/typography_generation/tools/structure_preserved_sampler.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import time
-from typing import List
-
-import torch
-import torch.nn as nn
-import torch.utils.data
-from typography_generation.io.data_loader import CrelloLoader
-from typography_generation.tools.sampler import Sampler
-from typography_generation.tools.score_func import EvalDataInstance
-from typography_generation.io.data_object import (
- DataPreprocessConfig,
- DesignContext,
- FontConfig,
- ModelInput,
- PrefixListObject,
- SamplingConfig,
-)
-from logzero import logger
-
-
-############################################################
-# Structure Preserved Sampler
-############################################################
-class StructurePreservedSampler(Sampler):
- def __init__(
- self,
- model: nn.Module,
- gpu: bool,
- save_dir: str,
- dataset: CrelloLoader,
- prefix_list_object: PrefixListObject,
- sampling_config: SamplingConfig,
- batch_size: int = 1,
- num_worker: int = 2,
- show_interval: int = 100,
- dataset_division: str = "test",
- debug: bool = False,
- ) -> None:
- super().__init__(
- model,
- gpu,
- save_dir,
- dataset,
- prefix_list_object,
- sampling_config,
- batch_size=batch_size,
- num_worker=num_worker,
- show_interval=show_interval,
- dataset_division=dataset_division,
- debug=debug,
- )
-
- def sample_iter(
- self,
- design_context_list: List[DesignContext],
- model_input_batchdata: List,
- data_index: str,
- end: float,
- ) -> None:
- start = time.time()
- model_inputs = ModelInput(design_context_list, model_input_batchdata, self.gpu)
- sampler_input = {
- "model_inputs": model_inputs,
- "dataset": self.dataset.dataset,
- "target_prefix_list": self.prefix_list_target,
- "sampling_param_geometry": self.sampling_config.sampling_param_geometry,
- "sampling_param_semantic": self.sampling_config.sampling_param_semantic,
- }
- predictions = self.model.structure_preserved_sample(**sampler_input)
-
- self.instance_data = EvalDataInstance(self.prefix_list_target)
-
- text_num = design_context_list[0].canvas_context.canvas_text_num
-
- self.save_data[data_index] = dict()
- for prefix in self.prefix_list_target:
- pred_token, gt_token, pred, gt = self.denormalizer.denormalize(
- prefix, text_num, predictions, design_context_list[0]
- )
- self.instance_data.rigister_att(
- text_num,
- prefix,
- pred_token,
- gt_token,
- pred,
- gt,
- )
- self.entire_data.update_prediction_data(
- data_index, self.instance_data, f"{prefix}"
- )
- self.save_data[data_index][prefix] = pred
- self.entire_data.text_num[data_index] = text_num
-
- forward_time = time.time()
- # if self.step % 200 == 0:
- data_show = "{}/{}/{}, forward_time: {:.3f} data {:.3f}".format(
- self.cnt,
- self.step + 1,
- self.steps,
- forward_time - start,
- (start - end),
- )
- logger.info(data_show)
diff --git a/src/typography_generation/tools/tokenizer.py b/src/typography_generation/tools/tokenizer.py
deleted file mode 100644
index 178c85f..0000000
--- a/src/typography_generation/tools/tokenizer.py
+++ /dev/null
@@ -1,95 +0,0 @@
-import pickle
-from typing import Dict, Tuple, Union
-from einops import repeat
-import numpy as np
-from torch import Tensor
-import torch
-
-default_cluster_num_dict = {
- "text_font_size": 16,
- "text_font_color": 64,
- "text_height": 16,
- "text_width": 16,
- "text_top": 64,
- "text_left": 64,
- "text_center_y": 64,
- "text_center_x": 64,
- "text_angle": 16,
- "text_letter_spacing": 16,
- "text_line_height_scale": 16,
- "canvas_aspect_ratio": 16,
-}
-
-
-class Tokenizer:
- def __init__(
- self,
- data_dir: str,
- cluster_num_dict: Union[Dict, None] = None,
- load_cluster: bool = True,
- ) -> None:
- self.prefix_list = [
- "text_font_size",
- "text_font_color",
- "text_height",
- "text_width",
- "text_top",
- "text_left",
- "text_center_y",
- "text_center_x",
- "text_angle",
- "text_letter_spacing",
- "text_line_height_scale",
- "canvas_aspect_ratio",
- ]
- self.rawdata2token = {
- "text_font_emb": "text_font",
- "text_font_size_raw": "text_font_size",
- }
- self.rawdata_list = list(self.rawdata2token.keys())
- self.rawdata_out_format = {
- "text_font_emb": "token",
- "text_font_size_raw": "raw",
- }
-
- self.prediction_token_list = [
- "text_font_emb",
- ]
- if cluster_num_dict is None:
- cluster_num_dict = default_cluster_num_dict
- if load_cluster is True:
- for prefix in self.prefix_list:
- cluster_num = cluster_num_dict[prefix]
- fn = f"{data_dir}/cluster/{prefix}_{cluster_num}.pkl"
- if prefix == "text_font_color":
- cluster = np.array(pickle.load(open(fn, "rb")))
- else:
- cluster = np.array(pickle.load(open(fn, "rb"))).flatten()
- setattr(self, f"{prefix}_cluster", cluster)
-
- def assign_label(self, val: Union[float, int], bins: np.array) -> int:
- label = int(np.argsort(np.square(bins - val))[0])
- return label
-
- def assign_color_label(self, val: Union[float, int], bins: np.array) -> int:
- val = np.tile(np.array(val)[np.newaxis, :], (len(bins), 1))
- d = np.square(bins - val).sum(1)
- label = int(np.argsort(d, axis=0)[0])
- return label
-
- def tokenize(self, prefix: str, val: Union[float, int]) -> int:
- if prefix == "text_font_color":
- label = self.assign_color_label(val, getattr(self, f"{prefix}_cluster"))
- else:
- label = self.assign_label(val, getattr(self, f"{prefix}_cluster"))
- return label
-
- def detokenize(self, prefix: str, label: Union[int, float]) -> Union[Tuple, float]:
- if prefix == "text_font_color":
- b, g, r = getattr(self, f"{prefix}_cluster")[int(label)]
- return (r, g, b)
- elif prefix == "text_font_size_raw":
- val = float(getattr(self, f"text_font_size_cluster")[int(label)])
- else:
- val = float(getattr(self, f"{prefix}_cluster")[int(label)])
- return val
diff --git a/src/typography_generation/tools/train.py b/src/typography_generation/tools/train.py
deleted file mode 100644
index c7e18b9..0000000
--- a/src/typography_generation/tools/train.py
+++ /dev/null
@@ -1,293 +0,0 @@
-import gc
-import os
-import re
-import time
-from typing import Any, List, Tuple, Union
-
-import datasets
-import torch
-import torch.nn as nn
-import torch.utils.data
-from logzero import logger
-from torch.utils.data._utils.collate import default_collate
-from torch.utils.tensorboard import SummaryWriter
-
-from typography_generation.config.attribute_config import (
- TextElementContextPredictionAttributeConfig,
-)
-from typography_generation.io.data_loader import CrelloLoader
-from typography_generation.io.data_object import (
- ModelInput,
- PrefixListObject,
-)
-from typography_generation.tools.loss import LossFunc
-from typography_generation.tools.prediction_recorder import PredictionRecoder
-from typography_generation.tools.tokenizer import Tokenizer
-
-
-############################################################
-# DataParallel_withLoss
-############################################################
-class FullModel(nn.Module):
- def __init__(self, model: nn.Module, loss: LossFunc, gpu: bool) -> None:
- super(FullModel, self).__init__()
- self.model = model
- self.loss = loss
- self.gpu = gpu
-
- def forward(
- self,
- design_context_list: List,
- model_input_batchdata: List,
- ) -> Tuple:
- model_inputs = ModelInput(design_context_list, model_input_batchdata, self.gpu)
- outputs = self.model(model_inputs)
- total_loss, record_items = self.loss(model_inputs, outputs, self.training)
- return (outputs, torch.unsqueeze(total_loss, 0), record_items)
-
- def update(self, epoch: int, epochs: int, step: int, steps: int) -> None:
- if self.model.model_name == "canvasvae":
- self.loss.update_vae_weight(epoch, epochs, step, steps)
-
-
-def collate_batch(batch: Tuple[Any, List, str]) -> Tuple:
- design_contexts_list = []
- input_batch = []
- svg_id_list = []
- index_list = []
- for design_contexts, model_input_list, svg_id, index in batch:
- design_contexts_list.append(design_contexts)
- input_batch.append(model_input_list)
- svg_id_list.append(svg_id)
- index_list.append(index)
- input_batch = default_collate(input_batch)
- return design_contexts_list, input_batch, svg_id_list, index_list
-
-
-OPTIMIZER_DICT = {
- "adam": (torch.optim.AdamW, {"betas": (0.5, 0.999)}),
- "sgd": (torch.optim.SGD, {}),
-}
-
-
-############################################################
-# Trainer
-############################################################
-class Trainer:
- def __init__(
- self,
- model: nn.Module,
- gpu: bool,
- save_dir: str,
- dataset: CrelloLoader,
- dataset_val: CrelloLoader,
- prefix_list_object: PrefixListObject,
- prediction_config_element: TextElementContextPredictionAttributeConfig,
- epochs: int = 31,
- save_epoch: int = 5,
- batch_size: int = 32,
- num_worker: int = 2,
- learning_rate: float = 0.0002,
- weight_decay: float = 0.01,
- optimizer_option: str = "adam",
- show_interval: int = 100,
- train_only: bool = False,
- debug: bool = False,
- ) -> None:
- self.gpu = gpu
- self.save_dir = save_dir
- self.epochs = epochs
- self.save_epoch = save_epoch
- self.batch_size = batch_size
- self.num_worker = num_worker
- self.show_interval = show_interval
- self.train_only = train_only
- self.debug = debug
- self.dataset = dataset
- self.dataset_val = dataset_val
- self.prefix_list_target = prefix_list_object.target
-
- layer_regex = {
- "lr1": r"(emb.*)|(enc.*)|(lf.*)|(dec.*)|(head.*)",
- }
- self.epoch = 0
- # model.emb.emb_canvas.load_resnet_weight(data_dir)
- # model.emb.emb_element.load_resnet_weight(data_dir)
- param = [
- p
- for name, p in model.named_parameters()
- if bool(re.fullmatch(layer_regex["lr1"], name))
- ]
- param_name = [
- name
- for name, _ in model.named_parameters()
- if bool(re.fullmatch(layer_regex["lr1"], name))
- ]
- lossfunc = LossFunc(
- model.model_name,
- self.prefix_list_target,
- prediction_config_element,
- gpu,
- topk=1,
- )
- logger.info(optimizer_option)
- optimizer_func, optimizer_kwarg = OPTIMIZER_DICT[optimizer_option]
- optimizer_kwarg["lr"] = learning_rate
- optimizer_kwarg["weight_decay"] = weight_decay
- self.optimizer = optimizer_func(param, **optimizer_kwarg)
- self.fullmodel = FullModel(model, lossfunc, gpu)
- self.writer = SummaryWriter(os.path.join(save_dir, "tensorboard"))
- if gpu is True:
- logger.info("use gpu")
- logger.info(f"torch.cuda.is_available() {torch.cuda.is_available()}")
- self.fullmodel.cuda()
- logger.info("model to cuda")
-
- def train_model(self) -> None:
- # Data generators
- dataloader = torch.utils.data.DataLoader(
- self.dataset,
- batch_size=self.batch_size,
- shuffle=True,
- num_workers=self.num_worker,
- pin_memory=True,
- collate_fn=collate_batch,
- )
- self.fullmodel.train()
- self.pr_train = PredictionRecoder(self.prefix_list_target)
- self.pr_val = PredictionRecoder(self.prefix_list_target)
- self.epoch = 0
- self.iter_count_train = 0
- self.iter_count_val = 0
- for epoch in range(0, self.epochs):
- logger.info("Epoch {}/{}.".format(epoch, self.epochs))
- self.epoch = epoch
- # Training
- self.pr_train.reset()
- logger.info("training")
- self.train_epoch(dataloader)
- self.pr_train.step_epoch()
- torch.cuda.empty_cache()
- gc.collect()
- logger.info("validation")
- self.pr_val.reset()
- if self.train_only is False:
- self.val_model()
- self.pr_val.step_epoch()
- if epoch % self.save_epoch == 0:
- torch.save(
- self.fullmodel.model.state_dict(),
- os.path.join(self.save_dir, "model.pth".format(epoch)),
- )
- torch.cuda.empty_cache()
- gc.collect()
- logger.info("training finished")
- logger.info("show data")
- self.pr_train.show_history_scores()
- self.pr_val.show_history_scores()
-
- def train_epoch(self, dataloader: Any) -> None:
- self.steps = len(dataloader)
- self.fullmodel.train()
- self.step = 0
- self.cnt = 0
- end = time.time()
- for inputs in dataloader:
- logger.debug("load data")
- design_context_list, model_input_batchdata, _, _ = inputs
- logger.debug("train step")
- self.train_step(design_context_list, model_input_batchdata, end)
- end = time.time()
- self.step += 1
- self.fullmodel.update(self.epoch, self.epochs, self.step, self.steps)
- if self.debug is True:
- break
-
- def train_step(
- self,
- design_context_list: List,
- model_input_batchdata: List,
- end: float,
- ) -> None:
- start = time.time()
- logger.debug("model apply")
- _, total_loss, recoder_items = self.fullmodel(
- design_context_list, model_input_batchdata
- )
- logger.debug(f"model apply {time.time()-start}")
- logger.debug("record prediction and gt")
- self.pr_train(recoder_items)
- logger.debug(f"record {time.time()-start}")
- logger.debug("update parameters")
- total_loss = torch.mean(total_loss)
- self.optimizer.zero_grad()
- total_loss.backward()
- self.optimizer.step()
- logger.debug(f"optimize {time.time()-start}")
- forward_time = time.time()
- if self.step % self.show_interval == 0:
- if self.gpu is True:
- torch.cuda.empty_cache()
- data_show = "{}/{}/{}/{}, forward_time: {:.3f} data {:.3f}".format(
- self.epoch,
- self.cnt,
- self.step + 1,
- self.steps,
- forward_time - start,
- (start - end),
- )
- logger.info(data_show)
- data_show = "total_loss: {:.3f}".format(total_loss.item())
- logger.info(data_show)
- score_dict = self.pr_train.compute_score()
- for k, v in score_dict.items():
- self.writer.add_scalar(f"train/{k}", v, self.iter_count_train)
- self.iter_count_train += 1
-
- def val_model(self) -> None:
- # Data generators
- dataloader = torch.utils.data.DataLoader(
- self.dataset_val,
- batch_size=self.batch_size,
- shuffle=False,
- num_workers=self.num_worker,
- pin_memory=True,
- collate_fn=collate_batch,
- )
- with torch.no_grad():
- self.steps = len(dataloader)
- self.step = 0
- self.cnt = 0
- self.fullmodel.eval()
- end = time.time()
- for inputs in dataloader:
- design_context_list, model_input_batchdata, _, _ = inputs
- self.val_step(design_context_list, model_input_batchdata, end)
- end = time.time()
- self.step += 1
-
- def val_step(
- self,
- design_context_list: List,
- model_input_batchdata: List,
- end: float,
- ) -> None:
- start = time.time()
- _, _, recoder_items = self.fullmodel(design_context_list, model_input_batchdata)
- self.pr_val(recoder_items)
- forward_time = time.time()
- if self.step % 40 == 0:
- data_show = "{}/{}/{}, forward_time: {:.3f} data {:.3f}".format(
- self.cnt,
- self.step + 1,
- self.steps,
- forward_time - start,
- (start - end),
- )
- logger.info(data_show)
- score_dict = self.pr_val.compute_score()
- for k, v in score_dict.items():
- self.writer.add_scalar(f"val/{k}", v, self.iter_count_val)
- self.iter_count_val += 1
- torch.cuda.empty_cache()
- gc.collect()
diff --git a/src/typography_generation/visualization/__init__.py b/src/typography_generation/visualization/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/src/typography_generation/visualization/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/typography_generation/visualization/renderer.py b/src/typography_generation/visualization/renderer.py
deleted file mode 100644
index 5a7d8ee..0000000
--- a/src/typography_generation/visualization/renderer.py
+++ /dev/null
@@ -1,183 +0,0 @@
-from typing import Any, List, Tuple
-import skia
-import os
-import pickle
-import numpy as np
-
-from typography_generation.visualization.renderer_util import (
- get_color_map,
- get_skia_font,
- get_text_actual_height,
- get_text_actual_width,
- get_text_alpha,
- get_texts,
-)
-
-
-class TextRenderer:
- def __init__(
- self,
- data_dir: str,
- fontlabel2fontname: Any,
- ) -> None:
- self.fontmgr = skia.FontMgr()
- fn = os.path.join(data_dir, "font2ttf.pkl")
- _font2ttf = pickle.load(open(fn, "rb"))
- font2ttf = {}
- for key in _font2ttf.keys():
- tmp = _font2ttf[key].split("/data/dataset/crello/")[1]
- fn = os.path.join(data_dir, tmp)
- font2ttf[key] = fn
- self.font2ttf = font2ttf
- fn = os.path.join(data_dir, "fonttype2fontid_fix.pkl")
- fonttype2fontid = pickle.load(open(fn, "rb"))
-
- self.fontid2fonttype = {}
- for k, v in fonttype2fontid.items():
- self.fontid2fonttype[v] = k
- self.fontlabel2fontname = fontlabel2fontname
-
- def draw_texts(
- self,
- element_data: dict,
- text_ids: List,
- bg: np.array,
- scaleinfo: Tuple,
- ) -> np.array:
- H, W = bg.shape[0], bg.shape[1]
- h_rate, w_rate = scaleinfo
- canvas = bg.copy()
- for text_id in text_ids:
- font_label = element_data["font"][text_id]
- font_name = self.fontlabel2fontname(int(font_label))
- font_name = font_name.replace(" ", "_")
- texts = get_texts(element_data, text_id)
- font, _ = get_skia_font(
- self.font2ttf,
- self.fontmgr,
- element_data,
- text_id,
- font_name,
- h_rate,
- )
- text_alpha = get_text_alpha(
- element_data,
- text_id,
- texts,
- font,
- H,
- W,
- w_rate,
- )
- text_rgb_map = get_color_map(element_data, text_id, H, W)
- canvas = canvas * (1 - text_alpha) + text_alpha * text_rgb_map
- return canvas
-
- def get_text_alpha_list(
- self,
- element_data: dict,
- text_ids: List,
- image_size: Tuple[int, int],
- scaleinfo: Tuple[float, float],
- ) -> List:
- H, W = image_size
- h_rate, w_rate = scaleinfo
- text_alpha_list = []
- for text_id in text_ids:
- font_label = element_data["font"][text_id]
- font_name = self.fontlabel2fontname(int(font_label))
- font_name = font_name.replace(" ", "_")
- texts = get_texts(element_data, text_id)
- font, _ = get_skia_font(
- self.font2ttf,
- self.fontmgr,
- element_data,
- text_id,
- font_name,
- h_rate,
- )
- text_alpha = get_text_alpha(
- element_data,
- text_id,
- texts,
- font,
- H,
- W,
- w_rate,
- )
- text_alpha_list.append(text_alpha)
- return text_alpha_list
-
- def get_text_actual_height_list(
- self, element_data: dict, text_ids: List, scaleinfo: Tuple[float, float]
- ) -> List:
- text_actual_height_list = []
- h_rate, _ = scaleinfo
- for text_id in text_ids:
- font_label = element_data["font"][text_id]
- font_name = self.fontlabel2fontname(int(font_label))
- font_name = font_name.replace(" ", "_")
- font, _ = get_skia_font(
- self.font2ttf,
- self.fontmgr,
- element_data,
- text_id,
- font_name,
- h_rate,
- )
- text_actual_height = get_text_actual_height(font)
- text_actual_height_list.append(text_actual_height)
- return text_actual_height_list
-
- def compute_and_set_text_actual_width(
- self,
- element_data: dict,
- text_ids: List,
- scaleinfo: Tuple,
- ) -> None:
- h_rate, w_rate = scaleinfo
- text_actual_width = {}
- for text_id in text_ids:
- font_label = element_data["font"][text_id]
- font_name = self.fontlabel2fontname(int(font_label))
- font_name = font_name.replace(" ", "_")
- texts = get_texts(element_data, text_id)
- font, _ = get_skia_font(
- self.font2ttf,
- self.fontmgr,
- element_data,
- text_id,
- font_name,
- h_rate,
- )
- _text_actual_width = get_text_actual_width(
- element_data, text_id, texts, font, w_rate
- )
- text_actual_width[text_id] = _text_actual_width
- element_data["text_actual_width"] = text_actual_width
-
- def compute_and_set_text_center(
- self, element_data: dict, text_ids: List, image_size: Tuple[int, int]
- ) -> None:
- H, W = image_size
- text_center_x = {}
- text_center_y = {}
- for text_id in text_ids:
- left = element_data["left"][text_id] * W
- w = element_data["width"][text_id] * W
- textAlign = element_data["text_align"][text_id]
- right = left + w
- actual_w = element_data["text_actual_width"][text_id]
- if textAlign == 1:
- _text_center_x = (left + right) / 2.0
- elif textAlign == 3:
- _text_center_x = right - actual_w / 2.0
- elif textAlign == 2:
- _text_center_x = left + actual_w / 2.0
- text_height = element_data["height"][text_id] * H
- top = element_data["top"][text_id] * H
- _text_center_y = (top + top + text_height) / 2.0
- text_center_y[text_id] = _text_center_y
- text_center_x[text_id] = _text_center_x
- element_data["text_center_y"] = text_center_y
- element_data["text_center_x"] = text_center_x
diff --git a/src/typography_generation/visualization/renderer_util.py b/src/typography_generation/visualization/renderer_util.py
deleted file mode 100644
index a4e6504..0000000
--- a/src/typography_generation/visualization/renderer_util.py
+++ /dev/null
@@ -1,225 +0,0 @@
-import html
-import math
-import os
-from typing import Any, Dict, List, Tuple, Union
-
-import numpy as np
-import skia
-
-
-def capitalize_text(text: str, capitalize: int) -> str:
- text_tmp = ""
- for character in text:
- if capitalize == 1:
- character = character.upper()
- text_tmp += character
- return text_tmp
-
-
-def text_align(
- textAlign: int, left: float, center: float, right: float, text_alpha_width: float
-) -> float:
- if textAlign == 1:
- x = center - text_alpha_width / 2.0
- elif textAlign == 3:
- x = right - text_alpha_width
- elif textAlign == 2:
- x = left
- return x
-
-
-def get_text_location_info(font: skia.Font, text_tmp: str) -> Tuple:
- glyphs = font.textToGlyphs(text_tmp)
- positions = font.getPos(glyphs)
- rects = font.getBounds(glyphs)
- return glyphs, positions, rects
-
-
-def compute_text_alpha_width(
- positions: List, rects: List, letterSpacing: float
-) -> Union[float, Any]:
- twidth = positions[-1].x() + rects[-1].right()
- if letterSpacing is not None:
- twidth += letterSpacing * (len(rects) - 1)
- return twidth
-
-
-def add_letter_margin(x: float, letterSpacing: float) -> float:
- if letterSpacing is not None:
- x = x + letterSpacing
- return x
-
-
-def get_text_actual_height(
- font: skia.Font,
-):
- ascender = -1 * font.getMetrics().fAscent
- descender = font.getMetrics().fDescent
- leading = font.getMetrics().fLeading
- text_height = ascender + descender + leading
- return text_height
-
-
-def get_text_alpha(
- element_data: Any,
- text_index: int,
- texts: List,
- font: skia.Font,
- H: int,
- W: int,
- w_rate: float,
-) -> np.array:
- center_y = element_data["text_center_y"][text_index] * H
- center_x = element_data["text_center_x"][text_index] * W
- text_width = get_text_actual_width(element_data, text_index, texts, font, w_rate)
- left = center_x - text_width / 2.0
- right = center_x + text_width / 2.0
- ascender = -1 * font.getMetrics().fAscent
- descender = font.getMetrics().fDescent
- leading = font.getMetrics().fLeading
- line_height_scale = element_data["line_height"][text_index]
- line_height = (ascender + descender + leading) * line_height_scale
- surface = skia.Surface(W, H)
- canvas = surface.getCanvas()
- fill_paint = skia.Paint(
- AntiAlias=True,
- Color=skia.ColorSetRGB(255, 0, 0),
- Style=skia.Paint.kFill_Style,
- )
- fill_paint.setBlendMode(skia.BlendMode.kSrcOver)
- y = center_y - line_height * len(texts) / 2.0
- for text in texts:
- text = html.unescape(text)
- text = capitalize_text(text, element_data["capitalize"][text_index])
- _, positions, rects = get_text_location_info(font, text)
- if len(positions) == 0:
- continue
- text_alpha_width = compute_text_alpha_width(
- positions, rects, element_data["letter_spacing"][text_index] * w_rate
- )
- # print(text,element_data["letter_spacing"][text_index],element_data["capitalize"][text_index])
- angle = float(element_data["angle"][text_index]) * 180 / math.pi
- x = text_align(
- element_data["text_align"][text_index],
- left,
- center_x,
- right,
- text_alpha_width,
- )
- canvas.rotate(angle, center_x, center_y)
- for i, character in enumerate(text):
- ydp = np.round(y + positions[i].y() + ascender)
- xdp = np.round(x + positions[i].x())
- textblob = skia.TextBlob(character, font)
- canvas.drawTextBlob(textblob, xdp, ydp, fill_paint)
- x = add_letter_margin(
- x, element_data["letter_spacing"][text_index] * w_rate
- )
- canvas.rotate(-1 * angle, center_x, y)
- y += line_height
- text_alpha = surface.makeImageSnapshot().toarray()[:, :, 0]
- text_alpha = text_alpha / 255.0
- text_alpha = np.tile(text_alpha[:, :, np.newaxis], (1, 1, 3))
- return np.minimum(text_alpha, np.zeros_like(text_alpha) + 1)
-
-
-def get_text_actual_width(
- element_data: Any,
- text_index: int,
- texts: List,
- font: skia.Font,
- w_rate: float,
-) -> np.array:
- text_alpha_width = 0.0
- for text in texts:
- text = html.unescape(text)
- text = capitalize_text(text, element_data["capitalize"][text_index])
- _, positions, rects = get_text_location_info(font, text)
- if len(positions) == 0:
- continue
- _text_alpha_width = compute_text_alpha_width(
- positions, rects, element_data["letter_spacing"][text_index] * w_rate
- )
- text_alpha_width = max(text_alpha_width, _text_alpha_width)
- return text_alpha_width
-
-
-def font_name_fix(font_name: str) -> str:
- if font_name == "Exo_2":
- font_name = "Exo\_2"
- if font_name == "Press_Start_2P":
- font_name = "Press_Start\_2P"
- if font_name == "quattrocento":
- font_name = "Quattrocento"
- if font_name == "yellowtail":
- font_name = "Yellowtail"
- if font_name == "sunday":
- font_name = "Sunday"
- if font_name == "bebas_neue":
- font_name = "Bebas_Neue"
- if font_name == "Brusher":
- font_name = "Brusher_Regular"
- if font_name == "Amatic_Sc":
- font_name = "Amatic_SC"
- if font_name == "Pt_Sans":
- font_name = "PT_Sans"
- if font_name == "Old_Standard_Tt":
- font_name = "Old_Standard_TT"
- if font_name == "Eb_Garamond":
- font_name = "EB_Garamond"
- if font_name == "Gfs_Didot":
- font_name = "GFS_Didot"
- if font_name == "Im_Fell":
- font_name = "IM_Fell"
- if font_name == "Im_Fell_Dw_Pica_Sc":
- font_name = "IM_Fell_DW_Pica_SC"
- if font_name == "Marcellus_Sc":
- font_name = "Marcellus_SC"
- return font_name
-
-
-def get_skia_font(
- font2ttf: dict,
- fontmgr: skia.FontMgr,
- element_data: Dict,
- targetid: int,
- font_name: str,
- scale_h: float,
- font_scale: float = 1.0,
-) -> Tuple:
- font_name = font_name_fix(font_name)
- if font_name in font2ttf:
- ttf = font2ttf[font_name]
- ft = fontmgr.makeFromFile(ttf, 0)
- font = skia.Font(
- ft, element_data["font_size"][targetid] * scale_h, font_scale, 1e-20
- )
-
- return font, font_name
- else:
- ft = fontmgr.makeFromFile("", 0)
- font = skia.Font(
- ft, element_data["font_size"][targetid] * scale_h, font_scale, 1e-20
- )
- return None, None
-
-
-def get_color_map(element_data: Any, targetid: int, H: int, W: int) -> np.array:
- B, G, R = element_data["color"][targetid]
- rgb_map = np.zeros((H, W, 3), dtype=np.uint8)
- rgb_map[:, :, 0] = B
- rgb_map[:, :, 1] = G
- rgb_map[:, :, 2] = R
- return rgb_map
-
-
-def get_texts(element_data: Dict, target_id: int) -> List:
- text = html.unescape(element_data["text"][target_id])
- _texts = text.split(os.linesep)
- texts = []
- for t in _texts:
- if t == "":
- pass
- else:
- texts.append(t)
- return texts
diff --git a/src/typography_generation/visualization/visualizer.py b/src/typography_generation/visualization/visualizer.py
deleted file mode 100644
index 1ef6d6c..0000000
--- a/src/typography_generation/visualization/visualizer.py
+++ /dev/null
@@ -1,190 +0,0 @@
-import copy
-from typing import Dict, List, Tuple
-
-import numpy as np
-from typography_generation.tools.denormalizer import Denormalizer
-from typography_generation.tools.tokenizer import Tokenizer
-from typography_generation.visualization.renderer import TextRenderer
-
-crelloattstr2pkgattstr = {
- "text_font": "font",
- "text_font_color": "color",
- "text_align_type": "text_align",
- "text_capitalize": "capitalize",
- "text_font_size": "font_size",
- "text_font_size_raw": "font_size",
- "text_angle": "angle",
- "text_letter_spacing": "letter_spacing",
- "text_line_height_scale": "line_height",
- "text_center_y": "text_center_y",
- "text_center_x": "text_center_x",
-}
-
-
-def get_text_ids(element_data: Dict) -> List:
- text_ids = []
- for k in range(len(element_data["text"])):
- if element_data["text"][k] == "":
- pass
- else:
- text_ids.append(k)
- return text_ids
-
-
-def replace_style_data_by_prediction(
- prediction: Dict, element_data: Dict, text_ids: List
-) -> Dict:
- element_data = copy.deepcopy(element_data)
- for prefix_pred, prefix_vec in crelloattstr2pkgattstr.items():
- if prefix_pred in prediction.keys():
- for i, t in enumerate(text_ids):
- element_data[prefix_vec][t] = prediction[prefix_pred][i][0]
- return element_data
-
-
-def get_ordered_text_ids(element_data, order_list) -> List:
- text_ids = []
- for i in order_list:
- if element_data["text"][i] == "":
- pass
- else:
- text_ids.append(i)
- return text_ids
-
-
-def visualize_prediction(
- renderer: TextRenderer,
- element_data: Dict,
- prediction: Dict,
- bg_img: np.array,
-) -> np.array:
- text_ids = get_text_ids(element_data)
- order_list = element_data["order_list"]
- scaleinfo = element_data["scale_box"]
- text_ids = get_ordered_text_ids(element_data, order_list)
- element_data = replace_style_data_by_prediction(prediction, element_data, text_ids)
- img = renderer.draw_texts(element_data, text_ids, np.array(bg_img), scaleinfo)
- return img
-
-
-def get_predicted_alphamaps(
- renderer: TextRenderer,
- element_data: Dict,
- prediction: Dict,
- image_size: Tuple[int, int],
- order_list: List = None,
-) -> List:
- text_ids = get_text_ids(element_data)
- if order_list is None:
- order_list = element_data["order_list"]
- scaleinfo = element_data["scale_box"]
- text_ids = ordering_text_ids(order_list, text_ids)
- element_data = replace_style_data_by_prediction(prediction, element_data, text_ids)
- alpha_list = renderer.get_text_alpha_list(
- element_data, text_ids, image_size, scaleinfo
- )
- return alpha_list
-
-
-def get_element_alphamaps(
- renderer: TextRenderer,
- element_data: Dict,
-) -> List:
- text_ids = get_text_ids(element_data)
- order_list = element_data["order_list"]
- scaleinfo = element_data["scale_box"]
- image_size = element_data["canvas_bg_size"]
- text_ids = ordering_text_ids(order_list, text_ids)
- alpha_list = renderer.get_text_alpha_list(
- element_data, text_ids, image_size, scaleinfo
- )
- return alpha_list
-
-
-def visualize_data(
- renderer: TextRenderer,
- element_data: Dict,
- bg_img: np.array,
-) -> np.array:
- text_ids = get_text_ids(element_data)
- scaleinfo = element_data["scale_box"]
- img = renderer.draw_texts(element_data, text_ids, np.array(bg_img), scaleinfo)
- return img
-
-
-def tokenize(
- _element_data: Dict,
- tokenizer: Tokenizer,
- denormalizer: Denormalizer,
- text_ids: List,
- bg_img: np.array,
- scaleinfo: Tuple,
-) -> Dict:
- element_data = copy.deepcopy(_element_data)
- h, w = bg_img.size[1], bg_img.size[0]
- for prefix_pred, prefix_vec in crelloattstr2pkgattstr.items():
- for i, t in enumerate(text_ids):
- data_info = {
- "element_data": _element_data,
- "text_index": t,
- "img_size": (h, w),
- "scaleinfo": scaleinfo,
- "text_actual_width": _element_data["text_actual_width"][t],
- "text": None,
- }
- data = getattr(denormalizer.dataset, f"get_{prefix_pred}")(**data_info)
- if prefix_pred in denormalizer.dataset.tokenizer.prefix_list:
- data = tokenizer.tokenize(prefix_pred, data)
- data = tokenizer.detokenize(prefix_pred, data)
- data = denormalizer.denormalize_elm(
- prefix_pred, data, h, w, scaleinfo[0], scaleinfo[1]
- )
- element_data[prefix_vec][t] = data
- return element_data
-
-
-def visualize_tokenization(
- renderer: TextRenderer,
- tokenizer: Tokenizer,
- denormalizer: Denormalizer,
- element_data: Dict,
- bg_img: np.array,
-) -> np.array:
- text_ids = get_text_ids(element_data)
- scaleinfo = element_data["scale_box"]
- element_data = tokenize(
- element_data, tokenizer, denormalizer, text_ids, bg_img, scaleinfo
- )
- img = renderer.draw_texts(element_data, text_ids, np.array(bg_img), scaleinfo)
- return img
-
-
-def get_text_coords(element_data: Dict, text_index: int, img_size: Tuple) -> Tuple:
- h, w = img_size
- top = int(element_data["top"][text_index] * h)
- left = int(element_data["left"][text_index] * w)
- height = int(element_data["height"][text_index] * h)
- width = int(element_data["width"][text_index] * w)
- return top, left, top + height, left + width
-
-
-def colorize_text(
- element_data: Dict,
- canvas: np.array,
- text_index: int,
- color: Tuple = (255, 0, 0),
- w: float = 0.5,
-) -> np.array:
- text_ids = get_text_ids(element_data)
- order_list = element_data["order_list"]
- scaleinfo = element_data["scale_box"]
- text_ids = ordering_text_ids(order_list, text_ids)
- y0, x0, y1, x1 = get_text_coords(
- element_data, text_ids[text_index], canvas.shape[:2]
- )
- tmp = canvas.copy()
- tmp[y0:y1, x0:x1, :] = np.array(color)
- canvas[y0:y1, x0:x1, :] = (
- w * canvas[y0:y1, x0:x1, :] + (1 - w) * tmp[y0:y1, x0:x1, :]
- )
- return canvas
diff --git a/tests/__init__.py b/tests/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/tests/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/tests/conftest.py b/tests/conftest.py
deleted file mode 100644
index 4254d4d..0000000
--- a/tests/conftest.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import pytest
-import torch
-
-from typography_generation.config.config_args_util import (
- get_global_config,
- get_model_config_input,
- get_prefix_lists,
-)
-from typography_generation.config.default import get_font_config
-from typography_generation.io.build_dataset import (
- build_test_dataset,
- build_train_dataset,
-)
-from typography_generation.model.model import create_model
-from typography_generation.tools.train import collate_batch
-
-
-@pytest.fixture
-def bartconfig():
- data_dir = "data"
- config_name = "bart"
- bartconfig = get_global_config(data_dir, config_name)
- return bartconfig
-
-
-@pytest.fixture
-def bartconfigdataset_test(bartconfig):
- data_dir = "data"
- prefix_list_object = get_prefix_lists(bartconfig)
- font_config = get_font_config(bartconfig)
-
- bartconfigdataset_test = build_test_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- debug=True,
- )
- return bartconfigdataset_test
-
-
-@pytest.fixture
-def bartconfigdataset(bartconfig):
- data_dir = "data"
- prefix_list_object = get_prefix_lists(bartconfig)
- font_config = get_font_config(bartconfig)
-
- bartconfigdataset, _ = build_train_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- debug=True,
- )
- return bartconfigdataset
-
-
-@pytest.fixture
-def bartconfigdataset_val(bartconfig):
- data_dir = "data"
- prefix_list_object = get_prefix_lists(bartconfig)
- font_config = get_font_config(bartconfig)
-
- _, bartconfigdataset_val = build_train_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- debug=True,
- )
- return bartconfigdataset_val
-
-
-@pytest.fixture
-def bartconfigdataloader(bartconfigdataset):
- bartconfigdataloader = torch.utils.data.DataLoader(
- bartconfigdataset,
- batch_size=2,
- shuffle=False,
- num_workers=1,
- pin_memory=True,
- collate_fn=collate_batch,
- )
- return bartconfigdataloader
-
-
-@pytest.fixture
-def bartmodel(bartconfig):
- model_name, model_kwargs = get_model_config_input(bartconfig)
-
- bertmodel = create_model(
- model_name,
- **model_kwargs,
- )
- return bertmodel
diff --git a/tests/io/__init__.py b/tests/io/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/tests/io/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/tests/io/test_data_loader.py b/tests/io/test_data_loader.py
deleted file mode 100644
index 3967126..0000000
--- a/tests/io/test_data_loader.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import pytest
-import datasets
-import torch
-from typography_generation.__main__ import (
- get_global_config,
- get_prefix_lists,
-)
-from typography_generation.config.default import (
- get_datapreprocess_config,
- get_font_config,
-)
-from typography_generation.io.build_dataset import build_test_dataset
-from typography_generation.io.data_loader import CrelloLoader
-from typography_generation.tools.tokenizer import Tokenizer
-
-from typography_generation.tools.train import collate_batch
-
-params = {"normal 1": ("data", "bart", 0), "normal 2": ("data", "bart", 1)}
-
-
-@pytest.mark.parametrize("data_dir, config_name, index", list(params.values()))
-def test_get_item(data_dir: str, config_name: str, index: int) -> None:
- config = get_global_config(data_dir, config_name)
- prefix_list_object = get_prefix_lists(config)
- font_config = get_font_config(config)
-
- dataset = build_test_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- debug=True,
- )
- dataset.__getitem__(index)
-
-
-@pytest.mark.parametrize("data_dir, config_name", [["data", "bart"]])
-def test_dataloader_iteration(data_dir: str, config_name: str) -> None:
- config = get_global_config(data_dir, config_name)
- prefix_list_object = get_prefix_lists(config)
- font_config = get_font_config(config)
-
- dataset = build_test_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- debug=True,
- )
-
- dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=1,
- shuffle=False,
- num_workers=1,
- pin_memory=True,
- collate_fn=collate_batch,
- )
- tmp = dataloader.__iter__()
- next(tmp)
- next(tmp)
diff --git a/tests/model/__init__.py b/tests/model/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/tests/model/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/tests/model/test_model.py b/tests/model/test_model.py
deleted file mode 100644
index 70d3298..0000000
--- a/tests/model/test_model.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from typing import Any
-import pytest
-import datasets
-import torch
-from typography_generation.__main__ import (
- get_global_config,
- get_model_config_input,
- get_prefix_lists,
-)
-from typography_generation.config.default import (
- get_datapreprocess_config,
- get_font_config,
-)
-from typography_generation.io.build_dataset import build_test_dataset
-from typography_generation.io.data_object import ModelInput
-from typography_generation.model.model import create_model
-from typography_generation.tools.train import collate_batch
-
-
-@pytest.mark.parametrize(
- "data_dir, config_name",
- [
- ["data", "bart"],
- # ["crello", "mlp"],
- # ["crello", "canvasvae"],
- # ["crello", "mfc"],
- ],
-)
-def test_model(dataset: Any, data_dir: str, config_name: str) -> None:
- config = get_global_config(data_dir, config_name)
- model_name, model_kwargs = get_model_config_input(config)
-
- model = create_model(
- model_name,
- **model_kwargs,
- )
-
- dataloader = torch.utils.data.DataLoader(
- dataset,
- batch_size=2,
- shuffle=False,
- num_workers=1,
- pin_memory=True,
- collate_fn=collate_batch,
- )
- tmp = dataloader.__iter__()
- design_context_list, model_input_batchdata, _, _ = next(tmp)
- model_inputs = ModelInput(design_context_list, model_input_batchdata, gpu=False)
- model(model_inputs)
diff --git a/tests/test_main.py b/tests/test_main.py
deleted file mode 100644
index 47fda27..0000000
--- a/tests/test_main.py
+++ /dev/null
@@ -1,106 +0,0 @@
-import datasets
-import pytest
-from logzero import logger
-import os
-from typography_generation.__main__ import (
- get_global_config,
- get_model_config_input,
- get_prefix_lists,
-)
-from typography_generation.config.default import (
- get_datapreprocess_config,
- get_font_config,
- get_model_input_prefix_list,
-)
-from typography_generation.io.build_dataset import build_train_dataset
-from typography_generation.io.data_loader import CrelloLoader
-from typography_generation.model.model import create_model
-from typography_generation.tools.tokenizer import Tokenizer
-from typography_generation.tools.train import Trainer
-
-
-@pytest.mark.parametrize("data_dir, config_name", [["crello", "bart"]])
-def test_model_creation(data_dir: str, config_name: str) -> None:
- config = get_global_config(data_dir, config_name)
- model_name, model_kwargs = get_model_config_input(config)
-
- create_model(
- model_name,
- **model_kwargs,
- )
-
-
-@pytest.mark.parametrize("data_dir, config_name", [["crello", "bart"]])
-def test_loader_creation(data_dir: str, config_name: str) -> None:
- config = get_global_config(data_dir, config_name)
- prefix_list_model_input = get_model_input_prefix_list(config)
- font_config = get_font_config(config)
- datapreprocess_config = get_datapreprocess_config(config)
- hugging_dataset = datasets.load_from_disk(
- os.path.join(data_dir, "extended_dataset", "map_features.hf")
- )
- tokenizer = Tokenizer(data_dir)
- CrelloLoader(
- data_dir,
- tokenizer,
- hugging_dataset,
- prefix_list_model_input,
- font_config,
- datapreprocess_config,
- train=True,
- dataset_division="train",
- )
-
-
-@pytest.mark.parametrize(
- "data_dir, save_dir, config_name",
- [["crello", "job", "bart"]],
-)
-def test_trainer_creation(data_dir: str, save_dir: str, config_name: str) -> None:
- logger.info(config_name)
- config = get_global_config(data_dir, config_name)
- model_name, model_kwargs = get_model_config_input(config)
-
- model = create_model(
- model_name,
- **model_kwargs,
- )
- prefix_list_object = get_prefix_lists(config)
- prediction_config_element = config.text_element_prediction_attribute_config
- font_config = get_font_config(config)
- datapreprocess_config = get_datapreprocess_config(config)
- optimizer_option = config.train_config.optimizer
-
- dataset, dataset_val = build_train_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- datapreprocess_config,
- debug=True,
- )
-
- gpu = False
- Trainer(
- model,
- gpu,
- save_dir,
- dataset,
- dataset_val,
- prefix_list_object,
- prediction_config_element,
- optimizer_option=optimizer_option,
- )
-
-
-@pytest.mark.parametrize(
- "data_dir, config_name",
- [["crello", "bart"]],
-)
-def test_model_config(data_dir: str, config_name: str) -> None:
- logger.info(f"config_name {config_name}")
- config = get_global_config(data_dir, config_name)
- model_name, model_kwargs = get_model_config_input(config)
- create_model(
- model_name,
- **model_kwargs,
- )
diff --git a/tests/tool/__init__.py b/tests/tool/__init__.py
deleted file mode 100644
index 8b13789..0000000
--- a/tests/tool/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/tests/tool/test_eval.py b/tests/tool/test_eval.py
deleted file mode 100644
index fcf9252..0000000
--- a/tests/tool/test_eval.py
+++ /dev/null
@@ -1,95 +0,0 @@
-import pytest
-from typography_generation.__main__ import (
- get_global_config,
- get_model_config_input,
- get_prefix_lists,
-)
-from typography_generation.config.default import (
- get_font_config,
-)
-from typography_generation.io.build_dataset import build_test_dataset
-from typography_generation.model.model import create_model
-from typography_generation.tools.evaluator import Evaluator
-
-
-def test_eval_iter(bartconfig, bartconfigdataset_test, bartmodel) -> None:
- prefix_list_object = get_prefix_lists(bartconfig)
-
- gpu = False
- save_dir = "job"
- evaluator = Evaluator(
- bartmodel,
- gpu,
- save_dir,
- bartconfigdataset_test,
- prefix_list_object,
- debug=True,
- )
- evaluator.eval_model()
-
-
-@pytest.mark.parametrize(
- "data_dir, save_dir, config_name, elementembeddingflagconfigname, canvasembeddingflagconfigname, elementpredictionflagconfigname",
- [
- [
- "data",
- "job",
- "bart",
- "text_element_embedding_flag_config/wofontsize",
- "canvas_embedding_flag_config/canvas_detail_given",
- "text_element_prediction_flag_config/rawfontsize",
- ],
- ],
-)
-def test_flag_config_eval(
- data_dir: str,
- save_dir: str,
- config_name: str,
- elementembeddingflagconfigname: str,
- canvasembeddingflagconfigname: str,
- elementpredictionflagconfigname,
-) -> None:
- global_config_input = {}
- global_config_input["data_dir"] = data_dir
- global_config_input["model_name"] = config_name
- global_config_input["test_config_name"] = "test_config"
- global_config_input["model_config_name"] = "model_config"
- global_config_input[
- "elementembeddingflag_config_name"
- ] = elementembeddingflagconfigname
- global_config_input[
- "canvasembeddingflag_config_name"
- ] = canvasembeddingflagconfigname
- global_config_input[
- "elementpredictionflag_config_name"
- ] = elementpredictionflagconfigname
-
- config = get_global_config(**global_config_input)
- model_name, model_kwargs = get_model_config_input(config)
-
- model = create_model(
- model_name,
- **model_kwargs,
- )
- # import torch
- # model.load_state_dict(torch.load("job/weight.pth", map_location="cpu"))
- prefix_list_object = get_prefix_lists(config)
- font_config = get_font_config(config)
-
- dataset = build_test_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- debug=True,
- )
-
- gpu = False
- evaluator = Evaluator(
- model,
- gpu,
- save_dir,
- dataset,
- prefix_list_object,
- debug=True,
- )
- evaluator.eval_model()
diff --git a/tests/tool/test_sample.py b/tests/tool/test_sample.py
deleted file mode 100644
index 75240c6..0000000
--- a/tests/tool/test_sample.py
+++ /dev/null
@@ -1,90 +0,0 @@
-from typing import Dict, List
-import pytest
-from typography_generation.__main__ import (
- get_global_config,
- get_model_config_input,
- get_prefix_lists,
- get_sampling_config,
-)
-from typography_generation.config.default import (
- get_font_config,
-)
-from typography_generation.io.build_dataset import build_test_dataset
-from typography_generation.model.model import create_model
-from typography_generation.tools.sampler import Sampler
-
-
-@pytest.mark.parametrize(
- "data_dir, save_dir, config_name, test_config, model_config",
- [
- ["data", "job", "bart", "test_config/topp07", "model_config"],
- ],
-)
-def test_sample_iter(
- bartconfigdataset_test,
- bartmodel,
- data_dir: str,
- save_dir: str,
- config_name: str,
- test_config: str,
- model_config: str,
-) -> None:
- config = get_global_config(
- data_dir,
- config_name,
- test_config_name=test_config,
- model_config_name=model_config,
- )
- prefix_list_object = get_prefix_lists(config)
- sampling_config = get_sampling_config(config)
-
- gpu = False
- sampler = Sampler(
- bartmodel,
- gpu,
- save_dir,
- bartconfigdataset_test,
- prefix_list_object,
- sampling_config,
- debug=True,
- )
- sampler.sample()
-
-
-@pytest.mark.parametrize(
- "data_dir,save_dir,config_name",
- [
- ["data", "job", "canvasvae"],
- ],
-)
-def test_sample_canvasvae(data_dir: str, save_dir: str, config_name: str) -> None:
- config = get_global_config(data_dir, config_name)
- model_name, model_kwargs = get_model_config_input(config)
-
- model = create_model(
- model_name,
- **model_kwargs,
- )
- prefix_list_object = get_prefix_lists(config)
- font_config = get_font_config(config)
- sampling_config = get_sampling_config(config)
-
- gpu = False
- dataset = build_test_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- debug=True,
- )
-
- gpu = False
- sampler = Sampler(
- model,
- gpu,
- save_dir,
- dataset,
- prefix_list_object,
- sampling_config,
- debug=True,
- )
- sampler.sample()
diff --git a/tests/tool/test_structure_preserved_sampler.py b/tests/tool/test_structure_preserved_sampler.py
deleted file mode 100644
index 0f4a042..0000000
--- a/tests/tool/test_structure_preserved_sampler.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from typing import Dict, List
-import pytest
-import os
-import pickle
-import torch
-from typography_generation.__main__ import (
- get_global_config,
- get_model_config_input,
- get_prefix_lists,
- get_sampling_config,
-)
-from typography_generation.config.default import (
- get_datapreprocess_config,
- get_font_config,
-)
-from typography_generation.io.build_dataset import build_test_dataset
-from typography_generation.io.data_loader import CrelloLoader
-from typography_generation.io.data_object import ModelInput
-from typography_generation.model.model import create_model
-from typography_generation.tools.denormalizer import Denormalizer
-from typography_generation.tools.structure_preserved_sampler import (
- StructurePreservedSampler,
-)
-from typography_generation.tools.tokenizer import Tokenizer
-import datasets
-
-from typography_generation.tools.train import collate_batch
-
-import logzero
-from logzero import logger
-import logging
-
-
-def test_sample_iter(bartconfig, bartconfigdataset_test, bartmodel) -> None:
- prefix_list_object = get_prefix_lists(bartconfig)
- sampling_config = get_sampling_config(bartconfig)
-
- gpu = False
- save_dir = "job"
- sampler = StructurePreservedSampler(
- bartmodel,
- gpu,
- save_dir,
- bartconfigdataset_test,
- prefix_list_object,
- sampling_config,
- debug=True,
- )
- sampler.sample()
diff --git a/tests/tool/test_train.py b/tests/tool/test_train.py
deleted file mode 100644
index 5734b28..0000000
--- a/tests/tool/test_train.py
+++ /dev/null
@@ -1,167 +0,0 @@
-import pytest
-from typography_generation.__main__ import (
- get_global_config,
- get_model_config_input,
- get_prefix_lists,
-)
-from typography_generation.config.default import (
- get_font_config,
-)
-from typography_generation.io.build_dataset import (
- build_test_dataset,
- build_train_dataset,
-)
-from typography_generation.io.data_object import ModelInput
-from typography_generation.model.model import create_model
-from typography_generation.tools.loss import LossFunc
-from typography_generation.tools.train import Trainer
-
-
-def test_compute_loss(bartconfig, bartconfigdataloader, bartmodel) -> None:
- prefix_list_object = get_prefix_lists(bartconfig)
-
- tmp = bartconfigdataloader.__iter__()
- design_context_list, model_input_batchdata, _, _ = next(tmp)
- model_inputs = ModelInput(design_context_list, model_input_batchdata, gpu=False)
- outputs = bartmodel(model_inputs)
- prediction_config_element = bartconfig.text_element_prediction_attribute_config
- loss = LossFunc(
- bartmodel.model_name,
- prefix_list_object.target,
- prediction_config_element,
- gpu=False,
- )
- loss(model_inputs, outputs, training=True)
-
-
-def test_train_iter(
- bartconfig, bartconfigdataset, bartconfigdataset_val, bartmodel
-) -> None:
- save_dir = "job"
- optimizer_option = bartconfig.train_config.optimizer
- prefix_list_object = get_prefix_lists(bartconfig)
- prediction_config_element = bartconfig.text_element_prediction_attribute_config
- trainer = Trainer(
- bartmodel,
- False,
- save_dir,
- bartconfigdataset,
- bartconfigdataset_val,
- prefix_list_object,
- prediction_config_element,
- optimizer_option=optimizer_option,
- debug=True,
- epochs=1,
- )
- trainer.train_model()
-
-
-@pytest.mark.parametrize(
- "data_dir, config_name",
- [["data", "mfc"], ["data", "canvasvae"]],
-)
-def test_train_config(data_dir: str, config_name: str) -> None:
- config = get_global_config(data_dir, config_name)
- model_name, model_kwargs = get_model_config_input(config)
-
- model = create_model(
- model_name,
- **model_kwargs,
- )
- font_config = get_font_config(config)
- optimizer_option = config.train_config.optimizer
- prefix_list_object = get_prefix_lists(config)
- prediction_config_element = config.text_element_prediction_attribute_config
-
- dataset, dataset_val = build_train_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- debug=True,
- )
-
- save_dir = "job"
- trainer = Trainer(
- model,
- False,
- save_dir,
- dataset,
- dataset_val,
- prefix_list_object,
- prediction_config_element,
- optimizer_option=optimizer_option,
- debug=True,
- epochs=1,
- )
- trainer.train_model()
-
-
-@pytest.mark.parametrize(
- "data_dir, save_dir, config_name, elementembeddingflagconfigname, canvasembeddingflagconfigname, elementpredictionflagconfigname",
- [
- [
- "data",
- "job",
- "bart",
- "text_element_embedding_flag_config/wofontsize",
- "canvas_embedding_flag_config/canvas_detail_given",
- "text_element_prediction_flag_config/rawfontsize",
- ],
- ],
-)
-def test_flag_config_train(
- data_dir: str,
- save_dir: str,
- config_name: str,
- elementembeddingflagconfigname: str,
- canvasembeddingflagconfigname: str,
- elementpredictionflagconfigname,
-) -> None:
- global_config_input = {}
- global_config_input["data_dir"] = data_dir
- global_config_input["model_name"] = config_name
- global_config_input["test_config_name"] = "test_config"
- global_config_input["model_config_name"] = "model_config"
- global_config_input[
- "elementembeddingflag_config_name"
- ] = elementembeddingflagconfigname
- global_config_input[
- "canvasembeddingflag_config_name"
- ] = canvasembeddingflagconfigname
- global_config_input[
- "elementpredictionflag_config_name"
- ] = elementpredictionflagconfigname
-
- config = get_global_config(**global_config_input)
- model_name, model_kwargs = get_model_config_input(config)
-
- model = create_model(
- model_name,
- **model_kwargs,
- )
- prefix_list_object = get_prefix_lists(config)
- font_config = get_font_config(config)
- prediction_config_element = config.text_element_prediction_attribute_config
- optimizer_option = config.train_config.optimizer
-
- dataset, dataset_val = build_train_dataset(
- data_dir,
- prefix_list_object,
- font_config,
- debug=True,
- )
-
- gpu = False
- trainer = Trainer(
- model,
- False,
- save_dir,
- dataset,
- dataset_val,
- prefix_list_object,
- prediction_config_element,
- optimizer_option=optimizer_option,
- debug=True,
- epochs=1,
- )
- trainer.train_model()