You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

OCR_preprocessor.py 7.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import random
  2. from typing import Tuple
  3. import matplotlib.pyplot as plt
  4. import numpy as np
  5. import scipy.ndimage as scp
  6. from bin.OCR_dataloader_iam import Batch
  7. class Preprocessor:
  8. def __init__(self,
  9. img_size: Tuple[int, int],
  10. padding: int = 0,
  11. dynamic_width: bool = False,
  12. data_augmentation: bool = False,
  13. line_mode: bool = False) -> None:
  14. # dynamic width only supported when no data augmentation happens
  15. assert not (dynamic_width and data_augmentation)
  16. # when padding is on, we need dynamic width enabled
  17. assert not (padding > 0 and not dynamic_width)
  18. self.img_size = img_size
  19. self.padding = padding
  20. self.dynamic_width = dynamic_width
  21. self.data_augmentation = data_augmentation
  22. self.line_mode = line_mode
  23. @staticmethod
  24. def _truncate_label(text: str, max_text_len: int) -> str:
  25. """
  26. Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
  27. labels. Repeat letters cost double because of the blank symbol needing to be inserted.
  28. If a too-long label is provided, ctc_loss returns an infinite gradient.
  29. """
  30. cost = 0
  31. for i in range(len(text)):
  32. if i != 0 and text[i] == text[i - 1]:
  33. cost += 2
  34. else:
  35. cost += 1
  36. if cost > max_text_len:
  37. return text[:i]
  38. return text
  39. def _simulate_text_line(self, batch: Batch) -> Batch:
  40. """Create image of a text line by pasting multiple word images into an image."""
  41. default_word_sep = 30
  42. default_num_words = 5
  43. # go over all batch elements
  44. res_imgs = []
  45. res_gt_texts = []
  46. for i in range(batch.batch_size):
  47. # number of words to put into current line
  48. num_words = random.randint(1, 8) if self.data_augmentation else default_num_words
  49. # concat ground truth texts
  50. curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)])
  51. res_gt_texts.append(curr_gt)
  52. # put selected word images into list, compute target image size
  53. sel_imgs = []
  54. word_seps = [0]
  55. h = 0
  56. w = 0
  57. for j in range(num_words):
  58. curr_sel_img = batch.imgs[(i + j) % batch.batch_size]
  59. curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep
  60. h = max(h, curr_sel_img.shape[0])
  61. w += curr_sel_img.shape[1]
  62. sel_imgs.append(curr_sel_img)
  63. if j + 1 < num_words:
  64. w += curr_word_sep
  65. word_seps.append(curr_word_sep)
  66. # put all selected word images into target image
  67. target = np.ones([h, w], np.uint8) * 255
  68. x = 0
  69. for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps):
  70. x += curr_word_sep
  71. y = (h - curr_sel_img.shape[0]) // 2
  72. target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img
  73. x += curr_sel_img.shape[1]
  74. # put image of line into result
  75. res_imgs.append(target)
  76. return Batch(res_imgs, res_gt_texts, batch.batch_size)
  77. def process_img(self, img: np.ndarray) -> np.ndarray:
  78. """Resize to target size, apply data augmentation."""
  79. # there are damaged files in IAM dataset - just use black image instead
  80. if img is None:
  81. img = np.zeros(self.img_size[::-1])
  82. # data augmentation
  83. img = img.astype(np.float)
  84. if self.data_augmentation:
  85. # photometric data augmentation
  86. if random.random() < 0.25:
  87. def rand_odd():
  88. return random.randint(1, 3) * 2 + 1
  89. img = scp.gaussian_filter(img, (rand_odd(), rand_odd()), 0) #cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
  90. if random.random() < 0.25:
  91. img = scp.grey_dilation(img, structure=np.ones((3, 3)))
  92. if random.random() < 0.25:
  93. img = scp.grey_erosion(img, structure=np.ones((3, 3)))
  94. # geometric data augmentation
  95. wt, ht = self.img_size[0], self.img_size[1]
  96. h, w = img.shape[0], img.shape[1]
  97. f = min(wt / w, ht / h)
  98. fx = f * np.random.uniform(0.75, 1.05)
  99. fy = f * np.random.uniform(0.75, 1.05)
  100. # random position around center
  101. txc = (wt - w * fx) / 2
  102. tyc = (ht - h * fy) / 2
  103. freedom_x = max((wt - fx * w) / 2, 0)
  104. freedom_y = max((ht - fy * h) / 2, 0)
  105. tx = txc + np.random.uniform(-freedom_x, freedom_x)
  106. ty = tyc + np.random.uniform(-freedom_y, freedom_y)
  107. # map image into target image
  108. M = np.float32([[fy, 0, ty], [0, fx, tx], [0,0,1]])
  109. M = np.linalg.inv(M)
  110. target = np.ones(self.img_size[::-1]) * 255
  111. img = scp.affine_transform(img, M, output_shape=self.img_size, output=target, mode="nearest")#, borderMode=cv2.BORDER_TRANSPARENT) #cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT)
  112. # photometric data augmentation
  113. if random.random() < 0.5:
  114. img = img * (0.25 + random.random() * 0.75)
  115. if random.random() < 0.25:
  116. img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255)
  117. if random.random() < 0.1:
  118. img = 255 - img
  119. # no data augmentation
  120. else:
  121. if self.dynamic_width:
  122. ht = self.img_size[1]
  123. h, w = img.shape
  124. f = ht / h
  125. wt = int(f * w + self.padding)
  126. wt = wt + (4 - wt) % 4
  127. tx = (wt - w * f) / 2
  128. ty = 0
  129. else:
  130. wt, ht = self.img_size[0], self.img_size[1]
  131. h, w = img.shape
  132. f = min(wt / w, ht / h)
  133. tx = (wt - w * f) / 2
  134. ty = (ht - h * f) / 2
  135. # map image into target image
  136. M = np.float32([[f, 0, ty], [0, f, tx], [0,0,1]])
  137. M = np.linalg.inv(M)
  138. target = np.ones([ht, wt]) * 255
  139. img = scp.affine_transform(img, M, output_shape=(ht,wt), output=target, mode="nearest")#, borderMode=cv2.BORDER_TRANSPARENT) #cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT)
  140. # transpose for TF
  141. img = np.transpose(img)
  142. # convert to range [-1, 1]
  143. img = img / 255 - 0.5
  144. return img
  145. def process_batch(self, batch: Batch) -> Batch:
  146. if self.line_mode:
  147. batch = self._simulate_text_line(batch)
  148. res_imgs = [self.process_img(img) for img in batch.imgs]
  149. max_text_len = res_imgs[0].shape[0] // 4
  150. res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts]
  151. return Batch(res_imgs, res_gt_texts, batch.batch_size)
  152. def main():
  153. img = plt.imread("../data/test.png")
  154. img = img = (img[:,:,0]+img[:,:,1]+img[:,:,2])*255//3
  155. img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img)
  156. plt.subplot(121)
  157. plt.imshow(img, cmap='gray')
  158. plt.subplot(122)
  159. plt.imshow(np.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1)
  160. plt.show()
  161. if __name__ == '__main__':
  162. main()