diff --git a/config.py b/config.py index 669c2e64..54b057a5 100644 --- a/config.py +++ b/config.py @@ -821,7 +821,7 @@ class Config: "DeepL_API": None, } self._USE_TRANSLATION_FEATURE = True - self._WEIGHT_TYPE = "m2m100_418m" + self._WEIGHT_TYPE = "Small" self._SEND_MESSAGE_FORMAT = "[message]" self._SEND_MESSAGE_FORMAT_WITH_T = "[message]([translation])" self._RECEIVED_MESSAGE_FORMAT = "[message]" diff --git a/controller.py b/controller.py index 90ebf0ac..d65d010d 100644 --- a/controller.py +++ b/controller.py @@ -488,6 +488,11 @@ def callbackSetUseTranslationFeature(value): config.USE_TRANSLATION_FEATURE = value if config.USE_TRANSLATION_FEATURE is True: view.useTranslationFeatureProcess("Normal") + if model.checkCTranslatorCTranslate2ModelWeight(): + model.changeTranslatorCTranslate2Model() + else: + view.useTranslationFeatureProcess("Disable") + # CTranslate2 weight is not downloaded else: view.useTranslationFeatureProcess("Disable") @@ -495,6 +500,11 @@ def callbackSetCtranslate2WeightType(value): print("callbackSetCtranslate2WeightType", value) config.WEIGHT_TYPE = str(value) view.updateSelectedCtranslate2WeightType(config.WEIGHT_TYPE) + if model.checkCTranslatorCTranslate2ModelWeight(): + model.changeTranslatorCTranslate2Model() + else: + view.useTranslationFeatureProcess("Disable") + # CTranslate2 weight is not downloaded def callbackSetDeeplAuthkey(value): print("callbackSetDeeplAuthkey", str(value)) diff --git a/main.py b/main.py index f990a89d..c160383d 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,8 @@ if __name__ == "__main__": from config import config from models.translation.utils import downloadCTranslate2Weight - downloadCTranslate2Weight(config.PATH_LOCAL, config.WEIGHT_TYPE, splash.updateDownloadProgress) + if config.USE_TRANSLATION_FEATURE is True: + downloadCTranslate2Weight(config.PATH_LOCAL, config.WEIGHT_TYPE, splash.updateDownloadProgress) import controller controller.createMainWindow() diff --git a/model.py b/model.py index 01277bb5..8580e85d 100644 --- a/model.py +++ b/model.py @@ -24,6 +24,7 @@ from models.transcription.transcription_transcriber import AudioTranscriber from models.xsoverlay.notification import xsoverlayForVRCT from models.translation.translation_languages import translation_lang from models.transcription.transcription_languages import transcription_lang +from models.translation.utils import checkCTranslate2Weight from config import config class threadFnc(Thread): @@ -63,10 +64,15 @@ class Model: self.speaker_audio_recorder = None self.speaker_energy_recorder = None self.speaker_energy_plot_progressbar = None - self.translator = Translator(config.PATH_LOCAL, config.WEIGHT_TYPE) + self.translator = Translator() + if config.USE_TRANSLATION_FEATURE is True: + self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.WEIGHT_TYPE) self.keyword_processor = KeywordProcessor() - def updateTranslator(self): + def checkCTranslatorCTranslate2ModelWeight(self): + return checkCTranslate2Weight(config.PATH_LOCAL, config.WEIGHT_TYPE) + + def changeTranslatorCTranslate2Model(self): self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.WEIGHT_TYPE) def resetKeywordProcessor(self): diff --git a/models/translation/translation_translator.py b/models/translation/translation_translator.py index 47bd7dd1..fbe6cf3f 100644 --- a/models/translation/translation_translator.py +++ b/models/translation/translation_translator.py @@ -9,20 +9,10 @@ import transformers # Translator class Translator(): - def __init__(self, path, model_type): + def __init__(self): self.deepl_client = None - directory_name = ctranslate2_weights[model_type]["directory_name"] - tokenizer = ctranslate2_weights[model_type]["tokenizer"] - weight_path = os.path.join(path, "weight", directory_name) - self.ctranslate2_translator = ctranslate2.Translator( - weight_path, - device="cpu", - device_index=0, - compute_type="int8", - inter_threads=1, - intra_threads=4 - ) - self.ctranslate2_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) + self.ctranslate2_translator = None + self.ctranslate2_tokenizer = None def authenticationDeepLAuthKey(self, authkey): result = True diff --git a/models/translation/utils.py b/models/translation/utils.py index 2f60459b..d47401cf 100644 --- a/models/translation/utils.py +++ b/models/translation/utils.py @@ -7,7 +7,7 @@ from typing import Callable import hashlib ctranslate2_weights = { - "m2m100_418m": { # M2M-100 418M-parameter model + "Small": { # M2M-100 418M-parameter model "url": "https://bit.ly/33fM1AO", "directory_name": "m2m100_418m", "tokenizer": "facebook/m2m100_418M", @@ -17,7 +17,7 @@ ctranslate2_weights = { "shared_vocabulary.txt": "bd440aa21b8ca3453fc792a0018a1f3fe68b3464aadddd4d16a4b72f73c86d8c", } }, - "m2m100_12b": { # M2M-100 1.2B-parameter model + "Large": { # M2M-100 1.2B-parameter model "url": "https://bit.ly/3GYiaed", "directory_name": "m2m100_12b", "tokenizer": "facebook/m2m100_1.2b", @@ -38,9 +38,7 @@ def calculate_file_hash(file_path, block_size=65536): return hash_object.hexdigest() -def downloadCTranslate2Weight(path, weight_type="m2m100_418m", func=None): - url = ctranslate2_weights[weight_type]["url"] - filename = 'weight.zip' +def checkCTranslate2Weight(path, weight_type="Small"): directory_name = 'weight' current_directory = path weight_directory_name = ctranslate2_weights[weight_type]["directory_name"] @@ -48,6 +46,7 @@ def downloadCTranslate2Weight(path, weight_type="m2m100_418m", func=None): files = ["model.bin", "sentencepiece.model", "shared_vocabulary.txt"] # check already downloaded + already_downloaded = False if all(os_path.exists(os_path.join(current_directory, directory_name, weight_directory_name, file)) for file in files): # check hash for file in files: @@ -55,6 +54,16 @@ def downloadCTranslate2Weight(path, weight_type="m2m100_418m", func=None): current_hash = calculate_file_hash(os_path.join(current_directory, directory_name, weight_directory_name, file)) if original_hash != current_hash: break + already_downloaded = True + return already_downloaded + +def downloadCTranslate2Weight(path, weight_type="Small", func=None): + url = ctranslate2_weights[weight_type]["url"] + filename = 'weight.zip' + directory_name = 'weight' + current_directory = path + + if checkCTranslate2Weight(path, weight_type): return try: diff --git a/view.py b/view.py index ae487ccc..52fb174c 100644 --- a/view.py +++ b/view.py @@ -278,8 +278,7 @@ class View(): VAR_DESC_CTRANSLATE2_WEIGHT_TYPE=StringVar(value=i18n.t("config_window.ctranslate2_weight_type.desc")), DICT_CTRANSLATE2_WEIGHT_TYPE=self.getSelectableCtranslate2WeightTypeDict(), CALLBACK_SET_CTRANSLATE2_WEIGHT_TYPE=None, - VAR_CTRANSLATE2_WEIGHT_TYPE=StringVar(value=self.getSelectableCtranslate2WeightTypeDict()["Small"]), - # VAR_CTRANSLATE2_WEIGHT_TYPE=StringVar(value=self.getSelectableCtranslate2WeightTypeDict()[config.WEIGHT_TYPE]), + VAR_CTRANSLATE2_WEIGHT_TYPE=StringVar(value=self.getSelectableCtranslate2WeightTypeDict()[config.WEIGHT_TYPE]), VAR_LABEL_DEEPL_AUTH_KEY=StringVar(value=i18n.t("config_window.deepl_auth_key.label")), VAR_DESC_DEEPL_AUTH_KEY=None, @@ -974,6 +973,11 @@ class View(): additional_widget.grid() self._closeMicWordFilterList() + def showRestartButton(self): + self._showRestartButton() + + def hideRestartButton(self): + self._hideRestartButton() def showRestartButtonIfRequired(self, locale:Union[None,str]=None): is_restart_required = not ( @@ -1013,9 +1017,7 @@ class View(): self.view_variable.VAR_CTRANSLATE2_WEIGHT_TYPE.set(self.getSelectableCtranslate2WeightTypeDict()[selected_weight_type]) def setLatestCTranslate2WeightType(self): - if config.WEIGHT_TYPE == "m2m100_418m": - WEIGHT_TYPE = "Small" - selected_weight_type = self.getSelectableCtranslate2WeightTypeDict()[WEIGHT_TYPE] + selected_weight_type = self.getSelectableCtranslate2WeightTypeDict()[config.WEIGHT_TYPE] self.view_variable.VAR_CTRANSLATE2_WEIGHT_TYPE.set(selected_weight_type)