diff --git a/src-python/config.py b/src-python/config.py index bd3af244..07926324 100644 --- a/src-python/config.py +++ b/src-python/config.py @@ -815,15 +815,15 @@ class Config: self.saveConfig(inspect.currentframe().f_code.co_name, value) @property - @json_serializable('CTRANSLATE2_COMPUTE_TYPE') - def CTRANSLATE2_COMPUTE_TYPE(self): - return self._CTRANSLATE2_COMPUTE_TYPE + @json_serializable('TRANSLATION_COMPUTE_TYPE') + def TRANSLATION_COMPUTE_TYPE(self): + return self._TRANSLATION_COMPUTE_TYPE - @CTRANSLATE2_COMPUTE_TYPE.setter - def CTRANSLATE2_COMPUTE_TYPE(self, value): + @TRANSLATION_COMPUTE_TYPE.setter + def TRANSLATION_COMPUTE_TYPE(self, value): if isinstance(value, str): - if value in self.SELECTED_TRANSLATION_COMPUTE_DEVICE["compute_type"]: - self._CTRANSLATE2_COMPUTE_TYPE = value + if value in self.SELECTED_TRANSLATION_COMPUTE_DEVICE["compute_types"]: + self._TRANSLATION_COMPUTE_TYPE = value self.saveConfig(inspect.currentframe().f_code.co_name, value) @property @@ -839,15 +839,15 @@ class Config: self.saveConfig(inspect.currentframe().f_code.co_name, value) @property - @json_serializable('WHISPER_COMPUTE_TYPE') - def WHISPER_COMPUTE_TYPE(self): - return self._WHISPER_COMPUTE_TYPE + @json_serializable('TRANSCRIPTION_COMPUTE_TYPE') + def TRANSCRIPTION_COMPUTE_TYPE(self): + return self._TRANSCRIPTION_COMPUTE_TYPE - @WHISPER_COMPUTE_TYPE.setter - def WHISPER_COMPUTE_TYPE(self, value): + @TRANSCRIPTION_COMPUTE_TYPE.setter + def TRANSCRIPTION_COMPUTE_TYPE(self, value): if isinstance(value, str): - if value in self.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["compute_type"]: - self._WHISPER_COMPUTE_TYPE = value + if value in self.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["compute_types"]: + self._TRANSCRIPTION_COMPUTE_TYPE = value self.saveConfig(inspect.currentframe().f_code.co_name, value) @property @@ -1209,9 +1209,9 @@ class Config: self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0]) self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0]) self._CTRANSLATE2_WEIGHT_TYPE = "small" - self._CTRANSLATE2_COMPUTE_TYPE = "auto" + self._TRANSLATION_COMPUTE_TYPE = "auto" self._WHISPER_WEIGHT_TYPE = "base" - self._WHISPER_COMPUTE_TYPE = "auto" + self._TRANSCRIPTION_COMPUTE_TYPE = "auto" self._AUTO_CLEAR_MESSAGE_BOX = True self._SEND_ONLY_TRANSLATED_MESSAGES = False self._OVERLAY_SMALL_LOG = False diff --git a/src-python/controller.py b/src-python/controller.py index 5ea86430..a746ecee 100644 --- a/src-python/controller.py +++ b/src-python/controller.py @@ -1437,6 +1437,7 @@ class Controller: @staticmethod def setCtranslate2WeightType(data, *args, **kwargs) -> dict: + pre_weight_type = config.CTRANSLATE2_WEIGHT_TYPE config.CTRANSLATE2_WEIGHT_TYPE = str(data) if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE): def callback(): @@ -1445,15 +1446,18 @@ class Controller: th_callback.daemon = True th_callback.start() th_callback.join() + else: + config.CTRANSLATE2_WEIGHT_TYPE = pre_weight_type return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE} @staticmethod - def getCtranslate2ComputeType(*args, **kwargs) -> dict: - return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE} + def getTranslationComputeType(*args, **kwargs) -> dict: + return {"status":200, "result":config.TRANSLATION_COMPUTE_TYPE} @staticmethod - def setCtranslate2ComputeType(data, *args, **kwargs) -> dict: - config.CTRANSLATE2_COMPUTE_TYPE = str(data) + def setTranslationComputeType(data, *args, **kwargs) -> dict: + pre_compute_type = config.TRANSLATION_COMPUTE_TYPE + config.TRANSLATION_COMPUTE_TYPE = str(data) if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE): def callback(): model.changeTranslatorCTranslate2Model() @@ -1461,7 +1465,9 @@ class Controller: th_callback.daemon = True th_callback.start() th_callback.join() - return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE} + else: + config.TRANSLATION_COMPUTE_TYPE = pre_compute_type + return {"status":200, "result":config.TRANSLATION_COMPUTE_TYPE} @staticmethod def getWhisperWeightType(*args, **kwargs) -> dict: @@ -1473,13 +1479,13 @@ class Controller: return {"status":200, "result": config.WHISPER_WEIGHT_TYPE} @staticmethod - def getWhisperComputeType(*args, **kwargs) -> dict: - return {"status":200, "result":config.WHISPER_COMPUTE_TYPE} + def getTranscriptionComputeType(*args, **kwargs) -> dict: + return {"status":200, "result":config.TRANSCRIPTION_COMPUTE_TYPE} @staticmethod - def setWhisperComputeType(data, *args, **kwargs) -> dict: - config.WHISPER_COMPUTE_TYPE = str(data) - return {"status":200, "result":config.WHISPER_COMPUTE_TYPE} + def setTranscriptionComputeType(data, *args, **kwargs) -> dict: + config.TRANSCRIPTION_COMPUTE_TYPE = str(data) + return {"status":200, "result":config.TRANSCRIPTION_COMPUTE_TYPE} @staticmethod def getSendMessageFormatParts(*args, **kwargs) -> dict: diff --git a/src-python/mainloop.py b/src-python/mainloop.py index 73e75594..a32fef8a 100644 --- a/src-python/mainloop.py +++ b/src-python/mainloop.py @@ -162,8 +162,8 @@ mapping = { "/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType}, "/set/data/ctranslate2_weight_type": {"status": True, "variable":controller.setCtranslate2WeightType}, - "/get/data/ctranslate2_compute_type": {"status": True, "variable":controller.getCtranslate2ComputeType}, - "/set/data/ctranslate2_compute_type": {"status": True, "variable":controller.setCtranslate2ComputeType}, + "/get/data/translation_compute_type": {"status": True, "variable":controller.getTranslationComputeType}, + "/set/data/translation_compute_type": {"status": True, "variable":controller.setTranslationComputeType}, "/run/download_ctranslate2_weight": {"status": True, "variable":controller.downloadCtranslate2Weight}, @@ -268,8 +268,8 @@ mapping = { "/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType}, "/set/data/whisper_weight_type": {"status": True, "variable":controller.setWhisperWeightType}, - "/get/data/whisper_compute_type": {"status": True, "variable":controller.getWhisperComputeType}, - "/set/data/whisper_compute_type": {"status": True, "variable":controller.setWhisperComputeType}, + "/get/data/transcription_compute_type": {"status": True, "variable":controller.getTranscriptionComputeType}, + "/set/data/transcription_compute_type": {"status": True, "variable":controller.setTranscriptionComputeType}, "/run/download_whisper_weight": {"status": True, "variable":controller.downloadWhisperWeight}, diff --git a/src-python/model.py b/src-python/model.py index 445b0a5e..639d375f 100644 --- a/src-python/model.py +++ b/src-python/model.py @@ -116,7 +116,7 @@ class Model: model_type=config.CTRANSLATE2_WEIGHT_TYPE, device=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device"], device_index=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"], - compute_type=config.CTRANSLATE2_COMPUTE_TYPE + compute_type=config.TRANSLATION_COMPUTE_TYPE ) def downloadCTranslate2ModelWeight(self, weight_type, callback=None, end_callback=None): @@ -440,7 +440,7 @@ class Model: whisper_weight_type=config.WHISPER_WEIGHT_TYPE, device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"], device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"], - compute_type=config.WHISPER_COMPUTE_TYPE, + compute_type=config.TRANSCRIPTION_COMPUTE_TYPE, ) def sendMicTranscript(): try: @@ -624,7 +624,7 @@ class Model: whisper_weight_type=config.WHISPER_WEIGHT_TYPE, device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"], device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"], - compute_type=config.WHISPER_COMPUTE_TYPE, + compute_type=config.TRANSCRIPTION_COMPUTE_TYPE, ) def sendSpeakerTranscript(): try: