1
0
mirror of https://gitee.com/sui-feng-cb/AzurLaneAutoScript1 synced 2026-04-30 16:00:14 +08:00
Files
AzurLaneAutoScript/module/base/ocr.py

174 lines
6.0 KiB
Python
Raw Normal View History

import time
import cv2
import numpy as np
from PIL import Image
from cnocr import CnOcr
from module.base.button import Button
from module.base.utils import extract_letters
from module.logger import logger
OCR_MODELS = {
# Font: Impact, AgencyFB
# Charset: 0123456789
'digit': CnOcr(root='./cnocr_models/digit', model_epoch=60),
# Font: Impact
# Charset: 0123456789ABCDEFSP-:/
'stage': CnOcr(root='./cnocr_models/stage', model_epoch=56),
'cnocr': CnOcr(root='./cnocr_models/cnocr', model_epoch=20)
}
image_shape = (280, 32)
width_range = (0.6, 1.4)
text_length = (1, 6)
text_interval = (0, 10)
y_range = (-2, 2)
class Ocr:
def __init__(self, buttons, lang, letter=(255, 255, 255), back=(0, 0, 0), mid_process_height=70, threshold=127,
additional_preprocess=None, use_binary=True, length=None, white_list=None, name='OCR'):
"""
Args:
lang (str): OCR model. in ['digit', 'cnocr'].
letter (tuple(int)): Letter RGB.
back (tuple(int)): Background RGB.
mid_process_height (int): 70
additional_preprocess (callable):
use_binary (bool):
length (int, tuple(int)): Expected length.
white_list (str): Expected str.
buttons (Button, List[Button]): Button or list of Button instance.
"""
self.lang = lang
self.cnocr = OCR_MODELS[lang]
self.letter = letter
self.back = back
self.mid_process_height = mid_process_height
self.threshold = threshold
self.additional_preprocess = additional_preprocess
self.use_binary = use_binary
self.length = (length, length) if isinstance(length, int) else length
self.white_list = white_list
self.buttons = buttons if isinstance(buttons, list) else [buttons]
self.name = str(buttons) if isinstance(buttons, Button) else name
def additional_preprocess_example(self, image):
"""
Args:
image (np.ndarray): data range: [0, 255], dtype: float. shape: [?, 70]
Returns:
np.ndarray: data range: [0, 255], dtype: float.
"""
pass
def pre_process(self, image):
"""
Args:
image: A cropped screenshot.
Returns:
np.ndarray: shape: [70, 280]. data range: [0, 1]
"""
# Resize to height=70.
size = (int(image.size[0] / image.size[1] * self.mid_process_height), self.mid_process_height)
image = image.resize(size, Image.BILINEAR)
# Set letter color to black, set background color to white.
image = extract_letters(image, letter=self.letter, back=self.back)
# Additional preprocess.
if self.additional_preprocess is not None:
image = self.additional_preprocess(image)
# Binarization.
if self.use_binary:
_, image = cv2.threshold(image, self.threshold, 255, cv2.THRESH_BINARY)
# Resize to input size.
size = (int(image.shape[1] / image.shape[0] * image_shape[1]), image_shape[1])
image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
# Left align
x = np.where(np.mean(image, axis=0) < 220)[0]
if len(x):
x = x[0] - 2 if x[0] - 2 >= 2 else 0
image = image[:, x:]
# Pad to image_shape=(280, 32)
diff_x = image_shape[0] - image.shape[1]
if diff_x > 0:
image = np.pad(image, ((0, 0), (0, diff_x)), mode='constant', constant_values=255)
else:
image = image[:, :image_shape[0]]
# Image.fromarray(image.astype('uint8')).show()
return image / 255.0
def after_process(self, result):
"""
Args:
result (list[str]): ['', '', '']
Returns:
str:
"""
result = ''.join(result)
if self.length is not None:
if len(result) > self.length[1] or len(result) < self.length[0]:
logger.warning(f'OCR result length unexpected. Expect: {self.length}. Result: {len(result)}')
if self.white_list:
for letter in result:
if letter not in self.white_list:
logger.warning(f'OCR letter unexpected. Letter: {letter}. White_list: {self.white_list}')
return result
def ocr(self, image):
start_time = time.time()
image_list = [self.pre_process(image.crop(button.area)) for button in self.buttons]
result_list = self.cnocr.ocr_for_single_lines(image_list)
result_list = [self.after_process(result) for result in result_list]
if len(self.buttons) == 1:
result_list = result_list[0]
logger.attr(name='%s %ss' % (self.name, str(round(time.time() - start_time, 3)).ljust(5, '0')),
text=str(result_list))
return result_list
class Digit(Ocr):
def __init__(self, buttons, letter=(255, 255, 255), back=(0, 0, 0), mid_process_height=70, threshold=127,
additional_preprocess=None, length=None, white_list=None, limit=None, name='OCR'):
super().__init__(buttons=buttons, lang='digit', letter=letter, back=back, mid_process_height=mid_process_height,
threshold=threshold,
additional_preprocess=additional_preprocess, length=length, white_list=white_list, name=name)
self.limit = (0, limit) if isinstance(limit, int) else limit
def after_process(self, raw):
"""
Returns:
int:
"""
raw = super().after_process(raw)
if not raw:
result = 0
else:
result = int(raw)
if self.limit:
if result < self.limit[0]:
logger.info(f'OCR result smaller than expected. Expect: {self.limit}. Raw: {raw}. Treat as: {result}')
result = self.limit[0]
if result > self.limit[1]:
logger.info(f'OCR result bigger than expected. Expect: {self.limit}. Raw: {raw}. Treat as: {result}')
result = self.limit[1]
return result