|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- import random
- from typing import Tuple
-
- import matplotlib.pyplot as plt
- import numpy as np
- import scipy.ndimage as scp
-
- from bin.OCR_dataloader_iam import Batch
-
-
- class Preprocessor:
- def __init__(self,
- img_size: Tuple[int, int],
- padding: int = 0,
- dynamic_width: bool = False,
- data_augmentation: bool = False,
- line_mode: bool = False) -> None:
- # dynamic width only supported when no data augmentation happens
- assert not (dynamic_width and data_augmentation)
- # when padding is on, we need dynamic width enabled
- assert not (padding > 0 and not dynamic_width)
-
- self.img_size = img_size
- self.padding = padding
- self.dynamic_width = dynamic_width
- self.data_augmentation = data_augmentation
- self.line_mode = line_mode
-
- @staticmethod
- def _truncate_label(text: str, max_text_len: int) -> str:
- """
- Function ctc_loss can't compute loss if it cannot find a mapping between text label and input
- labels. Repeat letters cost double because of the blank symbol needing to be inserted.
- If a too-long label is provided, ctc_loss returns an infinite gradient.
- """
- cost = 0
- for i in range(len(text)):
- if i != 0 and text[i] == text[i - 1]:
- cost += 2
- else:
- cost += 1
- if cost > max_text_len:
- return text[:i]
- return text
-
- def _simulate_text_line(self, batch: Batch) -> Batch:
- """Create image of a text line by pasting multiple word images into an image."""
-
- default_word_sep = 30
- default_num_words = 5
-
- # go over all batch elements
- res_imgs = []
- res_gt_texts = []
- for i in range(batch.batch_size):
- # number of words to put into current line
- num_words = random.randint(1, 8) if self.data_augmentation else default_num_words
-
- # concat ground truth texts
- curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)])
- res_gt_texts.append(curr_gt)
-
- # put selected word images into list, compute target image size
- sel_imgs = []
- word_seps = [0]
- h = 0
- w = 0
- for j in range(num_words):
- curr_sel_img = batch.imgs[(i + j) % batch.batch_size]
- curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep
- h = max(h, curr_sel_img.shape[0])
- w += curr_sel_img.shape[1]
- sel_imgs.append(curr_sel_img)
- if j + 1 < num_words:
- w += curr_word_sep
- word_seps.append(curr_word_sep)
-
- # put all selected word images into target image
- target = np.ones([h, w], np.uint8) * 255
- x = 0
- for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps):
- x += curr_word_sep
- y = (h - curr_sel_img.shape[0]) // 2
- target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img
- x += curr_sel_img.shape[1]
-
- # put image of line into result
- res_imgs.append(target)
-
- return Batch(res_imgs, res_gt_texts, batch.batch_size)
-
- def process_img(self, img: np.ndarray) -> np.ndarray:
- """Resize to target size, apply data augmentation."""
-
- # there are damaged files in IAM dataset - just use black image instead
- if img is None:
- img = np.zeros(self.img_size[::-1])
-
- # data augmentation
- img = img.astype(np.float)
- if self.data_augmentation:
- # photometric data augmentation
- if random.random() < 0.25:
- def rand_odd():
- return random.randint(1, 3) * 2 + 1
- img = scp.gaussian_filter(img, (rand_odd(), rand_odd()), 0) #cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0)
- if random.random() < 0.25:
- img = scp.grey_dilation(img, structure=np.ones((3, 3)))
- if random.random() < 0.25:
- img = scp.grey_erosion(img, structure=np.ones((3, 3)))
-
- # geometric data augmentation
- wt, ht = self.img_size[0], self.img_size[1]
- h, w = img.shape[0], img.shape[1]
- f = min(wt / w, ht / h)
- fx = f * np.random.uniform(0.75, 1.05)
- fy = f * np.random.uniform(0.75, 1.05)
-
- # random position around center
- txc = (wt - w * fx) / 2
- tyc = (ht - h * fy) / 2
- freedom_x = max((wt - fx * w) / 2, 0)
- freedom_y = max((ht - fy * h) / 2, 0)
- tx = txc + np.random.uniform(-freedom_x, freedom_x)
- ty = tyc + np.random.uniform(-freedom_y, freedom_y)
-
- # map image into target image
- M = np.float32([[fy, 0, ty], [0, fx, tx], [0,0,1]])
- M = np.linalg.inv(M)
- target = np.ones(self.img_size[::-1]) * 255
- img = scp.affine_transform(img, M, output_shape=self.img_size, output=target, mode="nearest")#, borderMode=cv2.BORDER_TRANSPARENT) #cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT)
-
- # photometric data augmentation
- if random.random() < 0.5:
- img = img * (0.25 + random.random() * 0.75)
- if random.random() < 0.25:
- img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255)
- if random.random() < 0.1:
- img = 255 - img
-
- # no data augmentation
- else:
- if self.dynamic_width:
- ht = self.img_size[1]
- h, w = img.shape
- f = ht / h
- wt = int(f * w + self.padding)
- wt = wt + (4 - wt) % 4
- tx = (wt - w * f) / 2
- ty = 0
- else:
- wt, ht = self.img_size[0], self.img_size[1]
- h, w = img.shape
- f = min(wt / w, ht / h)
- tx = (wt - w * f) / 2
- ty = (ht - h * f) / 2
-
- # map image into target image
- M = np.float32([[f, 0, ty], [0, f, tx], [0,0,1]])
- M = np.linalg.inv(M)
- target = np.ones([ht, wt]) * 255
- img = scp.affine_transform(img, M, output_shape=(ht,wt), output=target, mode="nearest")#, borderMode=cv2.BORDER_TRANSPARENT) #cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT)
-
- # transpose for TF
- img = np.transpose(img)
-
- # convert to range [-1, 1]
- img = img / 255 - 0.5
- return img
-
- def process_batch(self, batch: Batch) -> Batch:
- if self.line_mode:
- batch = self._simulate_text_line(batch)
-
- res_imgs = [self.process_img(img) for img in batch.imgs]
- max_text_len = res_imgs[0].shape[0] // 4
- res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts]
- return Batch(res_imgs, res_gt_texts, batch.batch_size)
-
-
- def main():
-
- img = plt.imread("../data/test.png")
- img = img = (img[:,:,0]+img[:,:,1]+img[:,:,2])*255//3
- img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img)
- plt.subplot(121)
- plt.imshow(img, cmap='gray')
- plt.subplot(122)
- plt.imshow(np.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1)
- plt.show()
-
-
- if __name__ == '__main__':
- main()
|