@@ -0,0 +1,5 @@ | |||
# Ignore everything in this directory | |||
* | |||
# Except this file and wordCharList.txt | |||
!.gitignore | |||
wordCharList.txt |
@@ -0,0 +1,32 @@ | |||
## OCR_paper_form 0.1 | |||
# An application to automatically extract handwritten informations from paper forms | |||
March 2023 | |||
Only available in french - developped as root application for La Gemme organisation. | |||
Runs on Python3 with Linux (developped on Python3.8, should run with older versions and on Windows) | |||
Needs Scipy, Matplotlib, Numpy, TensorFlow libraries. | |||
## Comment utiliser l'application | |||
Extrayez le dossier dans un répertoire dédié - elle fonctionne sans installation particulière, mis à part les librairies. Il s'appuie sur une analyse d'un modèle pré-enregistré pour extratire des données. | |||
C'est un logiciel terminal, l'interface graphique est inexistante. | |||
Lancez le logiciel avec python3 main.py, vous amenant vers un menu. | |||
Pour extraire des données, il faut d'abord se baser sur un modèle, qu'il faut construire (menu 2). Pour ce faire: | |||
- choisissez une image en PNG issue d'un formulaire papier vierge, ou d'un formulaire PDF exporté en PNG. Essayez d'avoir une image complète dont le nombre de pixels est équivalent à ceux des formulaires que vous serez amenés à scanner. | |||
- placez différents rectangles en fonction des informations demandées : case à écrire, ligne allant jusqu'à la fin de l'image, case à cocher (multiples ou exclusives), puis sélectionner les coins en haut à gauche et en bas à droite de votre rectangle en cliquant sur l'image qui s'affiche (veillez à ne pas cliquer sur l'extérieur de l'image) ; si vous sélectionnez des cases à cocher, cliquez bien sur chaque coin haut-gauche et bas-droit de chaque case, la fenêtre se fermera automatiquement après. | |||
- répétez l'opération pour chaque information. Une fois terminé, le modèle sera automatiquement enregistré | |||
Vous pouvez modifier le modèle (pas encore implémenté) | |||
Une fois le modèle utilisable et complet, enregistré, vous pouvez analyser les images PNG scannées, pléalablement placées dans le dossier "scanned" du logiciel. Sélectionnez le menu 1 et le formulaire correspondant (la détection autonmatique ou par type de document n'est pas encore implémentée). La détection est automatique sur tous les champs, et vise à extraire les données manuscrites pour chaque champ, pour chaque image scannée. | |||
(non encore codé -> ) Une fois les données extraites, une vérification visuelle est nécessaire avec une comparaison entre la fraction de l'image extraite contenant les données, et la transcription. La reconnaissance étant faite grâce à un corpus en anglais, sans caractères spéciaux comme le "@" et sans accents, il y aura nécessairement quelques erreurs. Une fois les transcriptions corrigées, et pour chaque fichier, on pourra les enregistrer dans un fichier, afin de les intégrer à une autre base de données par exemple. | |||
## Technologie | |||
Le traitement du signal est assez basique, la reconnaissance de caractère est basée sur un projet sur GitHub, SimpleHTR, à cette adresse : https://github.com/githubharald/SimpleHTR. Ce projet a été modifié pour être incorporé dans le projet, et intervient plus comme un module plutôt qu'une intégration propre. |
@@ -0,0 +1,27 @@ | |||
import argparse | |||
import pickle | |||
import cv2 | |||
import lmdb | |||
from path import Path | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument('--data_dir', type=Path, required=True) | |||
args = parser.parse_args() | |||
# 2GB is enough for IAM dataset | |||
assert not (args.data_dir / 'lmdb').exists() | |||
env = lmdb.open(str(args.data_dir / 'lmdb'), map_size=1024 * 1024 * 1024 * 2) | |||
# go over all png files | |||
fn_imgs = list((args.data_dir / 'img').walkfiles('*.png')) | |||
# and put the imgs into lmdb as pickled grayscale imgs | |||
with env.begin(write=True) as txn: | |||
for i, fn_img in enumerate(fn_imgs): | |||
print(i, len(fn_imgs)) | |||
img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE) | |||
basename = fn_img.basename() | |||
txn.put(basename.encode("ascii"), pickle.dumps(img)) | |||
env.close() |
@@ -0,0 +1,133 @@ | |||
import pickle | |||
import random | |||
from collections import namedtuple | |||
import matplotlib.pyplot as plt | |||
import lmdb | |||
import numpy as np | |||
Sample = namedtuple("Sample", "gt_text, file_path") | |||
Batch = namedtuple("Batch", "imgs, gt_texts, batch_size") | |||
class DataLoaderIAM: | |||
""" | |||
Loads data which corresponds to IAM format, | |||
see: http://www.fki.inf.unibe.ch/databases/iam-handwriting-database | |||
""" | |||
def __init__(self, | |||
data_dir, | |||
batch_size, | |||
data_split = 0.95, | |||
fast = True): | |||
"""Loader for dataset.""" | |||
assert data_dir.exists() | |||
self.fast = fast | |||
if fast: | |||
self.env = lmdb.open(str(data_dir / "lmdb"), readonly=True) | |||
self.data_augmentation = False | |||
self.curr_idx = 0 | |||
self.batch_size = batch_size | |||
self.samples = [] | |||
f = open(data_dir / "gt/words.txt") | |||
chars = set() | |||
bad_samples_reference = ["a01-117-05-02", "r06-022-03-05"] # known broken images in IAM dataset | |||
for line in f: | |||
# ignore empty and comment lines | |||
line = line.strip() | |||
if not line or line[0] == "#": | |||
continue | |||
line_split = line.split(" ") | |||
assert len(line_split) >= 9 | |||
# filename: part1-part2-part3 --> part1/part1-part2/part1-part2-part3.png | |||
file_name_split = line_split[0].split("-") | |||
file_name_subdir1 = file_name_split[0] | |||
file_name_subdir2 = "{file_name_split[0]}-{file_name_split[1]}" | |||
file_base_name = line_split[0] + ".png" | |||
file_name = data_dir / "img" / file_name_subdir1 / file_name_subdir2 / file_base_name | |||
if line_split[0] in bad_samples_reference: | |||
print("Ignoring known broken image:", file_name) | |||
continue | |||
# GT text are columns starting at 9 | |||
gt_text = " ".join(line_split[8:]) | |||
chars = chars.union(set(list(gt_text))) | |||
# put sample into list | |||
self.samples.append(Sample(gt_text, file_name)) | |||
# split into training and validation set: 95% - 5% | |||
split_idx = int(data_split * len(self.samples)) | |||
self.train_samples = self.samples[:split_idx] | |||
self.validation_samples = self.samples[split_idx:] | |||
# put words into lists | |||
self.train_words = [x.gt_text for x in self.train_samples] | |||
self.validation_words = [x.gt_text for x in self.validation_samples] | |||
# start with train set | |||
self.train_set() | |||
# list of all chars in dataset | |||
self.char_list = sorted(list(chars)) | |||
def train_set(self): | |||
"""Switch to randomly chosen subset of training set.""" | |||
self.data_augmentation = True | |||
self.curr_idx = 0 | |||
random.shuffle(self.train_samples) | |||
self.samples = self.train_samples | |||
self.curr_set = "train" | |||
def validation_set(self): | |||
"""Switch to validation set.""" | |||
self.data_augmentation = False | |||
self.curr_idx = 0 | |||
self.samples = self.validation_samples | |||
self.curr_set = "val" | |||
def get_iterator_info(self): | |||
"""Current batch index and overall number of batches.""" | |||
if self.curr_set == "train": | |||
num_batches = int(np.floor(len(self.samples) / self.batch_size)) # train set: only full-sized batches | |||
else: | |||
num_batches = int(np.ceil(len(self.samples) / self.batch_size)) # val set: allow last batch to be smaller | |||
curr_batch = self.curr_idx // self.batch_size + 1 | |||
return curr_batch, num_batches | |||
def has_next(self): | |||
"""Is there a next element?""" | |||
if self.curr_set == "train": | |||
return self.curr_idx + self.batch_size <= len(self.samples) # train set: only full-sized batches | |||
else: | |||
return self.curr_idx < len(self.samples) # val set: allow last batch to be smaller | |||
def _get_img(self, i): | |||
if self.fast: | |||
with self.env.begin() as txn: | |||
basename = Path(self.samples[i].file_path).basename() | |||
data = txn.get(basename.encode("ascii")) | |||
img = pickle.loads(data) | |||
else: | |||
img = plt.imread(self.samples[i].file_path) | |||
img = (img[:,:,0]+img[:,:,1]+img[:,:,2])*255//3 | |||
return img | |||
def get_next(self): | |||
"""Get next element.""" | |||
batch_range = range(self.curr_idx, min(self.curr_idx + self.batch_size, len(self.samples))) | |||
imgs = [self._get_img(i) for i in batch_range] | |||
gt_texts = [self.samples[i].gt_text for i in batch_range] | |||
self.curr_idx += self.batch_size | |||
return Batch(imgs, gt_texts, len(imgs)) |
@@ -0,0 +1,313 @@ | |||
import os | |||
import sys | |||
from typing import List, Tuple | |||
import numpy as np | |||
import tensorflow as tf | |||
from bin.OCR_dataloader_iam import Batch | |||
# Disable eager mode | |||
tf.compat.v1.disable_eager_execution() | |||
class DecoderType: | |||
"""CTC decoder types.""" | |||
BestPath = 0 | |||
BeamSearch = 1 | |||
WordBeamSearch = 2 | |||
class Model: | |||
"""Minimalistic TF model for HTR.""" | |||
def __init__(self, | |||
char_list: List[str], | |||
FileNames, | |||
decoder_type: str = DecoderType.BestPath, | |||
must_restore: bool = False, | |||
dump: bool = False) -> None: | |||
"""Init model: add CNN, RNN and CTC and initialize TF.""" | |||
self.dump = dump | |||
self.char_list = char_list | |||
self.decoder_type = decoder_type | |||
self.must_restore = must_restore | |||
self.snap_ID = 0 | |||
self.FileNames = FileNames | |||
# Whether to use normalization over a batch or a population | |||
self.is_train = tf.compat.v1.placeholder(tf.bool, name='is_train') | |||
# input image batch | |||
self.input_imgs = tf.compat.v1.placeholder(tf.float32, shape=(None, None, None)) | |||
# setup CNN, RNN and CTC | |||
self.setup_cnn() | |||
self.setup_rnn() | |||
self.setup_ctc() | |||
# setup optimizer to train NN | |||
self.batches_trained = 0 | |||
self.update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) | |||
with tf.control_dependencies(self.update_ops): | |||
self.optimizer = tf.compat.v1.train.AdamOptimizer().minimize(self.loss) | |||
#add less verbosity | |||
tf.get_logger().setLevel("ERROR") | |||
# initialize TF | |||
self.sess, self.saver = self.setup_tf() | |||
def setup_cnn(self) -> None: | |||
"""Create CNN layers.""" | |||
cnn_in4d = tf.expand_dims(input=self.input_imgs, axis=3) | |||
# list of parameters for the layers | |||
kernel_vals = [5, 5, 3, 3, 3] | |||
feature_vals = [1, 32, 64, 128, 128, 256] | |||
stride_vals = pool_vals = [(2, 2), (2, 2), (1, 2), (1, 2), (1, 2)] | |||
num_layers = len(stride_vals) | |||
# create layers | |||
pool = cnn_in4d # input to first CNN layer | |||
for i in range(num_layers): | |||
kernel = tf.Variable( | |||
tf.random.truncated_normal([kernel_vals[i], kernel_vals[i], feature_vals[i], feature_vals[i + 1]], | |||
stddev=0.1)) | |||
conv = tf.nn.conv2d(input=pool, filters=kernel, padding='SAME', strides=(1, 1, 1, 1)) | |||
conv_norm = tf.compat.v1.layers.batch_normalization(conv, training=self.is_train) | |||
relu = tf.nn.relu(conv_norm) | |||
pool = tf.nn.max_pool2d(input=relu, ksize=(1, pool_vals[i][0], pool_vals[i][1], 1), | |||
strides=(1, stride_vals[i][0], stride_vals[i][1], 1), padding='VALID') | |||
self.cnn_out_4d = pool | |||
def setup_rnn(self) -> None: | |||
"""Create RNN layers.""" | |||
rnn_in3d = tf.squeeze(self.cnn_out_4d, axis=[2]) | |||
# basic cells which is used to build RNN | |||
num_hidden = 256 | |||
cells = [tf.compat.v1.nn.rnn_cell.LSTMCell(num_units=num_hidden, state_is_tuple=True) for _ in | |||
range(2)] # 2 layers | |||
# stack basic cells | |||
stacked = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=True) | |||
# bidirectional RNN | |||
# BxTxF -> BxTx2H | |||
(fw, bw), _ = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw=stacked, cell_bw=stacked, inputs=rnn_in3d, | |||
dtype=rnn_in3d.dtype) | |||
# BxTxH + BxTxH -> BxTx2H -> BxTx1X2H | |||
concat = tf.expand_dims(tf.concat([fw, bw], 2), 2) | |||
# project output to chars (including blank): BxTx1x2H -> BxTx1xC -> BxTxC | |||
kernel = tf.Variable(tf.random.truncated_normal([1, 1, num_hidden * 2, len(self.char_list) + 1], stddev=0.1)) | |||
self.rnn_out_3d = tf.squeeze(tf.nn.atrous_conv2d(value=concat, filters=kernel, rate=1, padding='SAME'), | |||
axis=[2]) | |||
def setup_ctc(self) -> None: | |||
"""Create CTC loss and decoder.""" | |||
# BxTxC -> TxBxC | |||
self.ctc_in_3d_tbc = tf.transpose(a=self.rnn_out_3d, perm=[1, 0, 2]) | |||
# ground truth text as sparse tensor | |||
self.gt_texts = tf.SparseTensor(tf.compat.v1.placeholder(tf.int64, shape=[None, 2]), | |||
tf.compat.v1.placeholder(tf.int32, [None]), | |||
tf.compat.v1.placeholder(tf.int64, [2])) | |||
# calc loss for batch | |||
self.seq_len = tf.compat.v1.placeholder(tf.int32, [None]) | |||
self.loss = tf.reduce_mean( | |||
input_tensor=tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.ctc_in_3d_tbc, | |||
sequence_length=self.seq_len, | |||
ctc_merge_repeated=True)) | |||
# calc loss for each element to compute label probability | |||
self.saved_ctc_input = tf.compat.v1.placeholder(tf.float32, | |||
shape=[None, None, len(self.char_list) + 1]) | |||
self.loss_per_element = tf.compat.v1.nn.ctc_loss(labels=self.gt_texts, inputs=self.saved_ctc_input, | |||
sequence_length=self.seq_len, ctc_merge_repeated=True) | |||
# best path decoding or beam search decoding | |||
if self.decoder_type == DecoderType.BestPath: | |||
self.decoder = tf.nn.ctc_greedy_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len) | |||
elif self.decoder_type == DecoderType.BeamSearch: | |||
self.decoder = tf.nn.ctc_beam_search_decoder(inputs=self.ctc_in_3d_tbc, sequence_length=self.seq_len, | |||
beam_width=50) | |||
# word beam search decoding (see https://github.com/githubharald/CTCWordBeamSearch) | |||
elif self.decoder_type == DecoderType.WordBeamSearch: | |||
# prepare information about language (dictionary, characters in dataset, characters forming words) | |||
chars = ''.join(self.char_list) | |||
word_chars = open('../model/wordCharList.txt').read().splitlines()[0] | |||
corpus = open('../data/corpus.txt').read() | |||
# decode using the "Words" mode of word beam search | |||
from word_beam_search import WordBeamSearch | |||
self.decoder = WordBeamSearch(50, 'Words', 0.0, corpus.encode('utf8'), chars.encode('utf8'), | |||
word_chars.encode('utf8')) | |||
# the input to the decoder must have softmax already applied | |||
self.wbs_input = tf.nn.softmax(self.ctc_in_3d_tbc, axis=2) | |||
def setup_tf(self) -> Tuple[tf.compat.v1.Session, tf.compat.v1.train.Saver]: | |||
"""Initialize TF.""" | |||
#print('Python: ' + sys.version) | |||
#print('Tensorflow: ' + tf.__version__) | |||
sess = tf.compat.v1.Session() # TF session | |||
saver = tf.compat.v1.train.Saver(max_to_keep=1) # saver saves model to file | |||
model_dir = self.FileNames.fn_model_path | |||
latest_snapshot = tf.train.latest_checkpoint(model_dir) # is there a saved model? | |||
# if model must be restored (for inference), there must be a snapshot | |||
if self.must_restore and not latest_snapshot: | |||
raise Exception('No saved model found in: ' + model_dir) | |||
# load saved model if available | |||
if latest_snapshot: | |||
#print('Init with stored values from ' + latest_snapshot) | |||
saver.restore(sess, latest_snapshot) | |||
else: | |||
#print('Init with new values') | |||
sess.run(tf.compat.v1.global_variables_initializer()) | |||
return sess, saver | |||
def to_sparse(self, texts: List[str]) -> Tuple[List[List[int]], List[int], List[int]]: | |||
"""Put ground truth texts into sparse tensor for ctc_loss.""" | |||
indices = [] | |||
values = [] | |||
shape = [len(texts), 0] # last entry must be max(labelList[i]) | |||
# go over all texts | |||
for batchElement, text in enumerate(texts): | |||
# convert to string of label (i.e. class-ids) | |||
label_str = [self.char_list.index(c) for c in text] | |||
# sparse tensor must have size of max. label-string | |||
if len(label_str) > shape[1]: | |||
shape[1] = len(label_str) | |||
# put each label into sparse tensor | |||
for i, label in enumerate(label_str): | |||
indices.append([batchElement, i]) | |||
values.append(label) | |||
return indices, values, shape | |||
def decoder_output_to_text(self, ctc_output: tuple, batch_size: int) -> List[str]: | |||
"""Extract texts from output of CTC decoder.""" | |||
# word beam search: already contains label strings | |||
if self.decoder_type == DecoderType.WordBeamSearch: | |||
label_strs = ctc_output | |||
# TF decoders: label strings are contained in sparse tensor | |||
else: | |||
# ctc returns tuple, first element is SparseTensor | |||
decoded = ctc_output[0][0] | |||
# contains string of labels for each batch element | |||
label_strs = [[] for _ in range(batch_size)] | |||
# go over all indices and save mapping: batch -> values | |||
for (idx, idx2d) in enumerate(decoded.indices): | |||
label = decoded.values[idx] | |||
batch_element = idx2d[0] # index according to [b,t] | |||
label_strs[batch_element].append(label) | |||
# map labels to chars for all batch elements | |||
return [''.join([self.char_list[c] for c in labelStr]) for labelStr in label_strs] | |||
def train_batch(self, batch: Batch) -> float: | |||
"""Feed a batch into the NN to train it.""" | |||
num_batch_elements = len(batch.imgs) | |||
max_text_len = batch.imgs[0].shape[0] // 4 | |||
sparse = self.to_sparse(batch.gt_texts) | |||
eval_list = [self.optimizer, self.loss] | |||
feed_dict = {self.input_imgs: batch.imgs, self.gt_texts: sparse, | |||
self.seq_len: [max_text_len] * num_batch_elements, self.is_train: True} | |||
_, loss_val = self.sess.run(eval_list, feed_dict) | |||
self.batches_trained += 1 | |||
return loss_val | |||
@staticmethod | |||
def dump_nn_output(rnn_output: np.ndarray) -> None: | |||
"""Dump the output of the NN to CSV file(s).""" | |||
dump_dir = self.FileNames.fn_dump_path | |||
if not os.path.isdir(dump_dir): | |||
os.mkdir(dump_dir) | |||
# iterate over all batch elements and create a CSV file for each one | |||
max_t, max_b, max_c = rnn_output.shape | |||
for b in range(max_b): | |||
csv = '' | |||
for t in range(max_t): | |||
for c in range(max_c): | |||
csv += str(rnn_output[t, b, c]) + ';' | |||
csv += '\n' | |||
fn = dump_dir + 'rnnOutput_' + str(b) + '.csv' | |||
#print('Write dump of NN to file: ' + fn) | |||
with open(fn, 'w') as f: | |||
f.write(csv) | |||
def infer_batch(self, batch: Batch, calc_probability: bool = False, probability_of_gt: bool = False): | |||
"""Feed a batch into the NN to recognize the texts.""" | |||
# decode, optionally save RNN output | |||
num_batch_elements = len(batch.imgs) | |||
# put tensors to be evaluated into list | |||
eval_list = [] | |||
if self.decoder_type == DecoderType.WordBeamSearch: | |||
eval_list.append(self.wbs_input) | |||
else: | |||
eval_list.append(self.decoder) | |||
if self.dump or calc_probability: | |||
eval_list.append(self.ctc_in_3d_tbc) | |||
# sequence length depends on input image size (model downsizes width by 4) | |||
max_text_len = batch.imgs[0].shape[0] // 4 | |||
# dict containing all tensor fed into the model | |||
feed_dict = {self.input_imgs: batch.imgs, self.seq_len: [max_text_len] * num_batch_elements, | |||
self.is_train: False} | |||
# evaluate model | |||
eval_res = self.sess.run(eval_list, feed_dict) | |||
# TF decoders: decoding already done in TF graph | |||
if self.decoder_type != DecoderType.WordBeamSearch: | |||
decoded = eval_res[0] | |||
# word beam search decoder: decoding is done in C++ function compute() | |||
else: | |||
decoded = self.decoder.compute(eval_res[0]) | |||
# map labels (numbers) to character string | |||
texts = self.decoder_output_to_text(decoded, num_batch_elements) | |||
# feed RNN output and recognized text into CTC loss to compute labeling probability | |||
probs = None | |||
if calc_probability: | |||
sparse = self.to_sparse(batch.gt_texts) if probability_of_gt else self.to_sparse(texts) | |||
ctc_input = eval_res[1] | |||
eval_list = self.loss_per_element | |||
feed_dict = {self.saved_ctc_input: ctc_input, self.gt_texts: sparse, | |||
self.seq_len: [max_text_len] * num_batch_elements, self.is_train: False} | |||
loss_vals = self.sess.run(eval_list, feed_dict) | |||
probs = np.exp(-loss_vals) | |||
# dump the output of the NN to CSV file(s) | |||
if self.dump: | |||
self.dump_nn_output(eval_res[1]) | |||
tf.compat.v1.reset_default_graph() | |||
return texts, probs | |||
def save(self) -> None: | |||
"""Save model to file.""" | |||
self.snap_ID += 1 | |||
self.saver.save(self.sess, '../model/snapshot', global_step=self.snap_ID) |
@@ -0,0 +1,194 @@ | |||
import random | |||
from typing import Tuple | |||
import matplotlib.pyplot as plt | |||
import numpy as np | |||
import scipy.ndimage as scp | |||
from bin.OCR_dataloader_iam import Batch | |||
class Preprocessor: | |||
def __init__(self, | |||
img_size: Tuple[int, int], | |||
padding: int = 0, | |||
dynamic_width: bool = False, | |||
data_augmentation: bool = False, | |||
line_mode: bool = False) -> None: | |||
# dynamic width only supported when no data augmentation happens | |||
assert not (dynamic_width and data_augmentation) | |||
# when padding is on, we need dynamic width enabled | |||
assert not (padding > 0 and not dynamic_width) | |||
self.img_size = img_size | |||
self.padding = padding | |||
self.dynamic_width = dynamic_width | |||
self.data_augmentation = data_augmentation | |||
self.line_mode = line_mode | |||
@staticmethod | |||
def _truncate_label(text: str, max_text_len: int) -> str: | |||
""" | |||
Function ctc_loss can't compute loss if it cannot find a mapping between text label and input | |||
labels. Repeat letters cost double because of the blank symbol needing to be inserted. | |||
If a too-long label is provided, ctc_loss returns an infinite gradient. | |||
""" | |||
cost = 0 | |||
for i in range(len(text)): | |||
if i != 0 and text[i] == text[i - 1]: | |||
cost += 2 | |||
else: | |||
cost += 1 | |||
if cost > max_text_len: | |||
return text[:i] | |||
return text | |||
def _simulate_text_line(self, batch: Batch) -> Batch: | |||
"""Create image of a text line by pasting multiple word images into an image.""" | |||
default_word_sep = 30 | |||
default_num_words = 5 | |||
# go over all batch elements | |||
res_imgs = [] | |||
res_gt_texts = [] | |||
for i in range(batch.batch_size): | |||
# number of words to put into current line | |||
num_words = random.randint(1, 8) if self.data_augmentation else default_num_words | |||
# concat ground truth texts | |||
curr_gt = ' '.join([batch.gt_texts[(i + j) % batch.batch_size] for j in range(num_words)]) | |||
res_gt_texts.append(curr_gt) | |||
# put selected word images into list, compute target image size | |||
sel_imgs = [] | |||
word_seps = [0] | |||
h = 0 | |||
w = 0 | |||
for j in range(num_words): | |||
curr_sel_img = batch.imgs[(i + j) % batch.batch_size] | |||
curr_word_sep = random.randint(20, 50) if self.data_augmentation else default_word_sep | |||
h = max(h, curr_sel_img.shape[0]) | |||
w += curr_sel_img.shape[1] | |||
sel_imgs.append(curr_sel_img) | |||
if j + 1 < num_words: | |||
w += curr_word_sep | |||
word_seps.append(curr_word_sep) | |||
# put all selected word images into target image | |||
target = np.ones([h, w], np.uint8) * 255 | |||
x = 0 | |||
for curr_sel_img, curr_word_sep in zip(sel_imgs, word_seps): | |||
x += curr_word_sep | |||
y = (h - curr_sel_img.shape[0]) // 2 | |||
target[y:y + curr_sel_img.shape[0]:, x:x + curr_sel_img.shape[1]] = curr_sel_img | |||
x += curr_sel_img.shape[1] | |||
# put image of line into result | |||
res_imgs.append(target) | |||
return Batch(res_imgs, res_gt_texts, batch.batch_size) | |||
def process_img(self, img: np.ndarray) -> np.ndarray: | |||
"""Resize to target size, apply data augmentation.""" | |||
# there are damaged files in IAM dataset - just use black image instead | |||
if img is None: | |||
img = np.zeros(self.img_size[::-1]) | |||
# data augmentation | |||
img = img.astype(np.float) | |||
if self.data_augmentation: | |||
# photometric data augmentation | |||
if random.random() < 0.25: | |||
def rand_odd(): | |||
return random.randint(1, 3) * 2 + 1 | |||
img = scp.gaussian_filter(img, (rand_odd(), rand_odd()), 0) #cv2.GaussianBlur(img, (rand_odd(), rand_odd()), 0) | |||
if random.random() < 0.25: | |||
img = scp.grey_dilation(img, structure=np.ones((3, 3))) | |||
if random.random() < 0.25: | |||
img = scp.grey_erosion(img, structure=np.ones((3, 3))) | |||
# geometric data augmentation | |||
wt, ht = self.img_size[0], self.img_size[1] | |||
h, w = img.shape[0], img.shape[1] | |||
f = min(wt / w, ht / h) | |||
fx = f * np.random.uniform(0.75, 1.05) | |||
fy = f * np.random.uniform(0.75, 1.05) | |||
# random position around center | |||
txc = (wt - w * fx) / 2 | |||
tyc = (ht - h * fy) / 2 | |||
freedom_x = max((wt - fx * w) / 2, 0) | |||
freedom_y = max((ht - fy * h) / 2, 0) | |||
tx = txc + np.random.uniform(-freedom_x, freedom_x) | |||
ty = tyc + np.random.uniform(-freedom_y, freedom_y) | |||
# map image into target image | |||
M = np.float32([[fy, 0, ty], [0, fx, tx], [0,0,1]]) | |||
M = np.linalg.inv(M) | |||
target = np.ones(self.img_size[::-1]) * 255 | |||
img = scp.affine_transform(img, M, output_shape=self.img_size, output=target, mode="nearest")#, borderMode=cv2.BORDER_TRANSPARENT) #cv2.warpAffine(img, M, dsize=self.img_size, dst=target, borderMode=cv2.BORDER_TRANSPARENT) | |||
# photometric data augmentation | |||
if random.random() < 0.5: | |||
img = img * (0.25 + random.random() * 0.75) | |||
if random.random() < 0.25: | |||
img = np.clip(img + (np.random.random(img.shape) - 0.5) * random.randint(1, 25), 0, 255) | |||
if random.random() < 0.1: | |||
img = 255 - img | |||
# no data augmentation | |||
else: | |||
if self.dynamic_width: | |||
ht = self.img_size[1] | |||
h, w = img.shape | |||
f = ht / h | |||
wt = int(f * w + self.padding) | |||
wt = wt + (4 - wt) % 4 | |||
tx = (wt - w * f) / 2 | |||
ty = 0 | |||
else: | |||
wt, ht = self.img_size[0], self.img_size[1] | |||
h, w = img.shape | |||
f = min(wt / w, ht / h) | |||
tx = (wt - w * f) / 2 | |||
ty = (ht - h * f) / 2 | |||
# map image into target image | |||
M = np.float32([[f, 0, ty], [0, f, tx], [0,0,1]]) | |||
M = np.linalg.inv(M) | |||
target = np.ones([ht, wt]) * 255 | |||
img = scp.affine_transform(img, M, output_shape=(ht,wt), output=target, mode="nearest")#, borderMode=cv2.BORDER_TRANSPARENT) #cv2.warpAffine(img, M, dsize=(wt, ht), dst=target, borderMode=cv2.BORDER_TRANSPARENT) | |||
# transpose for TF | |||
img = np.transpose(img) | |||
# convert to range [-1, 1] | |||
img = img / 255 - 0.5 | |||
return img | |||
def process_batch(self, batch: Batch) -> Batch: | |||
if self.line_mode: | |||
batch = self._simulate_text_line(batch) | |||
res_imgs = [self.process_img(img) for img in batch.imgs] | |||
max_text_len = res_imgs[0].shape[0] // 4 | |||
res_gt_texts = [self._truncate_label(gt_text, max_text_len) for gt_text in batch.gt_texts] | |||
return Batch(res_imgs, res_gt_texts, batch.batch_size) | |||
def main(): | |||
img = plt.imread("../data/test.png") | |||
img = img = (img[:,:,0]+img[:,:,1]+img[:,:,2])*255//3 | |||
img_aug = Preprocessor((256, 32), data_augmentation=True).process_img(img) | |||
plt.subplot(121) | |||
plt.imshow(img, cmap='gray') | |||
plt.subplot(122) | |||
plt.imshow(np.transpose(img_aug) + 0.5, cmap='gray', vmin=0, vmax=1) | |||
plt.show() | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,205 @@ | |||
import json | |||
import matplotlib.pyplot as plt | |||
from bin.signal_processing import levenshteinDistance | |||
from bin.OCR_dataloader_iam import DataLoaderIAM, Batch | |||
from bin.OCR_model import Model, DecoderType | |||
from bin.OCR_preprocessor import Preprocessor | |||
# class FilePaths: | |||
# """Filenames and paths to data.""" | |||
# fn_char_list = '../OCR_model/charList.txt' | |||
# fn_summary = '../OCR_model/summary.json' | |||
# fn_corpus = '../data/corpus.txt' | |||
def get_img_height(): | |||
"""Fixed height for NN.""" | |||
return 32 | |||
def get_img_size(line_mode = False): | |||
"""Height is fixed for NN, width is set according to training mode (single words or text lines).""" | |||
if line_mode: | |||
return 256, get_img_height() | |||
return (128, get_img_height()) | |||
def write_summary(char_error_rates, word_accuracies, FileNames): | |||
"""Writes training summary file for NN.""" | |||
with open(FileNames.fn_summary, 'w') as f: | |||
json.dump({'charErrorRates': char_error_rates, 'wordAccuracies': word_accuracies}, f) | |||
def char_list_from_file(FileNames): | |||
with open(FileNames.fn_char_list) as f: | |||
liste = list(f.read()) | |||
if(liste[-1] == "\n"): | |||
return liste[:-1] | |||
else: | |||
return liste | |||
def train(model, | |||
loader, | |||
line_mode, | |||
FileNames, | |||
early_stopping = 25): | |||
"""Trains NN.""" | |||
epoch = 0 # number of training epochs since start | |||
summary_char_error_rates = [] | |||
summary_word_accuracies = [] | |||
preprocessor = Preprocessor(get_img_size(line_mode), data_augmentation=True, line_mode=line_mode) | |||
best_char_error_rate = float('inf') # best validation character error rate | |||
no_improvement_since = 0 # number of epochs no improvement of character error rate occurred | |||
# stop training after this number of epochs without improvement | |||
while True: | |||
epoch += 1 | |||
print('Epoch:', epoch) | |||
# train | |||
print('Train NN') | |||
loader.train_set() | |||
while loader.has_next(): | |||
iter_info = loader.get_iterator_info() | |||
batch = loader.get_next() | |||
batch = preprocessor.process_batch(batch) | |||
loss = model.train_batch(batch) | |||
print("Epoch: {epoch} Batch: {iter_info[0]}/{iter_info[1]} Loss: {loss}") | |||
# validate | |||
char_error_rate, word_accuracy = validate(model, loader, line_mode) | |||
# write summary | |||
summary_char_error_rates.append(char_error_rate) | |||
summary_word_accuracies.append(word_accuracy) | |||
write_summary(summary_char_error_rates, summary_word_accuracies, FileNames) | |||
# if best validation accuracy so far, save model parameters | |||
if char_error_rate < best_char_error_rate: | |||
print('Character error rate improved, save model') | |||
best_char_error_rate = char_error_rate | |||
no_improvement_since = 0 | |||
model.save() | |||
else: | |||
print("Character error rate not improved, best so far: ",char_error_rate * 100.0,"%") | |||
no_improvement_since += 1 | |||
# stop training if no more improvement in the last x epochs | |||
if no_improvement_since >= early_stopping: | |||
print("No more improvement since {early_stopping} epochs. Training stopped.") | |||
break | |||
def validate(model, loader, line_mode): | |||
"""Validates NN.""" | |||
print("Validate NN") | |||
loader.validation_set() | |||
preprocessor = Preprocessor(get_img_size(line_mode), line_mode=line_mode) | |||
num_char_err = 0 | |||
num_char_total = 0 | |||
num_word_ok = 0 | |||
num_word_total = 0 | |||
while loader.has_next(): | |||
iter_info = loader.get_iterator_info() | |||
print("Batch: ",iter_info[0] / iter_info[1]) | |||
batch = loader.get_next() | |||
batch = preprocessor.process_batch(batch) | |||
recognized, _ = model.infer_batch(batch) | |||
print("Ground truth -> Recognized") | |||
for i in range(len(recognized)): | |||
num_word_ok += 1 if batch.gt_texts[i] == recognized[i] else 0 | |||
num_word_total += 1 | |||
dist = levenshteinDistance(recognized[i], batch.gt_texts[i]) | |||
num_char_err += dist | |||
num_char_total += len(batch.gt_texts[i]) | |||
print('[OK]' if dist == 0 else '[ERR:%d]' % dist, '"' + batch.gt_texts[i] + '"', '->', | |||
'"' + recognized[i] + '"') | |||
# print validation result | |||
char_error_rate = num_char_err / num_char_total | |||
word_accuracy = num_word_ok / num_word_total | |||
print("Character error rate: {char_error_rate * 100.0}%. Word accuracy: {word_accuracy * 100.0}%.") | |||
return (char_error_rate, word_accuracy) | |||
def infer(model, fn_img): | |||
"""Recognizes text in image provided by file path.""" | |||
if(type(fn_img) == str): | |||
img = plt.imread(fn_img) | |||
img = (img[:,:,0]+img[:,:,1]+img[:,:,2])*255//3 | |||
fn_img = img | |||
assert fn_img is not None | |||
preprocessor = Preprocessor(get_img_size(), dynamic_width=True, padding=16) | |||
img = preprocessor.process_img(fn_img) | |||
batch = Batch([img], None, 1) | |||
recognized, probability = model.infer_batch(batch, True) | |||
return recognized[0] | |||
def ocr_run(FileNames, img_file='../data/word.png', mode="infer", decoder="bestpath", batch_size=100, data_dir=False, fast=False, line_mode=False, early_stopping=25, dump=False): | |||
"""Main function. | |||
parser.add_argument('--mode', choices=['train', 'validate', 'infer'], default='infer') | |||
parser.add_argument('--decoder', choices=['bestpath', 'beamsearch', 'wordbeamsearch'], default='bestpath') | |||
parser.add_argument('--batch_size', help='Batch size.', type=int, default=100) | |||
parser.add_argument('--data_dir', help='Directory containing IAM dataset.', required=False) | |||
parser.add_argument('--fast', help='Load samples from LMDB.', action='store_true') | |||
parser.add_argument('--line_mode', help='Train to read text lines instead of single words.', action='store_true') | |||
parser.add_argument('--img_file', help='Image used for inference.', default='../data/word.png') | |||
parser.add_argument('--early_stopping', help='Early stopping epochs.', type=int, default=25) | |||
parser.add_argument('--dump', help='Dump output of NN to CSV file(s).', action='store_true') | |||
""" | |||
# parse arguments and set CTC decoder | |||
decoder_mapping = {'bestpath': DecoderType.BestPath, | |||
'beamsearch': DecoderType.BeamSearch, | |||
'wordbeamsearch': DecoderType.WordBeamSearch} | |||
decoder_type = decoder_mapping[decoder] | |||
# train the model | |||
if mode == 'train': | |||
loader = DataLoaderIAM(data_dir, batch_size, fast=fast) | |||
# when in line mode, take care to have a whitespace in the char list | |||
char_list = loader.char_list | |||
if line_mode and ' ' not in char_list: | |||
char_list = [' '] + char_list | |||
# save characters and words | |||
with open(FileNames.fn_char_list, 'w') as f: | |||
f.write(''.join(char_list)) | |||
with open(FileNames.fn_corpus, 'w') as f: | |||
f.write(' '.join(loader.train_words + loader.validation_words)) | |||
model = Model(char_list, FileNames=FileNames, decoder_type=decoder_type) | |||
train(model, loader, line_mode=line_mode, early_stopping=early_stopping, FileNames = FileNames) | |||
return 0 | |||
# evaluate it on the validation set | |||
elif mode == 'validate': | |||
loader = DataLoaderIAM(data_dir, batch_size, fast=fast) | |||
model = Model(char_list_from_file(FileNames), FileNames=FileNames, decoder_type=decoder_type, must_restore=True) | |||
validate(model, loader, line_mode) | |||
return 0 | |||
# infer text on test image | |||
elif mode == 'infer': | |||
model = Model(char_list_from_file(FileNames), FileNames=FileNames, decoder_type=decoder_type, must_restore=True, dump=dump) | |||
recognized = infer(model, img_file) | |||
del model | |||
return recognized | |||
if __name__ == '__main__': | |||
ocr_run() |
@@ -0,0 +1,295 @@ | |||
import numpy as np | |||
import matplotlib.pyplot as plt | |||
import matplotlib.image as mpimg | |||
import matplotlib.patches as patches | |||
import pickle | |||
import os | |||
import bin.caracter_recognition as ocr | |||
informations_types = ["text_box", "text_begin", "1case", "xcases"] | |||
class Coord(): | |||
def __init__(self): | |||
self.x = 0 | |||
self.y = 0 | |||
def modifier(self, x ,y): | |||
self.x=x | |||
self.y=y | |||
def affine(self, multx, multy, offsetx, offsety): | |||
self.x = int(self.x*multx + offsetx) | |||
self.y = int(self.y*multy + offsety) | |||
def reset(self): | |||
self.x=0 | |||
self.y=0 | |||
def test_null(self): | |||
if(self.x == None or self.y == None): | |||
self.reset() | |||
class Coord_data(): | |||
def __init__(self, name, type): | |||
self.name = name | |||
self.type = type | |||
self.content_if_checkbox = [] | |||
if(self.type > 1): | |||
self.nb_boxes = 0 | |||
while(self.nb_boxes <=0): | |||
self.nb_boxes = int(input("Nombre de cases pouvant etre cochees : ")) | |||
for i in range(0,self.nb_boxes): | |||
text = "Intitule case n°" + str(i+1) + " : " | |||
self.content_if_checkbox = str(input(text)) | |||
else: | |||
self.nb_boxes = 1 #default | |||
self.box = [] #upper left and lower right corners of the boxes, format [[Coord1(), Coord2()],...] | |||
self.temps_coordinates = None | |||
def box_coords_min_max(self): | |||
minx, miny, maxx, maxy = self.box[0][0].x, self.box[0][0].y, self.box[0][1].x, self.box[0][1].y | |||
for subbox in self.box: | |||
minx = min(minx, subbox[0].x, subbox[1].x) | |||
miny = min(miny, subbox[0].y, subbox[1].y) | |||
maxx = max(maxx, subbox[0].x, subbox[1].x) | |||
maxy = max(maxy, subbox[0].y, subbox[1].y) | |||
return minx, miny, maxx, maxy | |||
def define_box(self, x1, x2, y1, y2, n=1): | |||
if n > 1: | |||
self.box.append([Coord(), Coord()]) | |||
else: | |||
self.box = [[Coord(), Coord()]] | |||
self.box[n-1][0].modifier(x1, y1) | |||
self.box[n-1][1].modifier(x2, y2) | |||
def define_box_begin(self, x1, y1, y2): | |||
self.define_box(x1, x1, y1, y2) | |||
def define_ckeck_marks(self, liste, n): | |||
"""liste is a (n,4) list of coordinates, n the number | |||
of boxes""" | |||
for i in range(0,n): | |||
define_box(liste[i][0], liste[i][1], liste[i][2], liste[i][3], i) | |||
def mouse_one_event(self, event): | |||
if(self.temps_coordinates == None): # if only 1 point | |||
self.temps_coordinates = Coord() | |||
self.temps_coordinates.modifier(event.xdata, event.ydata) | |||
self.temps_coordinates.test_null() | |||
self.temps_coordinates.modifier(int(self.temps_coordinates.x), int(self.temps_coordinates.y)) | |||
else: #if second point | |||
temp2 = Coord() | |||
temp2.modifier(event.xdata, event.ydata) | |||
temp2.test_null() | |||
x, y = int(temp2.x), int(temp2.y) | |||
if(self.type != 1): | |||
self.define_box(self.temps_coordinates.x, x, self.temps_coordinates.y, y, len(self.box)+1) | |||
else: | |||
self.define_box_begin(self.temps_coordinates.x, self.temps_coordinates.y, y) | |||
self.box[-1][1].test_null() | |||
self.temps_coordinates = None | |||
def mouse_event(self, event): | |||
if(len(self.box) < self.nb_boxes or self.temps_coordinates != None): | |||
self.mouse_one_event(event) | |||
if(len(self.box) == self.nb_boxes and self.temps_coordinates == None): | |||
plt.close("all") | |||
class Template_File(): | |||
def __init__(self): | |||
self.path_template_img = "" #complete path of the template image | |||
self.template_img = 0 #will be an image after init | |||
self.informations_template_objects = [] | |||
self.data_path_dir = "" | |||
self.template_name = "" | |||
def open_files(self, path_template_img, path_template_obj): | |||
self.path_template_img = path_template_img | |||
self.data_path_dir = path_template_obj | |||
self.template_img = plt.imread(path_template_img) | |||
self.template_img = (self.template_img[:,:,0]+self.template_img[:,:,1]+self.template_img[:,:,2])*255//3 | |||
infos_object = open(path_template_obj, 'rb') | |||
self.informations_template_objects = pickle.load(infos_object) | |||
infos_object.close() | |||
self.template_name = path_template_obj[path_template_obj.rfind("/")+1:] | |||
def define_template_img(self): | |||
self.path_template_img = str(input("Chemin complet de l'image PNG du modèle : ")) | |||
self.template_img = mpimg.imread(self.path_template_img) | |||
self.template_name = str(input("Nom du template : ")) | |||
def add_template_information(self): | |||
fig = plt.figure(num="Emplacement de la donnée") | |||
plt.imshow(self.template_img) | |||
plt.axis('off') | |||
if(self.informations_template_objects != []): | |||
boxes_lists = self.informations_template_objects | |||
rects = [] | |||
for object_data in boxes_lists: | |||
minx, miny, maxx, maxy = object_data.box_coords_min_max() | |||
rect = patches.Rectangle((minx, miny), maxx-minx, maxy-miny, linewidth=1, edgecolor='r', facecolor='none') | |||
plt.text(minx, miny, str(object_data.name), verticalalignment='top') | |||
rects.append(rect) | |||
plt.gca().add_patch(rect) | |||
information_type = int(input(" 1. Information manuscrite delimitee\n 2. Début d'information manuscrite\n 3. Case exclusive\n 4. Cases à choix multiples\n-> ")) | |||
information_type = information_type-1 | |||
information_name = str(input("Catégorie de la donnée : ")) | |||
self.informations_template_objects.append(Coord_data(information_name, information_type)) | |||
mng = plt.get_current_fig_manager() | |||
cid = fig.canvas.mpl_connect("button_press_event", self.informations_template_objects[-1].mouse_event)#self.fig | |||
plt.show() | |||
fig.canvas.mpl_disconnect(cid) #self.fig | |||
#for i in range(0,self.informations_template_objects[0].nb_boxes): | |||
# print(self.informations_template_objects[0].box[i][0].x, self.informations_template_objects[0].box[i][0].y, | |||
# self.informations_template_objects[0].box[i][1].x, self.informations_template_objects[0].box[i][1].y) | |||
def show_template_boxes(self): | |||
boxes_lists = self.informations_template_objects | |||
fig = plt.figure(num="Emplacement des données") | |||
plt.imshow(self.template_img) | |||
rects = [] | |||
for object_data in boxes_lists: | |||
minx, miny, maxx, maxy = object_data.box_coords_min_max() | |||
rect = patches.Rectangle((minx, miny), maxx-minx, maxy-miny, linewidth=1, edgecolor='r', facecolor='none') | |||
plt.text(minx, miny, str(object_data.name), verticalalignment='top') | |||
rects.append(rect) | |||
plt.gca().add_patch(rect) | |||
plt.show() | |||
def save_data_file(self, folder_path, file_name, extension): | |||
"""folder_path is the absolute path of the folder""" | |||
file_list = os.listdir(folder_path) | |||
if(file_name+"."+extension in file_list): | |||
answer = str(input("Le fichier "+file_name+extension+" existe déjà. Le remplacer? O/N : ")) | |||
if(answer == "O"): | |||
file = open(folder_path+file_name+extension,"wb") #overwrite binary mode | |||
data = [self.informations_template_objects] | |||
pickle.dump(data, file) | |||
file.close() | |||
else: | |||
print("Fichier non enregistre.\n") | |||
else: | |||
file = open(folder_path+file_name+extension,"wb") #overwrite binary mode | |||
data = self.informations_template_objects | |||
pickle.dump(data, file) | |||
file.close() | |||
def save_img_template(self, folder_path, file_name): | |||
file_list = os.listdir(folder_path) | |||
img_file_name = self.path_template_img[self.path_template_img.rfind("/")+1:] | |||
source_file = self.path_template_img | |||
destination_file = folder_path + file_name | |||
if(file_name[1:] in file_list): | |||
figure = plt.figure() | |||
ax1 = figure.add_subplot(121) | |||
ax2 = figure.add_subplot(122) | |||
ax1.title.set_text("Image-template non-enregistree") | |||
ax2.title.set_text("Image deja enregistree") | |||
ax1.axis('off') | |||
ax2.axis('off') | |||
ax1.imshow(self.template_img) | |||
saved_image = mpimg.imread(destination_file) | |||
ax2.imshow(saved_image) | |||
figure.suptitle("Une images-templates existe deja avec ce nom. Fermez la fenetre.") | |||
plt.show() | |||
answer = str(input("Une images-templates existe deja avec ce nom. La remplacer? O/N : ")) | |||
if(answer == "O"): | |||
if(os.name == "posix"): | |||
os.popen("cp \"" + source_file + "\" \"" + destination_file+"\"") | |||
elif(os.name == "win32" or os.name == "windows"): | |||
os.popen("copy \"" + source_file + "\" \"" + destination_file + "\"") | |||
else: | |||
print("Fichier non enregistre.\n") | |||
else: | |||
if(os.name == "posix"): | |||
os.popen("cp \"" + source_file + "\" \"" + destination_file+"\"") | |||
elif(os.name == "win32" or os.name == "windows"): | |||
os.popen("copy \"" + source_file + "\" \"" + destination_file + "\"") | |||
def save_template(self, folder_path, img_file_name, templ_file_name, extension): | |||
"""folder_path is the absolute path of the folder""" | |||
test_file_list = os.listdir(folder_path[:folder_path.rfind("/")+1]) | |||
if(folder_path[folder_path.rfind("/")+1:] not in test_file_list): | |||
os.mkdir(folder_path) | |||
self.save_data_file(folder_path, templ_file_name, extension) | |||
self.save_img_template(folder_path, img_file_name) | |||
class Handwritten_Content(): | |||
def __init__(self, FilesNames): | |||
self.images = [] | |||
self.name = [] | |||
self.result = [] | |||
self.File_Names = FilesNames | |||
def extract_handwritten_content(self, template_object, img_template_resized, img_scanned, ratio, offset): | |||
if(img_template_resized.shape[0] < template_object.template_img.shape[0]): | |||
ratio = 1/ratio | |||
list_coord_data_objects = template_object.informations_template_objects | |||
for boxes_list in list_coord_data_objects: # list of Coord_data objects | |||
for box in boxes_list.box: #for each box | |||
for coord in box: #for each coordinate | |||
coord.affine(ratio, ratio, offset[1], offset[0]) | |||
for boxes_list in list_coord_data_objects: | |||
if(boxes_list.type == 0): #if fully delimited | |||
self.images.append(img_scanned[boxes_list.box[0][0].y:boxes_list.box[0][1].y,boxes_list.box[0][0].x:boxes_list.box[0][1].x]) | |||
self.name.append(boxes_list.name) | |||
sentence = ocr.ocr_run(img_file=self.images[-1], FileNames=self.File_Names) | |||
if sentence == 0: | |||
sentence = "" | |||
self.result.append(sentence) | |||
elif(boxes_list.type == 1): #if beginning delimited | |||
xlimit = img_scanned.shape[1]-1 | |||
self.images.append(img_scanned[boxes_list.box[0][0].y:boxes_list.box[0][1].y,boxes_list.box[0][0].x:xlimit]) | |||
self.name.append(boxes_list.name) | |||
sentence = ocr.ocr_run(img_file=self.images[-1], FileNames=self.File_Names) | |||
self.result.append(sentence) | |||
# elif(boxes_list.type == 2): #if exclusive box | |||
plt.imshow(self.images[0]) | |||
plt.show() | |||
plt.imshow(self.images[1]) | |||
plt.show() | |||
if __name__ == "__main__": | |||
#used for tests | |||
template = Template_File() | |||
template.define_template_img() | |||
template.add_template_information() | |||
template.add_template_information() | |||
template.show_template_boxes() | |||
#template.save_template("/home/inc0nnu-rol/Documents/La Gemme/OCR_paper_form/files", "/formulaire1", ".opdf") | |||
print("Execution success") |
@@ -0,0 +1,233 @@ | |||
import os | |||
import numpy as np | |||
import matplotlib.pyplot as plt | |||
import scipy as scp | |||
def levenshteinDistance(s1, s2): | |||
if len(s1) > len(s2): | |||
s1, s2 = s2, s1 | |||
distances = range(len(s1) + 1) | |||
for i2, c2 in enumerate(s2): | |||
distances_ = [i2+1] | |||
for i1, c1 in enumerate(s1): | |||
if c1 == c2: | |||
distances_.append(distances[i1]) | |||
else: | |||
distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1]))) | |||
distances = distances_ | |||
return distances[-1] | |||
def unit_gaussian_filter(sigma): | |||
"""Create and returns a 2D square gaussian filter""" | |||
x_grid = np.linspace(-3*int(sigma+1),3*int(sigma+1), 6*int(sigma+1)+1) | |||
X,Y = np.meshgrid(x_grid, x_grid) | |||
R = X*0 | |||
for i in range(0,R.shape[0]): | |||
for j in range(0,R.shape[1]): | |||
R[i,j] = 1 / (2*np.pi*sigma*sigma) * np.exp(-(X[i,j]*X[i,j] + Y[i,j]*Y[i,j]) / (2*sigma*sigma)) | |||
return R | |||
def calculate_img_correlation(img1, img2): | |||
"""Calculation of the correlation between the two images. | |||
Produces a 2D map, and returns the maximum of it.""" | |||
if(img2.size > img1.size): | |||
img2, img1 = img1, img2 | |||
#img_template = scp.signal.correlate2d(img1, img1, "same") | |||
img_corr = scp.signal.correlate2d(img1, img2, "same") | |||
# sum_template = np.sum(img_template) | |||
# sum_test = np.sum(img_test) | |||
# ratio = sum_template/sum_test | |||
# img_test = img_test*ratio | |||
# max_template = np.amax(img_template) | |||
max_corr = np.amax(img_corr) | |||
index = np.where(img_corr == max_corr) | |||
# ratio_max = max_template/max_test | |||
return max_corr, index | |||
def otsu_algo(img): | |||
"""Calculates the threshold between upper and lower values in an image. | |||
Uses the Otsu algorithm and returns the lower threshold value.""" | |||
histo, edges = np.histogram(img, bins=255, range=(0,255)) | |||
sum_tot = np.sum(histo) | |||
w = np.zeros(len(histo)) | |||
mu = np.zeros(len(histo)) | |||
w[0] = histo[0] | |||
mu[0] = 1*histo[0] | |||
for i in range(1,len(histo)): | |||
w[i] = w[i-1] + histo[i] | |||
mu[i] = mu[i-1] + (i+1)*histo[i] | |||
w = w/sum_tot | |||
mu = mu/sum_tot | |||
muN = mu[len(histo)-1] | |||
crit = np.zeros(len(histo)) | |||
for i in range(0,len(histo)): | |||
crit[i] = w[i]*(muN-mu[i])*(muN-mu[i]) + (1-w[i])*mu[i]*mu[i] | |||
max_crit = np.amax(crit) | |||
place = np.where(crit == max_crit) | |||
if(type(place) != int): | |||
place = place[0] | |||
thresh = int(np.floor(place[0]/len(histo)*np.amax(img))) | |||
thresh = edges[thresh] | |||
return thresh | |||
def img_cells_zero(img_origin, img_gauss, threshold): | |||
"""Switches off all the values of ing_gauss lower than the threshold in img_origin.""" | |||
for i in range(0,img_origin.shape[0]): | |||
for j in range(0,img_origin.shape[1]): | |||
if(img_gauss[i,j]<=threshold): | |||
img_origin[i,j]=0 | |||
def img_one_border_remove(img, start="begin", axis=0): | |||
"""Find the limit of a black border in an image. | |||
Use "begin" or "end" through the selected axis (0 or 1) to | |||
find the limit from the upper/left or lower/right borders of the image. | |||
Returns the index limit.""" | |||
if(start=="begin"): | |||
add = 1 | |||
i = 0 | |||
else: | |||
add = -1 | |||
i = img.shape[axis-1]-1 | |||
axis_sum = np.sum(img, axis=axis) | |||
while(axis_sum[i] == 0): | |||
i = i+add | |||
return i | |||
def img_black_borders_remove(img): | |||
"""Find the borders of an image and removes them.""" | |||
y_beg = img_one_border_remove(img) | |||
x_beg = img_one_border_remove(img, axis=1) | |||
y_end = img_one_border_remove(img, start="end") | |||
x_end = img_one_border_remove(img, start="end", axis=1) | |||
return img[x_beg:x_end, y_beg:y_end] | |||
def resize(img, shape): | |||
"""Resizes an image to a different shape with the bilinear method.""" | |||
img_final = np.zeros(shape) | |||
y_vect = np.linspace(0, img.shape[0]-1, shape[0]) | |||
x_vect = np.linspace(0, img.shape[1]-1, shape[1]) | |||
for j in range(0,shape[0]): | |||
for i in range(0,shape[1]): | |||
y0 = int(np.floor(y_vect[j])) | |||
y1 = int(np.ceil(y_vect[j])) | |||
x0 = int(np.floor(x_vect[i])) | |||
x1 = int(np.ceil(x_vect[i])) | |||
x_coma = x_vect[i]-np.floor(x_vect[i]) | |||
y_coma = y_vect[j]-np.floor(y_vect[j]) | |||
img_final[j,i] = (1-x_coma)*(1-y_coma)*img[y0,x0] + (x_coma)*(1-y_coma)*img[y0,x1] + (1-x_coma)*(y_coma)*img[y1,x0] + (x_coma)*(y_coma)*img[y1,x1] | |||
return img_final | |||
def signal_processing_process(img_template, img_scanned): | |||
img_gauss = scp.signal.correlate2d(img_scanned, unit_gaussian_filter(1.5), "same") | |||
thresh = otsu_algo(img_gauss) | |||
img_cells_zero(img_scanned, img_gauss, thresh) | |||
im_trunc = img_black_borders_remove(img_scanned) | |||
ratio = (img_template.shape[0]/im_trunc.shape[0], img_template.shape[1]/im_trunc.shape[1]) | |||
shape1 = (int(img_template.shape[0]/ratio[0]), int(img_template.shape[1]/ratio[0])) | |||
shape2 = (int(img_template.shape[0]/ratio[1]), int(img_template.shape[1]/ratio[1])) | |||
if(shape1[0] != shape2[0] or shape1[1] != shape2[1]): | |||
img_final_1 = resize(img_template, shape1) | |||
img_final_2 = resize(img_template, shape2) | |||
max1, location1 = calculate_img_correlation(img_final_1, im_trunc) | |||
max2, location2 = calculate_img_correlation(img_final_2, im_trunc) | |||
if(max1 > max2): | |||
img_final = img_final_1 | |||
ratio = ratio[0] | |||
location = location1 | |||
else: | |||
img_final = img_final_2 | |||
ratio = ratio[1] | |||
location = location2 | |||
offset = (img_final.shape[0]//2-location[0], img_final.shape[1]//2-location[1]) | |||
return img_final, im_trunc, ratio, offset | |||
# to be tried: | |||
# - delete black borders (gauss to highlight written content, blacken and delete what is not the paper) | |||
# - resize with root method | |||
# no rotation ! | |||
if __name__ == "__main__": | |||
#used for tests | |||
im1 = plt.imread("/home/inc0nnu-rol/Documents/La Gemme/OCR_paper_form/test.png") | |||
im2 = plt.imread("/home/inc0nnu-rol/Documents/La Gemme/OCR_paper_form/test2.png") | |||
im3 = plt.imread("/home/inc0nnu-rol/Documents/La Gemme/OCR_paper_form/test3.png") | |||
im1 = (im1[:,:,0]+im1[:,:,1]+im1[:,:,2])*255//3 | |||
im2 = (im2[:,:,0]+im2[:,:,1]+im2[:,:,2])*255//3 | |||
im3 = (im3[:,:,0]+im3[:,:,1]+im3[:,:,2])*255//3 | |||
img_fin, ratios, offset = signal_processing_process(im1, im3) | |||
# img_gauss = scp.signal.correlate2d(im3, unit_gaussian_filter(1.5), "same") | |||
# | |||
# thresh = otsu_algo(img_gauss) | |||
# img_cells_zero(im3, img_gauss, thresh) | |||
# im_trunc = img_black_borders_remove(im3) | |||
# | |||
# plt.imshow(im_trunc) | |||
# plt.show() | |||
# | |||
# ratio = (im1.shape[0]/im_trunc.shape[0], im1.shape[1]/im_trunc.shape[1]) | |||
# | |||
# shape1 = (int(im1.shape[0]/ratio[0]), int(im1.shape[1]/ratio[0])) | |||
# shape2 = (int(im1.shape[0]/ratio[1]), int(im1.shape[1]/ratio[1])) | |||
# | |||
# if(shape1[0] != shape2[0] or shape1[1] != shape2[1]): | |||
# img_final_1 = resize(im1, shape1) | |||
# img_final_2 = resize(im1, shape2) | |||
# max1 = calculate_img_correlation(img_final_1, im_trunc) | |||
# max2 = calculate_img_correlation(img_final_2, im_trunc) | |||
# if(max1 > max2): | |||
# img_final = img_final_1 | |||
# else: | |||
# img_final = img_final_2 | |||
print(ratios, offset) | |||
plt.imshow(img_fin) | |||
plt.show() | |||
# Gauss = unit_gaussian_filter(1.5) | |||
# im_template = scp.signal.correlate2d(im1, im1, "same") | |||
# im_filt = scp.signal.correlate2d(im2, im_trunc, "same") | |||
# | |||
# plt.imshow(im_test) | |||
# plt.show() | |||
# | |||
# sum_template = np.sum(im_template) | |||
# sum3 = np.sum(im_test) | |||
# ratio = sum_template/sum3 | |||
# im_test = im_test*ratio | |||
# | |||
# max_template = np.amax(im_template) | |||
# min_template = np.amin(im_template) | |||
# max_im_test = np.amax(im_test) | |||
# min_im_test = np.amin(im_test) | |||
# | |||
# ratio_max = max_template/max_im_test | |||
# ratio_min = min_template/min_im_test | |||
# | |||
# index_max_template = np.where(im_template == np.amax(im_template)) | |||
# index_min_template = np.where(im_template == np.amin(im_template)) | |||
# index_max_im_test = np.where(im_test == np.amax(im_test)) | |||
# index_min_im_test = np.where(im_test == np.amin(im_test)) | |||
# Gauss = unit_gaussian_filter(5) | |||
# plt.imshow(Gauss) | |||
# plt.show() | |||
# print(np.sum(Gauss)) | |||
#max_index = (bla[0][0], bla[1][0]) |
@@ -0,0 +1,92 @@ | |||
import os | |||
import numpy as np | |||
import matplotlib.pyplot as plt | |||
import bin.caracter_recognition as cocr | |||
import bin.signal_processing as ts | |||
import bin.form_data_places as form | |||
class FileNames: | |||
template_object_name = "/template" | |||
template_image_name = "/template_img.png" | |||
template_extension = ".opdf" | |||
general_path = os.path.abspath(os.path.dirname(__file__)) | |||
scanned_path = general_path + "/scanned" | |||
templates_path = general_path + "/templates" | |||
ocr_model_path = general_path + "/OCR_model" | |||
fn_model_path = general_path + "/OCR_model" | |||
fn_dump_path = general_path + "/dump" | |||
fn_char_list = general_path + "/OCR_model/charList.txt" | |||
fn_summary = general_path + "/OCR_model/summary.json" | |||
fn_corpus = general_path + "/data/corpus.txt" | |||
def image_scanned_processing(scanned_img_path, template_path): | |||
template_object_path = template_path + FileNames.template_object_name + FileNames.template_extension | |||
template_img_path = template_path + FileNames.template_image_name | |||
template_object = form.Template_File() | |||
template_object.open_files(template_img_path, template_object_path) | |||
scanned_img = plt.imread(scanned_img_path) | |||
scanned_img = (scanned_img[:,:,0]+scanned_img[:,:,1]+scanned_img[:,:,2])*255//3 | |||
img_template_resized, img_scanned_trunc, ratio, offset = ts.signal_processing_process(template_object.template_img, scanned_img) | |||
handwritten = form.Handwritten_Content(FileNames) | |||
handwritten.extract_handwritten_content(template_object, img_template_resized, img_scanned_trunc, ratio, offset) | |||
return handwritten | |||
def menu(): | |||
print(" OCR_paper_form 0.1 terminal\nCree par Lilian RM pour la Gemme\n License MIT\n") | |||
answer = int(input(" 1. Scanner formulaires\n 2. Creer modele\n 3. Modifier modele\n 4. Quitter\n->")) | |||
while(answer != 4): | |||
if(answer == 1): | |||
template_list = os.listdir(FileNames.templates_path) | |||
for i in range(0,len(template_list)): | |||
print(i+1, template_list[i]) | |||
template_choice = int(input("Modele a utiliser : ")) | |||
template_path = FileNames.templates_path+"/"+template_list[template_choice-1] | |||
# template_object = Template_File() | |||
# template_files = os.listdir(template_path) | |||
# if(FileNames.template_extension in template_files[0]): | |||
# template_object.open_files(template_path+"/"+template_files[1], template_path+"/"+template_files[0]) | |||
# else: | |||
# template_object.open_files(template_path+"/"+template_files[0], template_path+"/"+template_files[1]) | |||
print("Merci de verifier que tous les fichiers scannes sont:\n- en format PNG\n- dans le dossier interne \"scanned\" du logiciel\nSinon définissez le chemin complet a utiliser, ou appuyez sur Entree") | |||
path = str(input("->")) | |||
if(path != ""): | |||
FileNames.scanned_path = path | |||
file_list = os.listdir(FileNames.scanned_path) | |||
handwritten = [] | |||
for img_file in file_list: | |||
handwritten.append(image_scanned_processing(FileNames.scanned_path + "/" + img_file, template_path)) | |||
if(answer == 2): | |||
template_object = form.Template_File() | |||
template_object.define_template_img() | |||
add = "O" | |||
while(add != "N"): | |||
template_object.add_template_information() | |||
add = str(input("Ajouter une information? O/N : ")) | |||
template_object.show_template_boxes() | |||
template_object.save_template(FileNames.templates_path + "/" + template_object.template_name, FileNames.template_image_name, FileNames.template_object_name, FileNames.template_extension) | |||
# if(answer == 3): | |||
# template_list = os.listdir(FileNames.templates_path) | |||
# for i in range(0,len(template_list)): | |||
# print(i+1, template_list[i]) | |||
# template_choice = int(input("Modele a modifier : ")) | |||
# template_path = FileNames.templates_path+"/"+template_list[template_choice-1] | |||
answer = int(input("\n 1. Scanner formulaires\n 2. Creer modele\n 3. Modifier modele\n 4. Quitter\n->")) | |||
print("Execution success") | |||
if __name__ == "__main__": | |||
menu() |