|
| 1 | +from PIL import Image, ImageEnhance, ImageOps |
| 2 | +import numpy as np |
| 3 | +import random |
| 4 | + |
| 5 | + |
| 6 | +class ImageNetPolicy(object): |
| 7 | + """ Randomly choose one of the best 24 Sub-policies on ImageNet. |
| 8 | +
|
| 9 | + Example: |
| 10 | + >>> policy = ImageNetPolicy() |
| 11 | + >>> transformed = policy(image) |
| 12 | +
|
| 13 | + Example as a PyTorch Transform: |
| 14 | + >>> transform=transforms.Compose([ |
| 15 | + >>> transforms.Resize(256), |
| 16 | + >>> ImageNetPolicy(), |
| 17 | + >>> transforms.ToTensor()]) |
| 18 | + """ |
| 19 | + def __init__(self, fillcolor=(128, 128, 128)): |
| 20 | + self.policies = [ |
| 21 | + SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), |
| 22 | + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), |
| 23 | + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), |
| 24 | + SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), |
| 25 | + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), |
| 26 | + |
| 27 | + SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), |
| 28 | + SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), |
| 29 | + SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), |
| 30 | + SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), |
| 31 | + SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), |
| 32 | + |
| 33 | + SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), |
| 34 | + SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), |
| 35 | + SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), |
| 36 | + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), |
| 37 | + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), |
| 38 | + |
| 39 | + SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), |
| 40 | + SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), |
| 41 | + SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), |
| 42 | + SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), |
| 43 | + SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), |
| 44 | + |
| 45 | + SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), |
| 46 | + SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), |
| 47 | + SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), |
| 48 | + SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), |
| 49 | + SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) |
| 50 | + ] |
| 51 | + |
| 52 | + |
| 53 | + def __call__(self, img): |
| 54 | + policy_idx = random.randint(0, len(self.policies) - 1) |
| 55 | + return self.policies[policy_idx](img) |
| 56 | + |
| 57 | + def __repr__(self): |
| 58 | + return "AutoAugment ImageNet Policy" |
| 59 | + |
| 60 | + |
| 61 | +class CIFAR10Policy(object): |
| 62 | + """ Randomly choose one of the best 25 Sub-policies on CIFAR10. |
| 63 | + """ |
| 64 | + def __init__(self, fillcolor=(128, 128, 128)): |
| 65 | + self.policies = [ |
| 66 | + SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), |
| 67 | + SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), |
| 68 | + SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), |
| 69 | + SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), |
| 70 | + SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), |
| 71 | + |
| 72 | + SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), |
| 73 | + SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), |
| 74 | + SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), |
| 75 | + SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), |
| 76 | + SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), |
| 77 | + |
| 78 | + SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), |
| 79 | + SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), |
| 80 | + SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), |
| 81 | + SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), |
| 82 | + SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), |
| 83 | + |
| 84 | + SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), |
| 85 | + SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), |
| 86 | + SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), |
| 87 | + SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), |
| 88 | + SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), |
| 89 | + |
| 90 | + SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), |
| 91 | + SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), |
| 92 | + SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), |
| 93 | + SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), |
| 94 | + SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) |
| 95 | + ] |
| 96 | + |
| 97 | + |
| 98 | + def __call__(self, img): |
| 99 | + policy_idx = random.randint(0, len(self.policies) - 1) |
| 100 | + return self.policies[policy_idx](img) |
| 101 | + |
| 102 | + def __repr__(self): |
| 103 | + return "AutoAugment CIFAR10 Policy" |
| 104 | + |
| 105 | + |
| 106 | +class SVHNPolicy(object): |
| 107 | + """ Randomly choose one of the best 25 Sub-policies on SVHN. |
| 108 | +
|
| 109 | + Example: |
| 110 | + >>> policy = SVHNPolicy() |
| 111 | + >>> transformed = policy(image) |
| 112 | +
|
| 113 | + Example as a PyTorch Transform: |
| 114 | + >>> transform=transforms.Compose([ |
| 115 | + >>> transforms.Resize(256), |
| 116 | + >>> SVHNPolicy(), |
| 117 | + >>> transforms.ToTensor()]) |
| 118 | + """ |
| 119 | + def __init__(self, fillcolor=(128, 128, 128)): |
| 120 | + self.policies = [ |
| 121 | + SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), |
| 122 | + SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), |
| 123 | + SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), |
| 124 | + SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), |
| 125 | + SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), |
| 126 | + |
| 127 | + SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), |
| 128 | + SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), |
| 129 | + SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), |
| 130 | + SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), |
| 131 | + SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), |
| 132 | + |
| 133 | + SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), |
| 134 | + SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), |
| 135 | + SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), |
| 136 | + SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), |
| 137 | + SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), |
| 138 | + |
| 139 | + SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), |
| 140 | + SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), |
| 141 | + SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), |
| 142 | + SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), |
| 143 | + SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), |
| 144 | + |
| 145 | + SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), |
| 146 | + SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), |
| 147 | + SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), |
| 148 | + SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), |
| 149 | + SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) |
| 150 | + ] |
| 151 | + |
| 152 | + |
| 153 | + def __call__(self, img): |
| 154 | + policy_idx = random.randint(0, len(self.policies) - 1) |
| 155 | + return self.policies[policy_idx](img) |
| 156 | + |
| 157 | + def __repr__(self): |
| 158 | + return "AutoAugment SVHN Policy" |
| 159 | + |
| 160 | + |
| 161 | +class SubPolicy(object): |
| 162 | + def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): |
| 163 | + ranges = { |
| 164 | + "shearX": np.linspace(0, 0.3, 10), |
| 165 | + "shearY": np.linspace(0, 0.3, 10), |
| 166 | + "translateX": np.linspace(0, 150 / 331, 10), |
| 167 | + "translateY": np.linspace(0, 150 / 331, 10), |
| 168 | + "rotate": np.linspace(0, 30, 10), |
| 169 | + "color": np.linspace(0.0, 0.9, 10), |
| 170 | + "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), |
| 171 | + "solarize": np.linspace(256, 0, 10), |
| 172 | + "contrast": np.linspace(0.0, 0.9, 10), |
| 173 | + "sharpness": np.linspace(0.0, 0.9, 10), |
| 174 | + "brightness": np.linspace(0.0, 0.9, 10), |
| 175 | + "autocontrast": [0] * 10, |
| 176 | + "equalize": [0] * 10, |
| 177 | + "invert": [0] * 10 |
| 178 | + } |
| 179 | + |
| 180 | + # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand |
| 181 | + def rotate_with_fill(img, magnitude): |
| 182 | + rot = img.convert("RGBA").rotate(magnitude) |
| 183 | + return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) |
| 184 | + |
| 185 | + func = { |
| 186 | + "shearX": lambda img, magnitude: img.transform( |
| 187 | + img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), |
| 188 | + Image.BICUBIC, fillcolor=fillcolor), |
| 189 | + "shearY": lambda img, magnitude: img.transform( |
| 190 | + img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), |
| 191 | + Image.BICUBIC, fillcolor=fillcolor), |
| 192 | + "translateX": lambda img, magnitude: img.transform( |
| 193 | + img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), |
| 194 | + fillcolor=fillcolor), |
| 195 | + "translateY": lambda img, magnitude: img.transform( |
| 196 | + img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), |
| 197 | + fillcolor=fillcolor), |
| 198 | + "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), |
| 199 | + # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), |
| 200 | + "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), |
| 201 | + "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), |
| 202 | + "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), |
| 203 | + "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( |
| 204 | + 1 + magnitude * random.choice([-1, 1])), |
| 205 | + "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( |
| 206 | + 1 + magnitude * random.choice([-1, 1])), |
| 207 | + "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( |
| 208 | + 1 + magnitude * random.choice([-1, 1])), |
| 209 | + "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), |
| 210 | + "equalize": lambda img, magnitude: ImageOps.equalize(img), |
| 211 | + "invert": lambda img, magnitude: ImageOps.invert(img) |
| 212 | + } |
| 213 | + |
| 214 | + # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( |
| 215 | + # operation1, ranges[operation1][magnitude_idx1], |
| 216 | + # operation2, ranges[operation2][magnitude_idx2]) |
| 217 | + self.p1 = p1 |
| 218 | + self.operation1 = func[operation1] |
| 219 | + self.magnitude1 = ranges[operation1][magnitude_idx1] |
| 220 | + self.p2 = p2 |
| 221 | + self.operation2 = func[operation2] |
| 222 | + self.magnitude2 = ranges[operation2][magnitude_idx2] |
| 223 | + |
| 224 | + |
| 225 | + def __call__(self, img): |
| 226 | + if random.random() < self.p1: img = self.operation1(img, self.magnitude1) |
| 227 | + if random.random() < self.p2: img = self.operation2(img, self.magnitude2) |
| 228 | + return img |
0 commit comments