You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long. 7.5KB

  1. import random
  2. from typing import Tuple
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import scipy.ndimage as scp
  6. from bin.OCR_dataloader_iam import Batch
  7. class Preprocessor:
  8. def __init__(self,
  9. img_size: Tuple[int, int],
  10. padding: int = 0,
  11. dynamic_width: bool = False,
  12. data_augmentation: bool = False,
  13. line_mode: bool = False) -> None:
  14. # dynamic width only supported when no data augmentation happens
  15. assert not (dynamic_width and data_augmentation)
  16. # when padding is on, we need dynamic width enabled
  17. assert not (padding > 0 and not dynamic_width)
  18. self.img_size = img_size
  19. self.padding = padding
  20. self.dynamic_width = dynamic_width
  21. self.data_augmentation = data_augmentation
  22. self.line_mode = line_mode
  23. @staticmethod
  24. def _truncate_label(text: str, max_text_len: int) -> str:
  25. """
  26. Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
  27. labels. Repeat letters cost double because of the blank symbol needing to be inserted.
  28. If a too-long label is provided, ctc_loss returns an infinite gradient.
  29. """
  30. cost = 0
  31. for i in range(len(text)):
  32. if i != 0 and text[i] == text[i - 1]:
  33. cost += 2
  34. else:
  35. cost += 1
  36. if cost > max_text_len:
  37. return text[:i]
  38. return text
  39. def _simulate_text_line(self, batch: Batch) -> Batch:
  40. """Create image of a text line by pasting multiple word images into an image."""
  41. default_word_sep = 30
  42. default_num_words = 5
  43. # go over all batch elements
  44. res_imgs = []
  45. res_gt_texts = []
  46. for i in range(batch.batch_size):
  47. # number of words to put into current line
  48. num_words = random.randint(1, 8) if self.data_augmentation else default_num_words
  49. # concat ground truth texts
  50. curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)])
  51. res_gt_texts.append(curr_gt)
  52. # put selected word images into list, compute target image size
  53. sel_imgs = []
  54. word_seps = [0]
  55. h = 0
  56. w = 0
  57. for j in range(num_words):
  58. curr_sel_img = batch.imgs[(i + j) % batch.batch_size]
  59. curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep
  60. h = max(h, curr_sel_img.shape[0])
  61. w += curr_sel_img.shape[1]
  62. sel_imgs.append(curr_sel_img)
  63. if j + 1 < num_words:
  64. w += curr_word_sep
  65. word_seps.append(curr_word_sep)
  66. # put all selected word images into target image
  67. target = np.ones([h, w], np.uint8) * 255
  68. x = 0
  69. for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps):
  70. x += curr_word_sep
  71. y = (h - curr_sel_img.shape[0]) // 2
  72. target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img
  73. x += curr_sel_img.shape[1]
  74. # put image of line into result
  75. res_imgs.append(target)
  76. return Batch(res_imgs, res_gt_texts, batch.batch_size)
  77. def process_img(self, img: np.ndarray) -> np.ndarray:
  78. """Resize to target size, apply data augmentation."""
  79. # there are damaged files in IAM dataset - just use black image instead
  80. if img is None:
  81. img = np.zeros(self.img_size[::-1])
  82. # data augmentation
  83. img = img.astype(np.float)
  84. if self.data_augmentation:
  85. # photometric data augmentation
  86. if random.random() < 0.25:
  87. def rand_odd():
  88. return random.randint(1, 3) * 2 + 1
  89. img = scp.gaussian_filter(img, (rand_odd(), rand_odd()), 0) #cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
  90. if random.random() < 0.25:
  91. img = scp.grey_dilation(img, structure=np.ones((3, 3)))
  92. if random.random() < 0.25:
  93. img = scp.grey_erosion(img, structure=np.ones((3, 3)))
  94. # geometric data augmentation
  95. wt, ht = self.img_size[0], self.img_size[1]
  96. h, w = img.shape[0], img.shape[1]
  97. f = min(wt / w, ht / h)
  98. fx = f * np.random.uniform(0.75, 1.05)
  99. fy = f * np.random.uniform(0.75, 1.05)
  100. # random position around center
  101. txc = (wt - w * fx) / 2
  102. tyc = (ht - h * fy) / 2
  103. freedom_x = max((wt - fx * w) / 2, 0)
  104. freedom_y = max((ht - fy * h) / 2, 0)
  105. tx = txc + np.random.uniform(-freedom_x, freedom_x)
  106. ty = tyc + np.random.uniform(-freedom_y, freedom_y)
  107. # map image into target image
  108. M = np.float32([[fy, 0, ty], [0, fx, tx], [0,0,1]])
  109. M = np.linalg.inv(M)
  110. target = np.ones(self.img_size[::-1]) * 255
  111. img = scp.affine_transform(img, M, output_shape=self.img_size, output=target, mode="nearest")#, borderMode=cv2.BORDER_TRANSPARENT) #cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT)
  112. # photometric data augmentation
  113. if random.random() < 0.5:
  114. img = img * (0.25 + random.random() * 0.75)
  115. if random.random() < 0.25:
  116. img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255)
  117. if random.random() < 0.1:
  118. img = 255 - img
  119. # no data augmentation
  120. else:
  121. if self.dynamic_width:
  122. ht = self.img_size[1]
  123. h, w = img.shape
  124. f = ht / h
  125. wt = int(f * w + self.padding)
  126. wt = wt + (4 - wt) % 4
  127. tx = (wt - w * f) / 2
  128. ty = 0
  129. else:
  130. wt, ht = self.img_size[0], self.img_size[1]
  131. h, w = img.shape
  132. f = min(wt / w, ht / h)
  133. tx = (wt - w * f) / 2
  134. ty = (ht - h * f) / 2
  135. # map image into target image
  136. M = np.float32([[f, 0, ty], [0, f, tx], [0,0,1]])
  137. M = np.linalg.inv(M)
  138. target = np.ones([ht, wt]) * 255
  139. img = scp.affine_transform(img, M, output_shape=(ht,wt), output=target, mode="nearest")#, borderMode=cv2.BORDER_TRANSPARENT) #cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT)
  140. # transpose for TF
  141. img = np.transpose(img)
  142. # convert to range [-1, 1]
  143. img = img / 255 - 0.5
  144. return img
  145. def process_batch(self, batch: Batch) -> Batch:
  146. if self.line_mode:
  147. batch = self._simulate_text_line(batch)
  148. res_imgs = [self.process_img(img) for img in batch.imgs]
  149. max_text_len = res_imgs[0].shape[0] // 4
  150. res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts]
  151. return Batch(res_imgs, res_gt_texts, batch.batch_size)
  152. def main():
  153. img = plt.imread("../data/test.png")
  154. img = img = (img[:,:,0]+img[:,:,1]+img[:,:,2])*255//3
  155. img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img)
  156. plt.subplot(121)
  157. plt.imshow(img, cmap='gray')
  158. plt.subplot(122)
  159. plt.imshow(np.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1)
  161. if __name__ == '__main__':
  162. main()