Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
chen-yingfa committed Aug 14, 2024
0 parents commit 2e253c0
Show file tree
Hide file tree
Showing 31 changed files with 2,425 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
__pycache__
*.pyc
.DS_Store
44 changes: 44 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Chujian 楚简

<div align="center">
<a href="https://huggingface.co/datasets/chen-yingfa/CHUBS">Dataset</a> | Paper (Upcoming)
</div>

<br>

This repository contains the official code for the paper [(Upcoming link)](https://arxiv.org/abs/).

Chu bamboo slips (CBS, Chinese: 楚简, pronounced as *chujian*) is a script used during the Spring and Autumn period of Ancient China roughly 2,000 years ago. The study of which hold great significant value for understanding the history and culture of Ancient China. We scraped, processed, annotated, and then released the first large-scale dataset of CBS characters, named CHUBS, with over 100K annotated CBS characters. Additionally, we propose a novel multi-model multi-granularity tokenizer tailored for handling the large number of out-of-vocabulary characters in CBS (characters that have no modern Chinese equivalence).

## Data

All our datasets are provided at <https://huggingface.co/datasets/chen-yingfa/CHUBS>

It contains the following two parts.

1. The main dataset (CHUBS)
2. A small part-of-speech (POS) tagging dataset of CBS text.

And, the file structure is as follows.

```
- glyphs.zip
- pos-tagging-data/
- dev_examples.json
- dev_examples_subchars.json
- test_examples.json
- test_exampels_subchars.json
- train_examples.json
- train_examples_subchars.json
```

CHUBS is contained within the `glyphs.zip` file while the latter is under the directory `pos-tagging-data`. The files with the `_subchars` suffix is the same data examples as the original file but each character is split into its sub-character components (see the paper for more details).

## Experiments

This repo contains three sets of experiments, each under its own directory.

1. [Character Recognition](char-recognition/README.md)
2. [Sub-Character Component Recognition](subchar-recognition/README.md)
3. [POS Tagging](pos-tagging/README.md)

1 change: 1 addition & 0 deletions char-recognition/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
result
32 changes: 32 additions & 0 deletions char-recognition/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Character Recognition

This directory implements the character recognition task for CBS (Chu Bamboo Slip) characters. Given an image of a CBS character, the objective is to classify it into one of the pre-specified set of labels. This is a standard image classification task.

## Data Processing

This task operates on the [CHUBS data](https://huggingface.co/datasets/chen-yingfa/CHUBS) (the `glyphs.zip` directory). Each example consists of an image of a CBS character and its corresponding label. The labels are the names of the directory that the image file belongs to.

Before running this code, we pre-process the data by running, for example, `python merge_classes.py -k 3 --src_dir path/to/data --dst_dir ./data`, where `--src_dir` specifies the directory containing the raw data and `--dst_dir` specifies the directory to save the processed data.

This code will perform the following.

1. Merging some similar classes. For instance, some variations of the same character are labeled differently, but for image classification, we just want to classify them into the same class.
2. Removing classes with less than $k$ examples. In the paper we used 2, 3, 10, 20.

This will generate some JSON files contained the examples used for training, validation, and testing the character recognizer into the directory specified through `--dst_dir`.

## Training

To train the ViT model, execute the following command.

```shell
python train.py --model vit --model_name vit_base_patch16_224_in21k --data_dir path/to/processed/data --device cuda
```

To train a ResNet-50, use:

```shell
python train.py --model resnet --model_name resnet50 --data_dir path/to/processed/data --device cuda
```

For more options, see the `args.py` file or execute `python -h train.py`.
28 changes: 28 additions & 0 deletions char-recognition/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from argparse import Namespace, ArgumentParser


def parse_args() -> Namespace:
p = ArgumentParser()
p.add_argument("--lr", type=float, default=0.0001)
p.add_argument("--lr_gamma", type=float, default=0.8)
p.add_argument("--batch_size", type=int, default=64)
p.add_argument("--num_epochs", type=int, default=16)
p.add_argument("--mode", default="train_test")
p.add_argument(
"--data_dir",
default="../data/glyphs_k-10",
)
p.add_argument(
"--output_dir",
default="./result/glyphs_k-10",
)
p.add_argument("--pretrained", type=bool, default=True)
p.add_argument(
"--model_name",
default="vit_base_patch16_224_in21k",
choices=["vit_base_patch16_224_in21k", "resnet50"],
)
p.add_argument("--model", default="vit", choices=["vit", "resnet"])
p.add_argument("--log_interval", type=int, default=10)
p.add_argument("--device", type=str, default="cuda")
return p.parse_args()
72 changes: 72 additions & 0 deletions char-recognition/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import json
from pathlib import Path
import random

from torch.utils.data import Dataset
from PIL import Image


class ChujianDataset(Dataset):
"""
Replacement for ImageFolder that supports empty subdirs.
Args:
root (string): Root directory path.
transform (callable, optional):
A function/transform that takes in an PIL image.
shuffle (bool, optional): Whether to shuffle the dataset.
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""

def __init__(
self,
glyph_to_files_path: Path,
transform=None,
shuffle: bool = False,
):
super().__init__()
self.glyph_to_files_path = Path(glyph_to_files_path)
self.transform = transform
self.shuffle = shuffle

# Loop through root directory and get all classes and image paths.
self.glyph_to_files = json.load(
open(self.glyph_to_files_path, 'r', encoding='utf8'))
self.imgs = []
self.classes = []
self.class_to_idx = {}
for glyph, files in self.glyph_to_files.items():
self.classes.append(glyph)
cls_idx = len(self.classes) - 1

image_paths = sorted(files)

# # Always pick 1000 images from each class.
# class_size = max(100, len(image_paths))
# image_paths = random.choices(image_paths, k=class_size)

for image_path in image_paths:
self.imgs.append((image_path, cls_idx))
if shuffle:
random.shuffle(self.imgs)

# # Duplicate all images to make the dataset balanced.
# for idx in range(len(self.imgs)):
# self.imgs.append(self.imgs[idx])

def __getitem__(self, idx: int) -> tuple:
"""
Return (image, class_index)
"""
image, label = self.imgs[idx]
image = Image.open(image)
if self.transform:
image = self.transform(image)
return image, label

def __len__(self):
return len(self.imgs)
204 changes: 204 additions & 0 deletions char-recognition/merge_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
from argparse import ArgumentParser
from pathlib import Path
import json
import random
from collections import defaultdict
from typing import Dict, List, Tuple


def dict_size(d) -> int:
return sum(len(v) for v in d.values())


def dump_json(data, file):
json.dump(
data,
open(file, 'w', encoding='utf8'),
ensure_ascii=False,
indent=4,
)


def merged_glyphs(glyphs: List[str]) -> defaultdict[str, list]:
'''
Return {new_name: [old_names]}
'''
# Map new glyph name to old glyph name
new_to_old_name = defaultdict(list)
# Discard all glyphs containing these chars (after preprocessing label)
DISCARD_CHARS = [
'?'
'□', '■',
'○', '●',
'△', '▲',
'☆', '★',
'◇', '◆',
'□'
]

for glyph in glyphs:
orig = glyph
# Normalize the glyph label
RM_STRS = [
'=', 'None'
]
for c in RM_STRS:
glyph = glyph.replace(c, '')

# Replace brackets
for c in ['(', '〈', '[']:
glyph = glyph.replace(c, '(')
for c in [')', '〉', ']']:
glyph = glyph.replace(c, ')')

if glyph == '':
continue

if glyph[-1] == ')':
for i in range(len(glyph) - 2, -1, -1):
if glyph[i] == '(':
# "(*)"
if glyph[i] == '(':
if glyph[i+1:-1] == '○':
glyph = glyph[:i]
else:
glyph = glyph[i+1:-1]
else:
# "*}(*)"
if glyph[i-1] == '}':
glyph = glyph[i+1:-1]
# "A(*)" -> "A"
else:
glyph = glyph[0]
break
else:
glyph = glyph[:-1]
# "A→B"
if '→' in glyph:
glyph = glyph.split('→')[1]
if glyph == '𬨭':
glyph = '將'
if glyph == '𫵖':
glyph = '尸示'

if any(c in glyph for c in DISCARD_CHARS):
# if '○' in glyph:
# print(orig)
# print(glyph)
# exit()
continue
new_to_old_name[glyph].append(orig)
return new_to_old_name


def get_glyph_to_files(src_dir: Path) -> Dict[str, List[str]]:
glyph_to_files = {}
for src_glyph_dir in src_dir.iterdir():
if not src_glyph_dir.is_dir():
continue
name = src_glyph_dir.name
glyph_to_files[name] = []
for file in src_glyph_dir.iterdir():
glyph_to_files[name].append(str(file))
return glyph_to_files


def split_data(glyph_to_files: dict) -> Tuple[dict, dict, dict]:
# Split into train and test set
# For each class, split by 8:1:1 ratio.
print("Splitting...")
train_images = {}
dev_images = {}
test_images = {}
for glyph, image_files in glyph_to_files.items():
random.seed(0)
random.shuffle(image_files)
# Floored to make sure test and dev has at least one example.
split_idx = [
int(len(image_files) * 0.9),
int(len(image_files) * 0.95),
]
train_images[glyph] = image_files[: split_idx[0]]
dev_images[glyph] = image_files[split_idx[0]:split_idx[1]]
test_images[glyph] = image_files[split_idx[1]:]
return train_images, dev_images, test_images


def split_and_dump(glyph_to_files: dict, dst_dir: Path):
train_images, dev_images, test_images = split_data(glyph_to_files)

print("Train images:", dict_size(train_images))
print("Dev images:", dict_size(dev_images))
print("Test images:", dict_size(test_images))

dump_json(train_images, dst_dir / "train.json")
dump_json(dev_images, dst_dir / "dev.json")
dump_json(test_images, dst_dir / "test.json")


def merge_and_dump(src_dir: Path, dst_dir: Path, k: int):
print('Getting glyph to files...')
orig_glyph_to_files = get_glyph_to_files(src_dir)
print('Merging glyphs...')
print(f'Before: {len(orig_glyph_to_files)} glyphs')
glyphs = list(orig_glyph_to_files.keys())
new_to_old_name = merged_glyphs(glyphs)
print(f'After: {len(new_to_old_name)} glyphs')

# Building new glyph to files
glyph_to_files = {}
for new_name, old_names in new_to_old_name.items():
glyph_to_files[new_name] = []
for old_name in old_names:
glyph_to_files[new_name].extend(orig_glyph_to_files[old_name])

num_examples = sum(len(files) for files in glyph_to_files.values())
print(f'Found {num_examples} examples')

# Remove the glyphs with less than k samples
print(f'Removing glyphs with less than {k} samples...')
glyph_to_files = {
glyph: files
for glyph, files in glyph_to_files.items() if len(files) >= k}
print(f'After: {len(glyph_to_files)} glyphs')
num_examples = sum(len(files) for files in glyph_to_files.values())
print(f'Found {num_examples} examples')

glyph_to_cnt = {k: len(v) for k, v in glyph_to_files.items()}

# Sort by descending count
merged_sorted = sorted(
glyph_to_cnt.items(), key=lambda x: x[1], reverse=True)
glyph_to_cnt = {k: v for k, v in merged_sorted}

glyph_to_cnt_file = dst_dir / 'glyph_to_count_sorted.json'
print(f'Dumping to {glyph_to_cnt_file}')
dump_json(glyph_to_cnt, glyph_to_cnt_file)

merged_to_orig_file = dst_dir / 'new_to_orig_name.json'
print(f'Dumping to {merged_to_orig_file}')
dump_json(new_to_old_name, merged_to_orig_file)

glyph_to_files_file = dst_dir / "glyph_to_files.json"
print(f'Dumping to {glyph_to_files_file}')
dump_json(glyph_to_files, glyph_to_files_file)

# Split and dump
split_and_dump(glyph_to_files, dst_dir)


def main():
p = ArgumentParser()
p.add_argument('-k', type=int, default=0)
p.add_argument('--src_dir', type=str, default='../data')
p.add_argument('--dst_dir', type=str, default='./data')
args = p.parse_args()
src_dir = args.src_dir / 'glyphs'
print(f'========= k = {args.k} =========')
dst_dir = args.dst_dir / f'glyphs_k-{args.k}'
dst_dir.mkdir(exist_ok=True, parents=True)
merge_and_dump(src_dir, dst_dir, args.k)


if __name__ == '__main__':
main()
Loading

0 comments on commit 2e253c0

Please sign in to comment.