__copyright__ = "Copyright (C) 2023 Ali Mustapha" __license__ = "GPL-3.0-or-later" import tensorflow as tf import numpy as np import re import unicodedata from utils import data_utils from unidecode import unidecode # from utils import self class GenderPredictor: def __init__(self, model_path): self.model = self.load_model(model_path) def load_model(self, path): model = tf.keras.models.load_model(path) # Compile and train the model model.compile( loss=tf.keras.losses.categorical_crossentropy, optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'] ) return model def predict_gender(self, name): EMAIL_re = re.compile(r"^[^\s@]+@[^\s@]+$") proba=100 if EMAIL_re.match(name): prediction = 2 else: translator = str.maketrans(r"-._\/+", " ") name = name.translate(translator) name = data_utils.text_to_romanize(name) name=data_utils.remove_spaces_from_ends(name) if (len(name) < 3 or data_utils.is_most_common_char(name)) and data_utils.is_roman_language(name): prediction = 2 elif not data_utils.is_alpha(name): prediction = 2 else: translator = str.maketrans("", "", "0123456789") name = name.translate(translator) name = name.split()[0] if len(name.split()[0]) > 2 else name try: predictions_proba = self.model.predict([name], verbose=0).astype('float') prediction,proba = self.get_label(predictions_proba) prediction_part = [] if prediction == 2: parts = name.split() if len(parts) >1: for part in parts: prediction = self.predict_gender(part) prediction_part.append(prediction) prediction=data_utils.find_common_item(prediction_part) except Exception: prediction = 2 return prediction,proba def get_label(self, predictions_proba): for index, row in enumerate(predictions_proba): proba=100 if row[2] >= 0.1: max_index = 2 else: max_index = np.argmax(row) if max_index == 2: proba= 100 else: proba =int(row[max_index] * 100) return max_index,proba