Skip to content

Commit

Permalink
Add crop with different aspect ratio in dataset preparing
Browse files Browse the repository at this point in the history
  • Loading branch information
Anghellia committed Sep 16, 2024
1 parent 007b6ec commit ac4e04b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
32 changes: 31 additions & 1 deletion image_datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,49 @@ def c_crop(image):
bottom = (height + new_size) / 2
return image.crop((left, top, right, bottom))

def crop_to_aspect_ratio(image, ratio="16:9"):
width, height = image.size
ratio_map = {
"16:9": (16, 9),
"4:3": (4, 3),
"1:1": (1, 1)
}
target_w, target_h = ratio_map[ratio]
target_ratio_value = target_w / target_h

current_ratio = width / height

if current_ratio > target_ratio_value:
new_width = int(height * target_ratio_value)
offset = (width - new_width) // 2
crop_box = (offset, 0, offset + new_width, height)
else:
new_height = int(width / target_ratio_value)
offset = (height - new_height) // 2
crop_box = (0, offset, width, offset + new_height)

cropped_img = image.crop(crop_box)
return cropped_img


class CustomImageDataset(Dataset):
def __init__(self, img_dir, img_size=512, caption_type='json'):
def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False):
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
self.images.sort()
self.img_size = img_size
self.caption_type = caption_type
self.random_ratio = random_ratio

def __len__(self):
return len(self.images)

def __getitem__(self, idx):
try:
img = Image.open(self.images[idx]).convert('RGB')
if self.random_ratio:
ratio = random.choice(["16:9", "default", "1:1", "4:3"])
if ratio != "default":
img = crop_to_aspect_ratio(img, ratio)
img = image_resize(img, self.img_size)
w, h = img.size
new_w = (w // 32) * 32
Expand Down
1 change: 1 addition & 0 deletions train_configs/test_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ data_config:
num_workers: 4
img_size: 512
img_dir: images/
random_ratio: true # support multi crop preprocessing
report_to: wandb
train_batch_size: 1
output_dir: lora/
Expand Down

0 comments on commit ac4e04b

Please sign in to comment.