This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 15
/
input_pipeline_laion.py
281 lines (230 loc) · 8.23 KB
/
input_pipeline_laion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/google/flax/tree/main/examples/imagenet
import tensorflow as tf
import tensorflow_text as tftx
from absl import logging
import functools
from utils.transform_util import (
decode_and_random_crop,
normalize_image,
)
from utils import logging_util
feature_description = {
"txt": tf.io.FixedLenFeature([], tf.string),
"jpg": tf.io.FixedLenFeature([], tf.string),
"height": tf.io.FixedLenFeature([], tf.int64),
"width": tf.io.FixedLenFeature([], tf.int64),
"key": tf.io.FixedLenFeature([], tf.string),
}
def parse_laion_example(example_proto):
example = tf.io.parse_single_example(example_proto, feature_description)
example["image"] = example.pop("jpg")
# the image id in the dataset (not continuous, due to download failures)
example["image_id"] = tf.strings.to_number(example.pop("key"), out_type=tf.int64)
return example
def tfds_preprocess_text(txt, tokenizer, cls_token, aug_txt):
"""
reference https://github.com/google-research/big_vision/blob/main/big_vision/pp/proj/flaxformer/bert_ops.py
"""
token_ids = tokenizer.tokenize(txt)
max_len = aug_txt.max_len + (-1 if cls_token else 0)
padded_token_ids, is_valid = tftx.pad_model_inputs(token_ids, max_len)
padded_token_ids, is_valid = padded_token_ids[0], is_valid[0]
if cls_token is not None:
# appendix cls token at the beginning
padded_token_ids = tf.concat(
[
tf.fill(
[
1,
],
cls_token,
),
padded_token_ids,
],
axis=0,
)
is_valid = tf.concat(
[
tf.fill(
[
1,
],
1,
),
is_valid,
],
axis=0,
)
return padded_token_ids, is_valid
def get_txt_tokenize_func(aug_txt):
if aug_txt.tokenizer == "tf_bert":
# vocab file: gs://vit_models/lit/LiT-B16B.txt. It should be the same as vocab.txt in:
# https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip
# md5sum: 64800d5d8528ce344256daf115d4965e
# vocab_size: 30523, (30522+1, including unknown [UNK])
vocab_file = "./vocab/vocab_bert_base.txt"
tokenizer = tftx.BertTokenizer(
vocab_file, lower_case=True, token_out_type=tf.int32
)
if aug_txt.cls_token:
with open(vocab_file) as f:
vocab = f.read().split("\n")
cls_token = vocab.index("[CLS]")
else:
cls_token = None
tokenize_func = functools.partial(
tfds_preprocess_text,
tokenizer=tokenizer,
cls_token=cls_token,
aug_txt=aug_txt,
)
vocab_size = (
tokenizer._wordpiece_tokenizer.vocab_size().numpy()
) # including unknown
return tokenize_func, vocab_size
else:
raise NotImplementedError
def decode_example(example, image_size, aug, tokenize_func):
# decoder the text
txt, txt_is_valid = preprocess_text(example["txt"], tokenize_func=tokenize_func)
# decoder the image
image = (
preprocess_image(example["image"], image_size=image_size, aug=aug)
if example["image"] is not None
else None
)
return {"image": image, "txt": txt, "txt_is_valid": txt_is_valid}
def preprocess_text(txt, tokenize_func):
txt_enc = tokenize_func(txt)
return txt_enc
def preprocess_image(image_bytes, dtype=tf.float32, image_size=None, aug=None):
"""Preprocesses the given image for training.
Args:
image_bytes: `Tensor` representing an image binary of arbitrary size.
dtype: data type of the image.
image_size: image size.
aug: configs for augmentations.
Returns:
A preprocessed image `Tensor`.
"""
crop_func = decode_and_random_crop[aug.crop_ver]
image = crop_func(
image_bytes,
image_size,
area_range=aug.area_range,
aspect_ratio_range=aug.aspect_ratio_range,
)
image = tf.reshape(image, [image_size, image_size, 3])
if aug.flip:
image = tf.image.random_flip_left_right(image)
image = normalize_image(image)
image = tf.image.convert_image_dtype(image, dtype=dtype)
return image
def create_split(
dataset_path,
batch_size,
data_layout,
train,
image_size=None,
cache=False,
seed=0,
cfg=None,
):
"""Creates a split from the LAION dataset using TensorFlow Datasets.
Args:
dataset_builder: TFDS dataset builder for ImageNet.
batch_size (local_batch_size): the batch size returned by the data pipeline.
data_layout: the partitioner data_layout
train: Whether to load the train or evaluation split.
image_size: The target size of the images.
cache: Whether to cache the dataset.
Returns:
A `tf.data.Dataset`.
"""
aug = cfg.aug
shard_id = data_layout.shard_id
num_shards = data_layout.num_shards
logging.info(f"laion data path{dataset_path}")
filenames = tf.io.gfile.glob(dataset_path + "/*.tfrecord")
filenames.sort()
if train:
train_records = len(filenames)
split_size = train_records // num_shards
start = shard_id * split_size
split = "train[{}:{}]".format(start, start + split_size)
filenames = filenames[start : start + split_size]
# ----------------------------------------
logging_util.verbose_on()
logging_util.sync_and_delay()
logging.info("Split: {} / {}".format(split, train_records))
logging_util.verbose_off()
# ----------------------------------------
else:
raise NotImplementedError
ds = tf.data.TFRecordDataset(filenames).map(parse_laion_example)
ds = ds.apply(tf.data.experimental.ignore_errors(log_warning=False))
options = tf.data.Options()
options.threading.private_threadpool_size = 48
ds = ds.with_options(options)
if cache:
ds = ds.cache()
if train:
ds = ds.repeat()
ds = ds.shuffle(16 * batch_size, seed=seed)
# create the tokenizer
tokenize_func, vocab_size = get_txt_tokenize_func(aug.txt)
assert vocab_size == cfg.model.model_txt.vocab_size
decode_fn = functools.partial(
decode_example, image_size=image_size, aug=aug, tokenize_func=tokenize_func
)
ds = ds.map(decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.prefetch(10)
return ds
def create_tags_split(
tags,
batch_size,
train,
image_size=None,
cache=False,
seed=0,
cfg=None,
):
"""Creates a split from the ImageNet tags dataset using TensorFlow Datasets.
Args:
dataset_builder: TFDS dataset builder for ImageNet.
batch_size (local_batch_size): the batch size returned by the data pipeline.
data_layout: the partitioner data_layout
train: Whether to load the train or evaluation split.
dtype: data type of the image.
image_size: The target size of the images.
cache: Whether to cache the dataset.
Returns:
A `tf.data.Dataset`.
"""
aug = cfg.aug
ds = tf.data.Dataset.from_tensor_slices(tags)
ds = ds.map(lambda x: {"txt": x, "image": None})
logging.info("Creating dataset from tags.")
options = tf.data.Options()
options.threading.private_threadpool_size = 48
ds = ds.with_options(options)
if cache:
ds = ds.cache()
if train:
ds = ds.repeat()
ds = ds.shuffle(16 * batch_size, seed=seed)
# create the tokenizer
tokenize_func, vocab_size = get_txt_tokenize_func(aug.txt)
assert vocab_size == cfg.model.model_txt.vocab_size
decode_fn = functools.partial(
decode_example, image_size=image_size, aug=aug, tokenize_func=tokenize_func
)
ds = ds.map(decode_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.batch(batch_size, drop_remainder=True)
return ds