|
| 1 | +import logging |
| 2 | +import os |
| 3 | +import shutil |
| 4 | +from concurrent.futures import ThreadPoolExecutor |
| 5 | +from contextlib import contextmanager |
| 6 | +from dataclasses import dataclass |
| 7 | +from queue import Queue, Full |
| 8 | +from threading import Thread, Event |
| 9 | +from typing import List, Iterable, ContextManager, Tuple |
| 10 | + |
| 11 | +from hbutils.system import TemporaryDirectory |
| 12 | +from hfutils.index import hf_tar_list_files, hf_tar_file_download |
| 13 | +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError |
| 14 | +from tqdm import tqdm |
| 15 | + |
| 16 | + |
| 17 | +@dataclass |
| 18 | +class DataLocation: |
| 19 | + tar_file: str |
| 20 | + filename: str |
| 21 | + |
| 22 | + |
| 23 | +def _n_path(path): |
| 24 | + """ |
| 25 | + Normalize a file path. |
| 26 | +
|
| 27 | + :param path: The file path to normalize. |
| 28 | + :type path: str |
| 29 | + :return: The normalized file path. |
| 30 | + :rtype: str |
| 31 | + """ |
| 32 | + return os.path.normpath(os.path.join('/', path)) |
| 33 | + |
| 34 | + |
| 35 | +class InvalidResourceDataError(Exception): |
| 36 | + pass |
| 37 | + |
| 38 | + |
| 39 | +class ResourceNotFoundError(InvalidResourceDataError): |
| 40 | + pass |
| 41 | + |
| 42 | + |
| 43 | +class FileUnrecognizableError(Exception): |
| 44 | + pass |
| 45 | + |
| 46 | + |
| 47 | +class GenericDataPool: |
| 48 | + def __init__(self, data_repo_id: str, data_revision: str = 'main', |
| 49 | + idx_repo_id: str = None, idx_revision: str = 'main'): |
| 50 | + self.data_repo_id = data_repo_id |
| 51 | + self.data_revision = data_revision |
| 52 | + |
| 53 | + self.idx_repo_id = idx_repo_id or data_repo_id |
| 54 | + self.idx_revision = idx_revision |
| 55 | + |
| 56 | + self._tar_infos = {} |
| 57 | + |
| 58 | + def _file_to_resource_id(self, tar_file: str, body: str): |
| 59 | + raise NotImplementedError |
| 60 | + |
| 61 | + def _make_tar_info(self, tar_file: str, force: bool = False): |
| 62 | + key = _n_path(tar_file) |
| 63 | + if force or key not in self._tar_infos: |
| 64 | + data = {} |
| 65 | + for file in hf_tar_list_files( |
| 66 | + repo_id=self.data_repo_id, |
| 67 | + repo_type='dataset', |
| 68 | + archive_in_repo=tar_file, |
| 69 | + revision=self.data_revision, |
| 70 | + |
| 71 | + idx_repo_id=self.idx_repo_id, |
| 72 | + idx_repo_type='dataset', |
| 73 | + idx_revision=self.idx_revision, |
| 74 | + ): |
| 75 | + try: |
| 76 | + resource_id = self._file_to_resource_id(tar_file, file) |
| 77 | + except FileUnrecognizableError: |
| 78 | + continue |
| 79 | + if resource_id not in data: |
| 80 | + data[resource_id] = [] |
| 81 | + data[resource_id].append(file) |
| 82 | + self._tar_infos[key] = data |
| 83 | + |
| 84 | + return self._tar_infos[key] |
| 85 | + |
| 86 | + def _request_possible_archives(self, resource_id) -> List[str]: |
| 87 | + raise NotImplementedError |
| 88 | + |
| 89 | + def _request_resource_by_id(self, resource_id) -> List[DataLocation]: |
| 90 | + for archive_file in self._request_possible_archives(resource_id): |
| 91 | + try: |
| 92 | + info = self._make_tar_info(archive_file, force=False) |
| 93 | + except (EntryNotFoundError, RepositoryNotFoundError): |
| 94 | + # no information found, skipped |
| 95 | + continue |
| 96 | + |
| 97 | + if resource_id in info: |
| 98 | + return [ |
| 99 | + DataLocation(tar_file=archive_file, filename=file) |
| 100 | + for file in info[resource_id] |
| 101 | + ] |
| 102 | + else: |
| 103 | + return [] |
| 104 | + |
| 105 | + @contextmanager |
| 106 | + def _mock_resource(self, resource_id) -> ContextManager[str]: |
| 107 | + with TemporaryDirectory() as td: |
| 108 | + for location in self._request_resource_by_id(resource_id): |
| 109 | + dst_filename = os.path.join(td, os.path.basename(location.filename)) |
| 110 | + hf_tar_file_download( |
| 111 | + repo_id=self.data_repo_id, |
| 112 | + repo_type='dataset', |
| 113 | + archive_in_repo=location.tar_file, |
| 114 | + file_in_archive=location.filename, |
| 115 | + local_file=dst_filename, |
| 116 | + revision=self.data_revision, |
| 117 | + |
| 118 | + idx_repo_id=self.idx_repo_id, |
| 119 | + idx_repo_type='dataset', |
| 120 | + idx_revision=self.idx_revision, |
| 121 | + ) |
| 122 | + yield td |
| 123 | + |
| 124 | + def batch_download_to_directory(self, resource_ids, dst_dir: str, max_workers: int = 12): |
| 125 | + pg_res = tqdm(resource_ids, desc='Batch Downloading') |
| 126 | + pg_downloaded = tqdm(desc='Files Downloaded') |
| 127 | + |
| 128 | + def _func(resource_id): |
| 129 | + try: |
| 130 | + with self._mock_resource(resource_id) as td: |
| 131 | + copied = False |
| 132 | + for root, dirs, files in os.walk(td): |
| 133 | + for file in files: |
| 134 | + src_file = os.path.abspath(os.path.join(root, file)) |
| 135 | + dst_file = os.path.join(dst_dir, os.path.relpath(src_file, td)) |
| 136 | + if os.path.dirname(dst_file): |
| 137 | + os.makedirs(os.path.dirname(dst_file), exist_ok=True) |
| 138 | + shutil.copyfile(src_file, dst_file) |
| 139 | + pg_downloaded.update() |
| 140 | + copied = True |
| 141 | + |
| 142 | + if not copied: |
| 143 | + logging.warning(f'No files found for resource {resource_id!r}.') |
| 144 | + |
| 145 | + except Exception as err: |
| 146 | + logging.error(f'Error occurred when downloading resource {resource_id!r} - {err!r}') |
| 147 | + finally: |
| 148 | + pg_res.update() |
| 149 | + |
| 150 | + tp = ThreadPoolExecutor(max_workers=max_workers) |
| 151 | + for rid in resource_ids: |
| 152 | + tp.submit(_func, rid) |
| 153 | + |
| 154 | + tp.shutdown(wait=True) |
| 155 | + |
| 156 | + def retrieve_resource_data(self, resource_id): |
| 157 | + raise NotImplementedError |
| 158 | + |
| 159 | + def batch_retrieve_resource_queue(self, resource_ids, max_workers: int = 12) -> Tuple[Queue, Event]: |
| 160 | + pg = tqdm(resource_ids, desc='Batch Retrieving') |
| 161 | + queue = Queue(maxsize=max_workers * 3) |
| 162 | + is_stopped = Event() |
| 163 | + |
| 164 | + def _func(resource_id): |
| 165 | + if is_stopped.is_set(): |
| 166 | + return |
| 167 | + |
| 168 | + try: |
| 169 | + try: |
| 170 | + data = self.retrieve_resource_data(resource_id) |
| 171 | + except ResourceNotFoundError: |
| 172 | + logging.warning(f'Resource {resource_id!r} not found.') |
| 173 | + return |
| 174 | + except InvalidResourceDataError as err: |
| 175 | + logging.warning(f'Resource {resource_id!r} is invalid - {err}.') |
| 176 | + return |
| 177 | + finally: |
| 178 | + pg.update() |
| 179 | + |
| 180 | + while True: |
| 181 | + try: |
| 182 | + queue.put(data, block=True, timeout=1.0) |
| 183 | + except Full: |
| 184 | + if is_stopped.is_set(): |
| 185 | + break |
| 186 | + continue |
| 187 | + else: |
| 188 | + break |
| 189 | + |
| 190 | + except Exception as err: |
| 191 | + logging.error(f'Error occurred when retrieving resource {resource_id!r} - {err!r}') |
| 192 | + |
| 193 | + def _productor(): |
| 194 | + tp = ThreadPoolExecutor(max_workers=max_workers) |
| 195 | + for rid in resource_ids: |
| 196 | + if is_stopped.is_set(): |
| 197 | + break |
| 198 | + tp.submit(_func, rid) |
| 199 | + |
| 200 | + tp.shutdown(wait=True) |
| 201 | + is_stopped.set() |
| 202 | + |
| 203 | + t_productor = Thread(target=_productor) |
| 204 | + t_productor.start() |
| 205 | + return queue, is_stopped |
| 206 | + |
| 207 | + |
| 208 | +def id_modulo_cut(id_text: str): |
| 209 | + id_text = id_text[::-1] |
| 210 | + data = [] |
| 211 | + for i in range(0, len(id_text), 3): |
| 212 | + data.append(id_text[i:i + 3][::-1]) |
| 213 | + return data[::-1] |
| 214 | + |
| 215 | + |
| 216 | +class IncrementIDDataPool(GenericDataPool): |
| 217 | + def __init__(self, data_repo_id: str, data_revision: str = 'main', |
| 218 | + idx_repo_id: str = None, idx_revision: str = 'main', |
| 219 | + base_level: int = 3, base_dir: str = 'images'): |
| 220 | + GenericDataPool.__init__(self, data_repo_id, data_revision, idx_repo_id, idx_revision) |
| 221 | + self.base_level = base_level |
| 222 | + self.base_dir = base_dir |
| 223 | + |
| 224 | + def _file_to_resource_id(self, tar_file: str, filename: str): |
| 225 | + try: |
| 226 | + body, _ = os.path.splitext(os.path.basename(filename)) |
| 227 | + return int(body) |
| 228 | + except (ValueError, TypeError): |
| 229 | + raise FileUnrecognizableError |
| 230 | + |
| 231 | + def _request_possible_archives(self, resource_id) -> Iterable[str]: |
| 232 | + modulo = resource_id % (10 ** self.base_level) |
| 233 | + modulo_str = str(modulo) |
| 234 | + if len(modulo_str) < self.base_level: |
| 235 | + modulo_str = '0' * (self.base_level - len(modulo_str)) + modulo_str |
| 236 | + |
| 237 | + modulo_segments = id_modulo_cut(modulo_str) |
| 238 | + modulo_segments[-1] = f'0{modulo_segments[-1]}' |
| 239 | + return [f'{self.base_dir}/{"/".join(modulo_segments)}.tar'] |
0 commit comments