diff --git a/controller.py b/controller.py index 634d3181..cb826080 100644 --- a/controller.py +++ b/controller.py @@ -400,10 +400,11 @@ def callbackSelectedTranslationEngine(selected_translation_engine): def callbackToggleTranslation(is_turned_on): config.ENABLE_TRANSLATION = is_turned_on if config.ENABLE_TRANSLATION is True: - model.changeTranslatorCTranslate2Model() + if model.isLoadedCTranslate2Model() is False: + model.changeTranslatorCTranslate2Model() view.printToTextbox_enableTranslation() else: - model.clearTranslatorCTranslate2Model() + # model.clearTranslatorCTranslate2Model() view.printToTextbox_disableTranslation() def callbackToggleTranscriptionSend(is_turned_on): diff --git a/model.py b/model.py index 7d89a0a9..d5d5c3e7 100644 --- a/model.py +++ b/model.py @@ -103,8 +103,11 @@ class Model: def changeTranslatorCTranslate2Model(self): self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) - def clearTranslatorCTranslate2Model(self): - self.translator.clearCTranslate2Model() + def isLoadedCTranslate2Model(self): + return self.translator.isLoadedCTranslate2Model() + + # def clearTranslatorCTranslate2Model(self): + # self.translator.clearCTranslate2Model() def checkTranscriptionWhisperModelWeight(self): return checkWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE) diff --git a/models/translation/translation_translator.py b/models/translation/translation_translator.py index d2717747..0ef71b88 100644 --- a/models/translation/translation_translator.py +++ b/models/translation/translation_translator.py @@ -14,6 +14,7 @@ class Translator(): self.deepl_client = None self.ctranslate2_translator = None self.ctranslate2_tokenizer = None + self.is_loaded_ctranslate2_model = False def authenticationDeepLAuthKey(self, authkey): result = True @@ -44,24 +45,30 @@ class Translator(): print("Error: changeCTranslate2Model()", e) tokenizer_path = os.path.join("./weights", "ctranslate2", directory_name, "tokenizer") self.ctranslate2_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) + self.is_loaded_ctranslate2_model = True - def clearCTranslate2Model(self): - del self.ctranslate2_translator - del self.ctranslate2_tokenizer - gc.collect() - self.ctranslate2_translator = None - self.ctranslate2_tokenizer = None + def isLoadedCTranslate2Model(self): + return self.is_loaded_ctranslate2_model + + # def clearCTranslate2Model(self): + # del self.ctranslate2_translator + # del self.ctranslate2_tokenizer + # gc.collect() + # self.ctranslate2_translator = None + # self.ctranslate2_tokenizer = None def translateCTranslate2(self, message, source_language, target_language): - try: - self.ctranslate2_tokenizer.src_lang = source_language - source = self.ctranslate2_tokenizer.convert_ids_to_tokens(self.ctranslate2_tokenizer.encode(message)) - target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]] - results = self.ctranslate2_translator.translate_batch([source], target_prefix=[target_prefix]) - target = results[0].hypotheses[0][1:] - result = self.ctranslate2_tokenizer.decode(self.ctranslate2_tokenizer.convert_tokens_to_ids(target)) - except Exception: - result = False + result = False + if self.is_loaded_ctranslate2_model is True: + try: + self.ctranslate2_tokenizer.src_lang = source_language + source = self.ctranslate2_tokenizer.convert_ids_to_tokens(self.ctranslate2_tokenizer.encode(message)) + target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]] + results = self.ctranslate2_translator.translate_batch([source], target_prefix=[target_prefix]) + target = results[0].hypotheses[0][1:] + result = self.ctranslate2_tokenizer.decode(self.ctranslate2_tokenizer.convert_tokens_to_ids(target)) + except Exception: + pass return result @staticmethod