mirror of
https://gitee.com/sui-feng-cb/AzurLaneAutoScript1
synced 2026-03-11 23:18:22 +08:00
Opt: Delete low confidence result in OCR
This commit is contained in:
@@ -5,6 +5,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from cnocr import CnOcr
|
||||
from cnocr.cn_ocr import data_dir, read_charset, check_model_name, load_module, gen_network
|
||||
from cnocr.fit.ctc_metrics import CtcMetrics
|
||||
from cnocr.hyperparams.cn_hyperparams import CnHyperparams as Hyperparams
|
||||
|
||||
from module.logger import logger
|
||||
@@ -109,6 +110,29 @@ class AlOcr(CnOcr):
|
||||
img = np.expand_dims(img, 0).astype('float32') / 255.0
|
||||
return img
|
||||
|
||||
def _gen_line_pred_chars(self, line_prob, img_width, max_img_width):
|
||||
"""
|
||||
Get the predicted characters.
|
||||
:param line_prob: with shape of [seq_length, num_classes]
|
||||
:param img_width:
|
||||
:param max_img_width:
|
||||
:return:
|
||||
"""
|
||||
class_ids = np.argmax(line_prob, axis=-1)
|
||||
|
||||
class_ids *= np.max(line_prob, axis=-1) > 0.1 # Delete low confidence result
|
||||
|
||||
if img_width < max_img_width:
|
||||
comp_ratio = self._hp.seq_len_cmpr_ratio
|
||||
end_idx = img_width // comp_ratio
|
||||
if end_idx < len(class_ids):
|
||||
class_ids[end_idx:] = 0
|
||||
prediction, start_end_idx = CtcMetrics.ctc_label(class_ids.tolist())
|
||||
alphabet = self._alphabet
|
||||
res = [alphabet[p] if alphabet[p] != '<space>' else ' ' for p in prediction]
|
||||
|
||||
return res
|
||||
|
||||
def debug(self, img_list):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -64,10 +64,6 @@ class Ocr:
|
||||
"""
|
||||
image = extract_letters(image, letter=self.letter, threshold=self.threshold)
|
||||
|
||||
|
||||
# un-comment to see the image feed to ocr model
|
||||
# Image.fromarray(image.astype('uint8')).show()
|
||||
|
||||
return image.astype(np.uint8)
|
||||
|
||||
def after_process(self, result):
|
||||
@@ -82,12 +78,15 @@ class Ocr:
|
||||
|
||||
return result
|
||||
|
||||
def ocr(self, image):
|
||||
def ocr(self, image, direct_ocr=False):
|
||||
start_time = time.time()
|
||||
|
||||
if self.alphabet is not None:
|
||||
self.cnocr.set_cand_alphabet(self.alphabet)
|
||||
image_list = [self.pre_process(np.array(image.crop(area))) for area in self.buttons]
|
||||
if direct_ocr:
|
||||
image_list = [self.pre_process(np.array(i)) for i in image]
|
||||
else:
|
||||
image_list = [self.pre_process(np.array(image.crop(area))) for area in self.buttons]
|
||||
|
||||
# This will show the images feed to OCR model
|
||||
# self.cnocr.debug(image_list)
|
||||
|
||||
Reference in New Issue
Block a user