diff --git a/controller.py b/controller.py index a40282df..34ea3a14 100644 --- a/controller.py +++ b/controller.py @@ -377,8 +377,10 @@ def callbackSelectedTranslationEngine(selected_translation_engine): def callbackToggleTranslation(is_turned_on): config.ENABLE_TRANSLATION = is_turned_on if config.ENABLE_TRANSLATION is True: + model.changeTranslatorCTranslate2Model() view.printToTextbox_enableTranslation() else: + model.clearTranslatorCTranslate2Model() view.printToTextbox_disableTranslation() def callbackToggleTranscriptionSend(is_turned_on): diff --git a/model.py b/model.py index f5913d17..e39b718f 100644 --- a/model.py +++ b/model.py @@ -66,8 +66,6 @@ class Model: self.speaker_energy_recorder = None self.speaker_energy_plot_progressbar = None self.translator = Translator() - if config.USE_TRANSLATION_FEATURE is True: - self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) self.keyword_processor = KeywordProcessor() def checkCTranslatorCTranslate2ModelWeight(self): @@ -76,6 +74,9 @@ class Model: def changeTranslatorCTranslate2Model(self): self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) + def clearTranslatorCTranslate2Model(self): + self.translator.clearCTranslate2Model() + def checkTranscriptionWhisperModelWeight(self): return checkWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE) diff --git a/models/translation/translation_translator.py b/models/translation/translation_translator.py index c966c672..a71d0f55 100644 --- a/models/translation/translation_translator.py +++ b/models/translation/translation_translator.py @@ -1,3 +1,4 @@ +import gc import os from deepl import Translator as deepl_Translator from translators import translate_text as other_web_Translator @@ -44,6 +45,13 @@ class Translator(): tokenizer_path = os.path.join("./weights", "ctranslate2", directory_name, "tokenizer") self.ctranslate2_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) + def clearCTranslate2Model(self): + del self.ctranslate2_translator + del self.ctranslate2_tokenizer + gc.collect() + self.ctranslate2_translator = None + self.ctranslate2_tokenizer = None + @staticmethod def getLanguageCode(translator_name, target_country, source_language, target_language): match translator_name: