From 4572aee2b71c51160c20b78a703043e39fc841b6 Mon Sep 17 00:00:00 2001 From: misyaguziya <53165965+misyaguziya@users.noreply.github.com> Date: Mon, 6 Oct 2025 16:40:05 +0900 Subject: [PATCH] [Update] Controller and Model: Refactor translation device management and add parameter change tracking --- src-python/controller.py | 79 +++++++++---------- src-python/model.py | 6 ++ .../translation/translation_translator.py | 7 ++ 3 files changed, 50 insertions(+), 42 deletions(-) diff --git a/src-python/controller.py b/src-python/controller.py index 4015016f..f40a25ae 100644 --- a/src-python/controller.py +++ b/src-python/controller.py @@ -1,3 +1,4 @@ +import copy from typing import Callable, Any from time import sleep from subprocess import Popen @@ -753,25 +754,10 @@ class Controller: def setSelectedTranslationComputeDevice(self, device:str, *args, **kwargs) -> dict: printLog("setSelectedTranslationComputeDevice", device) - pre_device = config.SELECTED_TRANSLATION_COMPUTE_DEVICE - pre_compute_type = config.SELECTED_TRANSLATION_COMPUTE_TYPE config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device config.SELECTED_TRANSLATION_COMPUTE_TYPE = "auto" - try: - model.changeTranslatorCTranslate2Model() - self.run(200, self.run_mapping["selected_translation_compute_type"], config.SELECTED_TRANSLATION_COMPUTE_TYPE) - except Exception as e: - # VRAM不足エラーの検出(デバイス切り替え時) - is_vram_error, error_message = model.detectVRAMError(e) - if is_vram_error: - # 前のデバイス設定に戻す - printLog("VRAM error detected, reverting device setting") - config.SELECTED_TRANSLATION_COMPUTE_DEVICE = pre_device - config.SELECTED_TRANSLATION_COMPUTE_TYPE = pre_compute_type - model.changeTranslatorCTranslate2Model() - else: - # その他のエラーは通常通り処理 - errorLogging() + self.run(200, self.run_mapping["selected_translation_compute_type"], config.SELECTED_TRANSLATION_COMPUTE_TYPE) + model.setChangedTranslatorParameters(True) return {"status":200,"result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE} @staticmethod @@ -801,12 +787,39 @@ class Controller: # def getMaxSpeakerThreshold(*args, **kwargs) -> dict: # return {"status":200, "result":config.MAX_SPEAKER_THRESHOLD} - @staticmethod - def setEnableTranslation(*args, **kwargs) -> dict: + def setEnableTranslation(self, *args, **kwargs) -> dict: if config.ENABLE_TRANSLATION is False: - if model.isLoadedCTranslate2Model() is False: - model.changeTranslatorCTranslate2Model() - config.ENABLE_TRANSLATION = True + if model.isLoadedCTranslate2Model() is False or model.isChangedTranslatorParameters() is True: + try: + model.changeTranslatorCTranslate2Model() + model.setChangedTranslatorParameters(False) + config.ENABLE_TRANSLATION = True + except Exception as e: + # VRAM不足エラーの検出(デバイス切り替え時) + is_vram_error, error_message = model.detectVRAMError(e) + if is_vram_error: + # Defaultのデバイス設定に戻す + printLog("VRAM error detected, reverting device setting") + self.setDisableTranslation() + config.SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(config.SELECTABLE_COMPUTE_DEVICE_LIST[0]) + config.SELECTED_TRANSLATION_COMPUTE_TYPE = "auto" + self.run(200, self.run_mapping["selected_translation_compute_device"], config.SELECTED_TRANSLATION_COMPUTE_DEVICE) + self.run(200, self.run_mapping["selected_translation_compute_type"], config.SELECTED_TRANSLATION_COMPUTE_TYPE) + self.run( + 400, + self.run_mapping["enable_translation"], + { + "message":"Translation disabled due to VRAM overflow", + "data": False + }, + ) + model.changeTranslatorCTranslate2Model() + model.setChangedTranslatorParameters(False) + else: + # その他のエラーは通常通り処理 + errorLogging() + else: + config.ENABLE_TRANSLATION = True return {"status":200, "result":config.ENABLE_TRANSLATION} @staticmethod @@ -1571,17 +1584,8 @@ class Controller: @staticmethod def setCtranslate2WeightType(data, *args, **kwargs) -> dict: - pre_weight_type = config.CTRANSLATE2_WEIGHT_TYPE config.CTRANSLATE2_WEIGHT_TYPE = str(data) - if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE): - def callback(): - model.changeTranslatorCTranslate2Model() - th_callback = Thread(target=callback) - th_callback.daemon = True - th_callback.start() - th_callback.join() - else: - config.CTRANSLATE2_WEIGHT_TYPE = pre_weight_type + model.setChangedTranslatorParameters(True) return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE} @staticmethod @@ -1590,17 +1594,8 @@ class Controller: @staticmethod def setSelectedTranslationComputeType(data, *args, **kwargs) -> dict: - pre_compute_type = config.SELECTED_TRANSLATION_COMPUTE_TYPE config.SELECTED_TRANSLATION_COMPUTE_TYPE = str(data) - if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE): - def callback(): - model.changeTranslatorCTranslate2Model() - th_callback = Thread(target=callback) - th_callback.daemon = True - th_callback.start() - th_callback.join() - else: - config.SELECTED_TRANSLATION_COMPUTE_TYPE = pre_compute_type + model.setChangedTranslatorParameters(True) return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_TYPE} @staticmethod diff --git a/src-python/model.py b/src-python/model.py index 6048c630..bbf43604 100644 --- a/src-python/model.py +++ b/src-python/model.py @@ -128,6 +128,12 @@ class Model: def isLoadedCTranslate2Model(self): return self.translator.isLoadedCTranslate2Model() + def isChangedTranslatorParameters(self): + return self.translator.isChangedTranslatorParameters() + + def setChangedTranslatorParameters(self, is_changed): + self.translator.setChangedTranslatorParameters(is_changed) + def checkTranscriptionWhisperModelWeight(self, weight_type:str): return checkWhisperWeight(config.PATH_LOCAL, weight_type) diff --git a/src-python/models/translation/translation_translator.py b/src-python/models/translation/translation_translator.py index a9a1a56a..a12b326e 100644 --- a/src-python/models/translation/translation_translator.py +++ b/src-python/models/translation/translation_translator.py @@ -23,6 +23,7 @@ class Translator(): self.ctranslate2_translator = None self.ctranslate2_tokenizer = None self.is_loaded_ctranslate2_model = False + self.is_changed_translator_parameters = False self.is_enable_translators = ENABLE_TRANSLATORS def authenticationDeepLAuthKey(self, authkey): @@ -64,6 +65,12 @@ class Translator(): def isLoadedCTranslate2Model(self): return self.is_loaded_ctranslate2_model + def isChangedTranslatorParameters(self): + return self.is_changed_translator_parameters + + def setChangedTranslatorParameters(self, is_changed): + self.is_changed_translator_parameters = is_changed + def translateCTranslate2(self, message, source_language, target_language): result = False if self.is_loaded_ctranslate2_model is True: