import random from typing import Tuple import matplotlib.pyplot as plt import numpy as np import scipy.ndimage as scp from bin.OCR_dataloader_iam import Batch class Preprocessor: def __init__(self, img_size: Tuple[int, int], padding: int = 0, dynamic_width: bool = False, data_augmentation: bool = False, line_mode: bool = False) -> None: # dynamic width only supported when no data augmentation happens assert not (dynamic_width and data_augmentation) # when padding is on, we need dynamic width enabled assert not (padding > 0 and not dynamic_width) self.img_size = img_size self.padding = padding self.dynamic_width = dynamic_width self.data_augmentation = data_augmentation self.line_mode = line_mode @staticmethod def _truncate_label(text: str, max_text_len: int) -> str: """ Function ctc_loss can't compute loss if it cannot find a mapping between text label and input labels. Repeat letters cost double because of the blank symbol needing to be inserted. If a too-long label is provided, ctc_loss returns an infinite gradient. """ cost = 0 for i in range(len(text)): if i != 0 and text[i] == text[i - 1]: cost += 2 else: cost += 1 if cost > max_text_len: return text[:i] return text def _simulate_text_line(self, batch: Batch) -> Batch: """Create image of a text line by pasting multiple word images into an image.""" default_word_sep = 30 default_num_words = 5 # go over all batch elements res_imgs = [] res_gt_texts = [] for i in range(batch.batch_size): # number of words to put into current line num_words = random.randint(1, 8) if self.data_augmentation else default_num_words # concat ground truth texts curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)]) res_gt_texts.append(curr_gt) # put selected word images into list, compute target image size sel_imgs = [] word_seps = [0] h = 0 w = 0 for j in range(num_words): curr_sel_img = batch.imgs[(i + j) % batch.batch_size] curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep h = max(h, curr_sel_img.shape[0]) w += curr_sel_img.shape[1] sel_imgs.append(curr_sel_img) if j + 1 < num_words: w += curr_word_sep word_seps.append(curr_word_sep) # put all selected word images into target image target = np.ones([h, w], np.uint8) * 255 x = 0 for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps): x += curr_word_sep y = (h - curr_sel_img.shape[0]) // 2 target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img x += curr_sel_img.shape[1] # put image of line into result res_imgs.append(target) return Batch(res_imgs, res_gt_texts, batch.batch_size) def process_img(self, img: np.ndarray) -> np.ndarray: """Resize to target size, apply data augmentation.""" # there are damaged files in IAM dataset - just use black image instead if img is None: img = np.zeros(self.img_size[::-1]) # data augmentation img = img.astype(np.float) if self.data_augmentation: # photometric data augmentation if random.random() < 0.25: def rand_odd(): return random.randint(1, 3) * 2 + 1 img = scp.gaussian_filter(img, (rand_odd(), rand_odd()), 0) #cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0) if random.random() < 0.25: img = scp.grey_dilation(img, structure=np.ones((3, 3))) if random.random() < 0.25: img = scp.grey_erosion(img, structure=np.ones((3, 3))) # geometric data augmentation wt, ht = self.img_size[0], self.img_size[1] h, w = img.shape[0], img.shape[1] f = min(wt / w, ht / h) fx = f * np.random.uniform(0.75, 1.05) fy = f * np.random.uniform(0.75, 1.05) # random position around center txc = (wt - w * fx) / 2 tyc = (ht - h * fy) / 2 freedom_x = max((wt - fx * w) / 2, 0) freedom_y = max((ht - fy * h) / 2, 0) tx = txc + np.random.uniform(-freedom_x, freedom_x) ty = tyc + np.random.uniform(-freedom_y, freedom_y) # map image into target image M = np.float32([[fy, 0, ty], [0, fx, tx], [0,0,1]]) M = np.linalg.inv(M) target = np.ones(self.img_size[::-1]) * 255 img = scp.affine_transform(img, M, output_shape=self.img_size, output=target, mode="nearest")#, borderMode=cv2.BORDER_TRANSPARENT) #cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT) # photometric data augmentation if random.random() < 0.5: img = img * (0.25 + random.random() * 0.75) if random.random() < 0.25: img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255) if random.random() < 0.1: img = 255 - img # no data augmentation else: if self.dynamic_width: ht = self.img_size[1] h, w = img.shape f = ht / h wt = int(f * w + self.padding) wt = wt + (4 - wt) % 4 tx = (wt - w * f) / 2 ty = 0 else: wt, ht = self.img_size[0], self.img_size[1] h, w = img.shape f = min(wt / w, ht / h) tx = (wt - w * f) / 2 ty = (ht - h * f) / 2 # map image into target image M = np.float32([[f, 0, ty], [0, f, tx], [0,0,1]]) M = np.linalg.inv(M) target = np.ones([ht, wt]) * 255 img = scp.affine_transform(img, M, output_shape=(ht,wt), output=target, mode="nearest")#, borderMode=cv2.BORDER_TRANSPARENT) #cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT) # transpose for TF img = np.transpose(img) # convert to range [-1, 1] img = img / 255 - 0.5 return img def process_batch(self, batch: Batch) -> Batch: if self.line_mode: batch = self._simulate_text_line(batch) res_imgs = [self.process_img(img) for img in batch.imgs] max_text_len = res_imgs[0].shape[0] // 4 res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts] return Batch(res_imgs, res_gt_texts, batch.batch_size) def main(): img = plt.imread("../data/test.png") img = img = (img[:,:,0]+img[:,:,1]+img[:,:,2])*255//3 img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img) plt.subplot(121) plt.imshow(img, cmap='gray') plt.subplot(122) plt.imshow(np.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1) plt.show() if __name__ == '__main__': main()