-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 2e253c0
Showing
31 changed files
with
2,425 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__pycache__ | ||
*.pyc | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Chujian 楚简 | ||
|
||
<div align="center"> | ||
<a href="https://huggingface.co/datasets/chen-yingfa/CHUBS">Dataset</a> | Paper (Upcoming) | ||
</div> | ||
|
||
<br> | ||
|
||
This repository contains the official code for the paper [(Upcoming link)](https://arxiv.org/abs/). | ||
|
||
Chu bamboo slips (CBS, Chinese: 楚简, pronounced as *chujian*) is a script used during the Spring and Autumn period of Ancient China roughly 2,000 years ago. The study of which hold great significant value for understanding the history and culture of Ancient China. We scraped, processed, annotated, and then released the first large-scale dataset of CBS characters, named CHUBS, with over 100K annotated CBS characters. Additionally, we propose a novel multi-model multi-granularity tokenizer tailored for handling the large number of out-of-vocabulary characters in CBS (characters that have no modern Chinese equivalence). | ||
|
||
## Data | ||
|
||
All our datasets are provided at <https://huggingface.co/datasets/chen-yingfa/CHUBS> | ||
|
||
It contains the following two parts. | ||
|
||
1. The main dataset (CHUBS) | ||
2. A small part-of-speech (POS) tagging dataset of CBS text. | ||
|
||
And, the file structure is as follows. | ||
|
||
``` | ||
- glyphs.zip | ||
- pos-tagging-data/ | ||
- dev_examples.json | ||
- dev_examples_subchars.json | ||
- test_examples.json | ||
- test_exampels_subchars.json | ||
- train_examples.json | ||
- train_examples_subchars.json | ||
``` | ||
|
||
CHUBS is contained within the `glyphs.zip` file while the latter is under the directory `pos-tagging-data`. The files with the `_subchars` suffix is the same data examples as the original file but each character is split into its sub-character components (see the paper for more details). | ||
|
||
## Experiments | ||
|
||
This repo contains three sets of experiments, each under its own directory. | ||
|
||
1. [Character Recognition](char-recognition/README.md) | ||
2. [Sub-Character Component Recognition](subchar-recognition/README.md) | ||
3. [POS Tagging](pos-tagging/README.md) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Character Recognition | ||
|
||
This directory implements the character recognition task for CBS (Chu Bamboo Slip) characters. Given an image of a CBS character, the objective is to classify it into one of the pre-specified set of labels. This is a standard image classification task. | ||
|
||
## Data Processing | ||
|
||
This task operates on the [CHUBS data](https://huggingface.co/datasets/chen-yingfa/CHUBS) (the `glyphs.zip` directory). Each example consists of an image of a CBS character and its corresponding label. The labels are the names of the directory that the image file belongs to. | ||
|
||
Before running this code, we pre-process the data by running, for example, `python merge_classes.py -k 3 --src_dir path/to/data --dst_dir ./data`, where `--src_dir` specifies the directory containing the raw data and `--dst_dir` specifies the directory to save the processed data. | ||
|
||
This code will perform the following. | ||
|
||
1. Merging some similar classes. For instance, some variations of the same character are labeled differently, but for image classification, we just want to classify them into the same class. | ||
2. Removing classes with less than $k$ examples. In the paper we used 2, 3, 10, 20. | ||
|
||
This will generate some JSON files contained the examples used for training, validation, and testing the character recognizer into the directory specified through `--dst_dir`. | ||
|
||
## Training | ||
|
||
To train the ViT model, execute the following command. | ||
|
||
```shell | ||
python train.py --model vit --model_name vit_base_patch16_224_in21k --data_dir path/to/processed/data --device cuda | ||
``` | ||
|
||
To train a ResNet-50, use: | ||
|
||
```shell | ||
python train.py --model resnet --model_name resnet50 --data_dir path/to/processed/data --device cuda | ||
``` | ||
|
||
For more options, see the `args.py` file or execute `python -h train.py`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from argparse import Namespace, ArgumentParser | ||
|
||
|
||
def parse_args() -> Namespace: | ||
p = ArgumentParser() | ||
p.add_argument("--lr", type=float, default=0.0001) | ||
p.add_argument("--lr_gamma", type=float, default=0.8) | ||
p.add_argument("--batch_size", type=int, default=64) | ||
p.add_argument("--num_epochs", type=int, default=16) | ||
p.add_argument("--mode", default="train_test") | ||
p.add_argument( | ||
"--data_dir", | ||
default="../data/glyphs_k-10", | ||
) | ||
p.add_argument( | ||
"--output_dir", | ||
default="./result/glyphs_k-10", | ||
) | ||
p.add_argument("--pretrained", type=bool, default=True) | ||
p.add_argument( | ||
"--model_name", | ||
default="vit_base_patch16_224_in21k", | ||
choices=["vit_base_patch16_224_in21k", "resnet50"], | ||
) | ||
p.add_argument("--model", default="vit", choices=["vit", "resnet"]) | ||
p.add_argument("--log_interval", type=int, default=10) | ||
p.add_argument("--device", type=str, default="cuda") | ||
return p.parse_args() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import json | ||
from pathlib import Path | ||
import random | ||
|
||
from torch.utils.data import Dataset | ||
from PIL import Image | ||
|
||
|
||
class ChujianDataset(Dataset): | ||
""" | ||
Replacement for ImageFolder that supports empty subdirs. | ||
Args: | ||
root (string): Root directory path. | ||
transform (callable, optional): | ||
A function/transform that takes in an PIL image. | ||
shuffle (bool, optional): Whether to shuffle the dataset. | ||
Attributes: | ||
classes (list): List of the class names sorted alphabetically. | ||
class_to_idx (dict): Dict with items (class_name, class_index). | ||
imgs (list): List of (image path, class_index) tuples | ||
""" | ||
|
||
def __init__( | ||
self, | ||
glyph_to_files_path: Path, | ||
transform=None, | ||
shuffle: bool = False, | ||
): | ||
super().__init__() | ||
self.glyph_to_files_path = Path(glyph_to_files_path) | ||
self.transform = transform | ||
self.shuffle = shuffle | ||
|
||
# Loop through root directory and get all classes and image paths. | ||
self.glyph_to_files = json.load( | ||
open(self.glyph_to_files_path, 'r', encoding='utf8')) | ||
self.imgs = [] | ||
self.classes = [] | ||
self.class_to_idx = {} | ||
for glyph, files in self.glyph_to_files.items(): | ||
self.classes.append(glyph) | ||
cls_idx = len(self.classes) - 1 | ||
|
||
image_paths = sorted(files) | ||
|
||
# # Always pick 1000 images from each class. | ||
# class_size = max(100, len(image_paths)) | ||
# image_paths = random.choices(image_paths, k=class_size) | ||
|
||
for image_path in image_paths: | ||
self.imgs.append((image_path, cls_idx)) | ||
if shuffle: | ||
random.shuffle(self.imgs) | ||
|
||
# # Duplicate all images to make the dataset balanced. | ||
# for idx in range(len(self.imgs)): | ||
# self.imgs.append(self.imgs[idx]) | ||
|
||
def __getitem__(self, idx: int) -> tuple: | ||
""" | ||
Return (image, class_index) | ||
""" | ||
image, label = self.imgs[idx] | ||
image = Image.open(image) | ||
if self.transform: | ||
image = self.transform(image) | ||
return image, label | ||
|
||
def __len__(self): | ||
return len(self.imgs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
from argparse import ArgumentParser | ||
from pathlib import Path | ||
import json | ||
import random | ||
from collections import defaultdict | ||
from typing import Dict, List, Tuple | ||
|
||
|
||
def dict_size(d) -> int: | ||
return sum(len(v) for v in d.values()) | ||
|
||
|
||
def dump_json(data, file): | ||
json.dump( | ||
data, | ||
open(file, 'w', encoding='utf8'), | ||
ensure_ascii=False, | ||
indent=4, | ||
) | ||
|
||
|
||
def merged_glyphs(glyphs: List[str]) -> defaultdict[str, list]: | ||
''' | ||
Return {new_name: [old_names]} | ||
''' | ||
# Map new glyph name to old glyph name | ||
new_to_old_name = defaultdict(list) | ||
# Discard all glyphs containing these chars (after preprocessing label) | ||
DISCARD_CHARS = [ | ||
'?' | ||
'□', '■', | ||
'○', '●', | ||
'△', '▲', | ||
'☆', '★', | ||
'◇', '◆', | ||
'□' | ||
] | ||
|
||
for glyph in glyphs: | ||
orig = glyph | ||
# Normalize the glyph label | ||
RM_STRS = [ | ||
'=', 'None' | ||
] | ||
for c in RM_STRS: | ||
glyph = glyph.replace(c, '') | ||
|
||
# Replace brackets | ||
for c in ['(', '〈', '[']: | ||
glyph = glyph.replace(c, '(') | ||
for c in [')', '〉', ']']: | ||
glyph = glyph.replace(c, ')') | ||
|
||
if glyph == '': | ||
continue | ||
|
||
if glyph[-1] == ')': | ||
for i in range(len(glyph) - 2, -1, -1): | ||
if glyph[i] == '(': | ||
# "(*)" | ||
if glyph[i] == '(': | ||
if glyph[i+1:-1] == '○': | ||
glyph = glyph[:i] | ||
else: | ||
glyph = glyph[i+1:-1] | ||
else: | ||
# "*}(*)" | ||
if glyph[i-1] == '}': | ||
glyph = glyph[i+1:-1] | ||
# "A(*)" -> "A" | ||
else: | ||
glyph = glyph[0] | ||
break | ||
else: | ||
glyph = glyph[:-1] | ||
# "A→B" | ||
if '→' in glyph: | ||
glyph = glyph.split('→')[1] | ||
if glyph == '𬨭': | ||
glyph = '將' | ||
if glyph == '𫵖': | ||
glyph = '尸示' | ||
|
||
if any(c in glyph for c in DISCARD_CHARS): | ||
# if '○' in glyph: | ||
# print(orig) | ||
# print(glyph) | ||
# exit() | ||
continue | ||
new_to_old_name[glyph].append(orig) | ||
return new_to_old_name | ||
|
||
|
||
def get_glyph_to_files(src_dir: Path) -> Dict[str, List[str]]: | ||
glyph_to_files = {} | ||
for src_glyph_dir in src_dir.iterdir(): | ||
if not src_glyph_dir.is_dir(): | ||
continue | ||
name = src_glyph_dir.name | ||
glyph_to_files[name] = [] | ||
for file in src_glyph_dir.iterdir(): | ||
glyph_to_files[name].append(str(file)) | ||
return glyph_to_files | ||
|
||
|
||
def split_data(glyph_to_files: dict) -> Tuple[dict, dict, dict]: | ||
# Split into train and test set | ||
# For each class, split by 8:1:1 ratio. | ||
print("Splitting...") | ||
train_images = {} | ||
dev_images = {} | ||
test_images = {} | ||
for glyph, image_files in glyph_to_files.items(): | ||
random.seed(0) | ||
random.shuffle(image_files) | ||
# Floored to make sure test and dev has at least one example. | ||
split_idx = [ | ||
int(len(image_files) * 0.9), | ||
int(len(image_files) * 0.95), | ||
] | ||
train_images[glyph] = image_files[: split_idx[0]] | ||
dev_images[glyph] = image_files[split_idx[0]:split_idx[1]] | ||
test_images[glyph] = image_files[split_idx[1]:] | ||
return train_images, dev_images, test_images | ||
|
||
|
||
def split_and_dump(glyph_to_files: dict, dst_dir: Path): | ||
train_images, dev_images, test_images = split_data(glyph_to_files) | ||
|
||
print("Train images:", dict_size(train_images)) | ||
print("Dev images:", dict_size(dev_images)) | ||
print("Test images:", dict_size(test_images)) | ||
|
||
dump_json(train_images, dst_dir / "train.json") | ||
dump_json(dev_images, dst_dir / "dev.json") | ||
dump_json(test_images, dst_dir / "test.json") | ||
|
||
|
||
def merge_and_dump(src_dir: Path, dst_dir: Path, k: int): | ||
print('Getting glyph to files...') | ||
orig_glyph_to_files = get_glyph_to_files(src_dir) | ||
print('Merging glyphs...') | ||
print(f'Before: {len(orig_glyph_to_files)} glyphs') | ||
glyphs = list(orig_glyph_to_files.keys()) | ||
new_to_old_name = merged_glyphs(glyphs) | ||
print(f'After: {len(new_to_old_name)} glyphs') | ||
|
||
# Building new glyph to files | ||
glyph_to_files = {} | ||
for new_name, old_names in new_to_old_name.items(): | ||
glyph_to_files[new_name] = [] | ||
for old_name in old_names: | ||
glyph_to_files[new_name].extend(orig_glyph_to_files[old_name]) | ||
|
||
num_examples = sum(len(files) for files in glyph_to_files.values()) | ||
print(f'Found {num_examples} examples') | ||
|
||
# Remove the glyphs with less than k samples | ||
print(f'Removing glyphs with less than {k} samples...') | ||
glyph_to_files = { | ||
glyph: files | ||
for glyph, files in glyph_to_files.items() if len(files) >= k} | ||
print(f'After: {len(glyph_to_files)} glyphs') | ||
num_examples = sum(len(files) for files in glyph_to_files.values()) | ||
print(f'Found {num_examples} examples') | ||
|
||
glyph_to_cnt = {k: len(v) for k, v in glyph_to_files.items()} | ||
|
||
# Sort by descending count | ||
merged_sorted = sorted( | ||
glyph_to_cnt.items(), key=lambda x: x[1], reverse=True) | ||
glyph_to_cnt = {k: v for k, v in merged_sorted} | ||
|
||
glyph_to_cnt_file = dst_dir / 'glyph_to_count_sorted.json' | ||
print(f'Dumping to {glyph_to_cnt_file}') | ||
dump_json(glyph_to_cnt, glyph_to_cnt_file) | ||
|
||
merged_to_orig_file = dst_dir / 'new_to_orig_name.json' | ||
print(f'Dumping to {merged_to_orig_file}') | ||
dump_json(new_to_old_name, merged_to_orig_file) | ||
|
||
glyph_to_files_file = dst_dir / "glyph_to_files.json" | ||
print(f'Dumping to {glyph_to_files_file}') | ||
dump_json(glyph_to_files, glyph_to_files_file) | ||
|
||
# Split and dump | ||
split_and_dump(glyph_to_files, dst_dir) | ||
|
||
|
||
def main(): | ||
p = ArgumentParser() | ||
p.add_argument('-k', type=int, default=0) | ||
p.add_argument('--src_dir', type=str, default='../data') | ||
p.add_argument('--dst_dir', type=str, default='./data') | ||
args = p.parse_args() | ||
src_dir = args.src_dir / 'glyphs' | ||
print(f'========= k = {args.k} =========') | ||
dst_dir = args.dst_dir / f'glyphs_k-{args.k}' | ||
dst_dir.mkdir(exist_ok=True, parents=True) | ||
merge_and_dump(src_dir, dst_dir, args.k) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.