diff --git a/config.py b/config.py index 7079fe2c..be98a68f 100644 --- a/config.py +++ b/config.py @@ -7,7 +7,6 @@ import tkinter as tk from tkinter import font from models.translation.translation_languages import translation_lang from models.transcription.transcription_utils import getInputDevices, getDefaultInputDevice -from models.translation.utils import ctranslate2_weights from utils import generatePercentageStringsList, isUniqueStrings json_serializable_vars = {} @@ -95,10 +94,6 @@ class Config: def SELECTABLE_UI_LANGUAGES_DICT(self): return self._SELECTABLE_UI_LANGUAGES_DICT - @property - def CTRANSLATE2_WEIGHTS(self): - return self._CTRANSLATE2_WEIGHTS - @property def MAX_MIC_ENERGY_THRESHOLD(self): return self._MAX_MIC_ENERGY_THRESHOLD @@ -731,7 +726,6 @@ class Config: "ko": "한국어" # If you want to add a new language and key, please append it here. } - self._CTRANSLATE2_WEIGHTS = ctranslate2_weights self._MAX_MIC_ENERGY_THRESHOLD = 2000 self._MAX_SPEAKER_ENERGY_THRESHOLD = 4000 @@ -805,7 +799,7 @@ class Config: self._AUTH_KEYS = { "DeepL_API": None, } - self._WEIGHT_TYPE = "small" + self._WEIGHT_TYPE = "m2m100_418m" self._SEND_MESSAGE_FORMAT = "[message]" self._SEND_MESSAGE_FORMAT_WITH_T = "[message]([translation])" self._RECEIVED_MESSAGE_FORMAT = "[message]" diff --git a/main.py b/main.py index 00bd351a..f990a89d 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ if __name__ == "__main__": from config import config from models.translation.utils import downloadCTranslate2Weight - downloadCTranslate2Weight(config.PATH_LOCAL, config.WEIGHT_TYPE, config.CTRANSLATE2_WEIGHTS, splash.updateDownloadProgress) + downloadCTranslate2Weight(config.PATH_LOCAL, config.WEIGHT_TYPE, splash.updateDownloadProgress) import controller controller.createMainWindow() diff --git a/model.py b/model.py index 33c30973..185e6dcc 100644 --- a/model.py +++ b/model.py @@ -63,12 +63,11 @@ class Model: self.speaker_audio_recorder = None self.speaker_energy_recorder = None self.speaker_energy_plot_progressbar = None - self.translator = Translator(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHTS[config.WEIGHT_TYPE]) + self.translator = Translator(config.PATH_LOCAL, config.WEIGHT_TYPE) self.keyword_processor = KeywordProcessor() - def resetTranslator(self): - del self.translator - self.translator = Translator(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHTS[config.WEIGHT_TYPE]) + def updateTranslator(self): + self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.WEIGHT_TYPE) def resetKeywordProcessor(self): del self.keyword_processor diff --git a/models/translation/translation_translator.py b/models/translation/translation_translator.py index 8c3e401b..47bd7dd1 100644 --- a/models/translation/translation_translator.py +++ b/models/translation/translation_translator.py @@ -2,20 +2,27 @@ import os from deepl import Translator as deepl_Translator from translators import translate_text as other_web_Translator from .translation_languages import translation_lang +from .utils import ctranslate2_weights import ctranslate2 import transformers # Translator class Translator(): - def __init__(self, path, weight_config): - self.translator_status = {} - directory_name = weight_config["directory_name"] - tokenizer = weight_config["tokenizer"] - self.weight_path = os.path.join(path, "weight", directory_name) - self.translator = ctranslate2.Translator(self.weight_path, device="cpu", device_index=0, compute_type="int8", inter_threads=1, intra_threads=4) - self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) + def __init__(self, path, model_type): self.deepl_client = None + directory_name = ctranslate2_weights[model_type]["directory_name"] + tokenizer = ctranslate2_weights[model_type]["tokenizer"] + weight_path = os.path.join(path, "weight", directory_name) + self.ctranslate2_translator = ctranslate2.Translator( + weight_path, + device="cpu", + device_index=0, + compute_type="int8", + inter_threads=1, + intra_threads=4 + ) + self.ctranslate2_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) def authenticationDeepLAuthKey(self, authkey): result = True @@ -27,6 +34,20 @@ class Translator(): result = False return result + def changeCTranslate2Model(self, path, model_type): + directory_name = ctranslate2_weights[model_type]["directory_name"] + tokenizer = ctranslate2_weights[model_type]["tokenizer"] + weight_path = os.path.join(path, "weight", directory_name) + self.ctranslate2_translator = ctranslate2.Translator( + weight_path, + device="cpu", + device_index=0, + compute_type="int8", + inter_threads=1, + intra_threads=4 + ) + self.ctranslate2_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) + def translate(self, translator_name, source_language, target_language, target_country, message): try: result = "" @@ -81,12 +102,12 @@ class Translator(): to_language=target_language, ) case "CTranslate2": - self.tokenizer.src_lang = source_language - source = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(message)) - target_prefix = [self.tokenizer.lang_code_to_token[target_language]] - results = self.translator.translate_batch([source], target_prefix=[target_prefix]) + self.ctranslate2_tokenizer.src_lang = source_language + source = self.ctranslate2_tokenizer.convert_ids_to_tokens(self.ctranslate2_tokenizer.encode(message)) + target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]] + results = self.ctranslate2_translator.translate_batch([source], target_prefix=[target_prefix]) target = results[0].hypotheses[0][1:] - result = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(target)) + result = self.ctranslate2_tokenizer.decode(self.ctranslate2_tokenizer.convert_tokens_to_ids(target)) except Exception: import traceback with open('error.log', 'a') as f: diff --git a/models/translation/utils.py b/models/translation/utils.py index f1c40b9e..3b71f87e 100644 --- a/models/translation/utils.py +++ b/models/translation/utils.py @@ -8,7 +8,7 @@ from typing import Callable import hashlib ctranslate2_weights = { - "small": { # M2M-100 418M-parameter model + "m2m100_418m": { # M2M-100 418M-parameter model "url": "https://bit.ly/33fM1AO", "directory_name": "m2m100_418m", "tokenizer": "facebook/m2m100_418M", @@ -18,7 +18,7 @@ ctranslate2_weights = { "shared_vocabulary.txt": "bd440aa21b8ca3453fc792a0018a1f3fe68b3464aadddd4d16a4b72f73c86d8c", } }, - "large": { # M2M-100 1.2B-parameter model + "m2m100_12b": { # M2M-100 1.2B-parameter model "url": "https://bit.ly/3GYiaed", "directory_name": "m2m100_12b", "tokenizer": "facebook/m2m100_1.2b", @@ -39,7 +39,7 @@ def calculate_file_hash(file_path, block_size=65536): return hash_object.hexdigest() -def downloadCTranslate2Weight(path, weight_type="small", ctranslate2_weights=ctranslate2_weights, func=None): +def downloadCTranslate2Weight(path, weight_type="small", func=None): url = ctranslate2_weights[weight_type]["url"] filename = 'weight.zip' directory_name = 'weight'