Skip to content

Commit c36bf91

Browse files
committed
dev(narugo): first raw version, ci skip{
1 parent e2a84a9 commit c36bf91

File tree

15 files changed

+558
-0
lines changed

15 files changed

+558
-0
lines changed

.github/workflows/badge.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ jobs:
88
update-badges:
99
name: Update Badges
1010
runs-on: ubuntu-latest
11+
if: ${{ !contains(github.event.head_commit.message, 'ci skip') && !contains(github.event.head_commit.message, 'test skip') }}
1112
strategy:
1213
matrix:
1314
python-version: [ 3.8 ]

.github/workflows/release_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ jobs:
88
source_release:
99
name: Try package the source
1010
runs-on: ${{ matrix.os }}
11+
if: ${{ !contains(github.event.head_commit.message, 'ci skip') && !contains(github.event.head_commit.message, 'test skip') }}
1112
strategy:
1213
fail-fast: false
1314
matrix:

README.md

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,71 @@ Swiftly get tons of images from indexed tars on Huggingface
2222

2323
(Still WIP ...)
2424

25+
## Installation
26+
27+
```shell
28+
git clone https://github.com/deepghs/cheesechaser.git
29+
cd cheesechaser
30+
pip install -r requirements.txt
31+
```
32+
33+
## How this library works
34+
35+
This library is based on the mirror datasets on huggingface.
36+
37+
## Batch Download Images
38+
39+
* Danbooru
40+
41+
```python
42+
from cheesechaser.datapool import DanbooruStableDataPool
43+
44+
pool = DanbooruStableDataPool()
45+
46+
# download danbooru #2010000-2010300, to directory /data/exp2
47+
pool.batch_download_to_directory(
48+
resource_ids=range(2010000, 2010300),
49+
dst_dir='/data/exp2',
50+
max_workers=12,
51+
)
52+
```
53+
54+
* Konachan
55+
56+
```python
57+
from cheesechaser.datapool import KonachanDataPool
58+
59+
pool = KonachanDataPool()
60+
61+
# download konachan #210000-210300, to directory /data/exp2
62+
pool.batch_download_to_directory(
63+
resource_ids=range(210000, 210300),
64+
dst_dir='/data/exp2',
65+
max_workers=12,
66+
)
67+
```
68+
69+
* Civitai (this mirror repository on hf is private for now, you have to use hf token of an authorized account)
70+
71+
```python
72+
from cheesechaser.datapool import CivitaiDataPool
73+
74+
pool = CivitaiDataPool()
75+
76+
# download civitai #7810000-7810300, to directory /data/exp2
77+
# should contain one image and one json metadata file
78+
pool.batch_download_to_directory(
79+
resource_ids=range(7810000, 7810300),
80+
dst_dir='/data/exp2',
81+
max_workers=12,
82+
)
83+
```
84+
85+
More supported:
86+
87+
* `RealbooruDataPool`
88+
* `ThreedbooruDataPool`
89+
* `FancapsDataPool`
90+
* `BangumiBaseDataPool`
91+
* `AnimePicturesDataPool`
92+

cheesechaser/datapool/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .anime_pictures import AnimePicturesDataPool
2+
from .bangumibase import BangumiBaseDataPool
3+
from .base import DataLocation, InvalidResourceDataError, FileUnrecognizableError, GenericDataPool, IncrementIDDataPool, \
4+
ResourceNotFoundError
5+
from .civitai import CivitaiDataPool
6+
from .danbooru import DanbooruDataPool, DanbooruStableDataPool
7+
from .fancaps import FancapsDataPool
8+
from .image import ImageData, ImageJsonAttachedData, ImageOnlyDataPool, ImageJsonAttachedDataPool
9+
from .konachan import KonachanDataPool
10+
from .realbooru import RealbooruDataPool
11+
from .threedbooru import ThreedbooruDataPool
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .image import ImageOnlyDataPool
2+
3+
_ANIME_PICTURES_REPO = 'deepghs/anime_pictures_full'
4+
5+
6+
class AnimePicturesDataPool(ImageOnlyDataPool):
7+
def __init__(self, revision: str = 'main'):
8+
ImageOnlyDataPool.__init__(
9+
self,
10+
data_repo_id=_ANIME_PICTURES_REPO,
11+
data_revision=revision,
12+
idx_repo_id=_ANIME_PICTURES_REPO,
13+
idx_revision=revision,
14+
)

cheesechaser/datapool/bangumibase.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .image import ImageOnlyDataPool
2+
3+
_BANGUMIBASE_REPO = 'deepghs/bangumibase_full'
4+
5+
6+
class BangumiBaseDataPool(ImageOnlyDataPool):
7+
def __init__(self, revision: str = 'main'):
8+
ImageOnlyDataPool.__init__(
9+
self,
10+
data_repo_id=_BANGUMIBASE_REPO,
11+
data_revision=revision,
12+
idx_repo_id=_BANGUMIBASE_REPO,
13+
idx_revision=revision,
14+
)

cheesechaser/datapool/base.py

Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import logging
2+
import os
3+
import shutil
4+
from concurrent.futures import ThreadPoolExecutor
5+
from contextlib import contextmanager
6+
from dataclasses import dataclass
7+
from queue import Queue, Full
8+
from threading import Thread, Event
9+
from typing import List, Iterable, ContextManager, Tuple
10+
11+
from hbutils.system import TemporaryDirectory
12+
from hfutils.index import hf_tar_list_files, hf_tar_file_download
13+
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
14+
from tqdm import tqdm
15+
16+
17+
@dataclass
18+
class DataLocation:
19+
tar_file: str
20+
filename: str
21+
22+
23+
def _n_path(path):
24+
"""
25+
Normalize a file path.
26+
27+
:param path: The file path to normalize.
28+
:type path: str
29+
:return: The normalized file path.
30+
:rtype: str
31+
"""
32+
return os.path.normpath(os.path.join('/', path))
33+
34+
35+
class InvalidResourceDataError(Exception):
36+
pass
37+
38+
39+
class ResourceNotFoundError(InvalidResourceDataError):
40+
pass
41+
42+
43+
class FileUnrecognizableError(Exception):
44+
pass
45+
46+
47+
class GenericDataPool:
48+
def __init__(self, data_repo_id: str, data_revision: str = 'main',
49+
idx_repo_id: str = None, idx_revision: str = 'main'):
50+
self.data_repo_id = data_repo_id
51+
self.data_revision = data_revision
52+
53+
self.idx_repo_id = idx_repo_id or data_repo_id
54+
self.idx_revision = idx_revision
55+
56+
self._tar_infos = {}
57+
58+
def _file_to_resource_id(self, tar_file: str, body: str):
59+
raise NotImplementedError
60+
61+
def _make_tar_info(self, tar_file: str, force: bool = False):
62+
key = _n_path(tar_file)
63+
if force or key not in self._tar_infos:
64+
data = {}
65+
for file in hf_tar_list_files(
66+
repo_id=self.data_repo_id,
67+
repo_type='dataset',
68+
archive_in_repo=tar_file,
69+
revision=self.data_revision,
70+
71+
idx_repo_id=self.idx_repo_id,
72+
idx_repo_type='dataset',
73+
idx_revision=self.idx_revision,
74+
):
75+
try:
76+
resource_id = self._file_to_resource_id(tar_file, file)
77+
except FileUnrecognizableError:
78+
continue
79+
if resource_id not in data:
80+
data[resource_id] = []
81+
data[resource_id].append(file)
82+
self._tar_infos[key] = data
83+
84+
return self._tar_infos[key]
85+
86+
def _request_possible_archives(self, resource_id) -> List[str]:
87+
raise NotImplementedError
88+
89+
def _request_resource_by_id(self, resource_id) -> List[DataLocation]:
90+
for archive_file in self._request_possible_archives(resource_id):
91+
try:
92+
info = self._make_tar_info(archive_file, force=False)
93+
except (EntryNotFoundError, RepositoryNotFoundError):
94+
# no information found, skipped
95+
continue
96+
97+
if resource_id in info:
98+
return [
99+
DataLocation(tar_file=archive_file, filename=file)
100+
for file in info[resource_id]
101+
]
102+
else:
103+
return []
104+
105+
@contextmanager
106+
def _mock_resource(self, resource_id) -> ContextManager[str]:
107+
with TemporaryDirectory() as td:
108+
for location in self._request_resource_by_id(resource_id):
109+
dst_filename = os.path.join(td, os.path.basename(location.filename))
110+
hf_tar_file_download(
111+
repo_id=self.data_repo_id,
112+
repo_type='dataset',
113+
archive_in_repo=location.tar_file,
114+
file_in_archive=location.filename,
115+
local_file=dst_filename,
116+
revision=self.data_revision,
117+
118+
idx_repo_id=self.idx_repo_id,
119+
idx_repo_type='dataset',
120+
idx_revision=self.idx_revision,
121+
)
122+
yield td
123+
124+
def batch_download_to_directory(self, resource_ids, dst_dir: str, max_workers: int = 12):
125+
pg_res = tqdm(resource_ids, desc='Batch Downloading')
126+
pg_downloaded = tqdm(desc='Files Downloaded')
127+
128+
def _func(resource_id):
129+
try:
130+
with self._mock_resource(resource_id) as td:
131+
copied = False
132+
for root, dirs, files in os.walk(td):
133+
for file in files:
134+
src_file = os.path.abspath(os.path.join(root, file))
135+
dst_file = os.path.join(dst_dir, os.path.relpath(src_file, td))
136+
if os.path.dirname(dst_file):
137+
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
138+
shutil.copyfile(src_file, dst_file)
139+
pg_downloaded.update()
140+
copied = True
141+
142+
if not copied:
143+
logging.warning(f'No files found for resource {resource_id!r}.')
144+
145+
except Exception as err:
146+
logging.error(f'Error occurred when downloading resource {resource_id!r} - {err!r}')
147+
finally:
148+
pg_res.update()
149+
150+
tp = ThreadPoolExecutor(max_workers=max_workers)
151+
for rid in resource_ids:
152+
tp.submit(_func, rid)
153+
154+
tp.shutdown(wait=True)
155+
156+
def retrieve_resource_data(self, resource_id):
157+
raise NotImplementedError
158+
159+
def batch_retrieve_resource_queue(self, resource_ids, max_workers: int = 12) -> Tuple[Queue, Event]:
160+
pg = tqdm(resource_ids, desc='Batch Retrieving')
161+
queue = Queue(maxsize=max_workers * 3)
162+
is_stopped = Event()
163+
164+
def _func(resource_id):
165+
if is_stopped.is_set():
166+
return
167+
168+
try:
169+
try:
170+
data = self.retrieve_resource_data(resource_id)
171+
except ResourceNotFoundError:
172+
logging.warning(f'Resource {resource_id!r} not found.')
173+
return
174+
except InvalidResourceDataError as err:
175+
logging.warning(f'Resource {resource_id!r} is invalid - {err}.')
176+
return
177+
finally:
178+
pg.update()
179+
180+
while True:
181+
try:
182+
queue.put(data, block=True, timeout=1.0)
183+
except Full:
184+
if is_stopped.is_set():
185+
break
186+
continue
187+
else:
188+
break
189+
190+
except Exception as err:
191+
logging.error(f'Error occurred when retrieving resource {resource_id!r} - {err!r}')
192+
193+
def _productor():
194+
tp = ThreadPoolExecutor(max_workers=max_workers)
195+
for rid in resource_ids:
196+
if is_stopped.is_set():
197+
break
198+
tp.submit(_func, rid)
199+
200+
tp.shutdown(wait=True)
201+
is_stopped.set()
202+
203+
t_productor = Thread(target=_productor)
204+
t_productor.start()
205+
return queue, is_stopped
206+
207+
208+
def id_modulo_cut(id_text: str):
209+
id_text = id_text[::-1]
210+
data = []
211+
for i in range(0, len(id_text), 3):
212+
data.append(id_text[i:i + 3][::-1])
213+
return data[::-1]
214+
215+
216+
class IncrementIDDataPool(GenericDataPool):
217+
def __init__(self, data_repo_id: str, data_revision: str = 'main',
218+
idx_repo_id: str = None, idx_revision: str = 'main',
219+
base_level: int = 3, base_dir: str = 'images'):
220+
GenericDataPool.__init__(self, data_repo_id, data_revision, idx_repo_id, idx_revision)
221+
self.base_level = base_level
222+
self.base_dir = base_dir
223+
224+
def _file_to_resource_id(self, tar_file: str, filename: str):
225+
try:
226+
body, _ = os.path.splitext(os.path.basename(filename))
227+
return int(body)
228+
except (ValueError, TypeError):
229+
raise FileUnrecognizableError
230+
231+
def _request_possible_archives(self, resource_id) -> Iterable[str]:
232+
modulo = resource_id % (10 ** self.base_level)
233+
modulo_str = str(modulo)
234+
if len(modulo_str) < self.base_level:
235+
modulo_str = '0' * (self.base_level - len(modulo_str)) + modulo_str
236+
237+
modulo_segments = id_modulo_cut(modulo_str)
238+
modulo_segments[-1] = f'0{modulo_segments[-1]}'
239+
return [f'{self.base_dir}/{"/".join(modulo_segments)}.tar']

cheesechaser/datapool/civitai.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from .image import ImageJsonAttachedDataPool
2+
3+
_CIVITAI_REPO = 'deepghs/civitai_full'
4+
5+
6+
class CivitaiDataPool(ImageJsonAttachedDataPool):
7+
def __init__(self, revision: str = 'main'):
8+
ImageJsonAttachedDataPool.__init__(
9+
self,
10+
data_repo_id=_CIVITAI_REPO,
11+
data_revision=revision,
12+
idx_repo_id=_CIVITAI_REPO,
13+
idx_revision=revision,
14+
base_level=4,
15+
)

0 commit comments

Comments
 (0)