From 245855d0ca9d0ee316f0e183be800b2959404c17 Mon Sep 17 00:00:00 2001 From: misyaguziya <53165965+misyaguziya@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:35:34 +0900 Subject: [PATCH] [Update] Add compute type management for CTranslate2 and Whisper models --- src-python/config.py | 38 ++++++++++++++++++- src-python/controller.py | 33 ++++++++++++++++ src-python/mainloop.py | 8 ++++ src-python/model.py | 12 ++++-- .../transcription_transcriber.py | 4 +- .../transcription/transcription_whisper.py | 5 ++- .../translation/translation_translator.py | 5 ++- src-python/utils.py | 5 ++- 8 files changed, 98 insertions(+), 12 deletions(-) diff --git a/src-python/config.py b/src-python/config.py index c544727d..1a605701 100644 --- a/src-python/config.py +++ b/src-python/config.py @@ -11,7 +11,7 @@ from models.translation.translation_languages import translation_lang from models.translation.translation_utils import ctranslate2_weights from models.transcription.transcription_languages import transcription_lang from models.transcription.transcription_whisper import _MODELS as whisper_models -from utils import errorLogging, validateDictStructure +from utils import errorLogging, validateDictStructure, getComputeTypeList json_serializable_vars = {} def json_serializable(var_name): @@ -135,6 +135,14 @@ class Config: def SELECTABLE_COMPUTE_DEVICE_LIST(self): return self._SELECTABLE_COMPUTE_DEVICE_LIST + @property + def SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST(self): + return self._SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST + + @property + def SELECTABLE_WHISPER_COMPUTE_TYPE_LIST(self): + return self._SELECTABLE_WHISPER_COMPUTE_TYPE_LIST + @property def SEND_MESSAGE_BUTTON_TYPE_LIST(self): return self._SEND_MESSAGE_BUTTON_TYPE_LIST @@ -814,6 +822,18 @@ class Config: self._CTRANSLATE2_WEIGHT_TYPE = value self.saveConfig(inspect.currentframe().f_code.co_name, value) + @property + @json_serializable('CTRANSLATE2_COMPUTE_TYPE') + def CTRANSLATE2_COMPUTE_TYPE(self): + return self._CTRANSLATE2_COMPUTE_TYPE + + @CTRANSLATE2_COMPUTE_TYPE.setter + def CTRANSLATE2_COMPUTE_TYPE(self, value): + if isinstance(value, str): + if value in self.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST: + self._CTRANSLATE2_COMPUTE_TYPE = value + self.saveConfig(inspect.currentframe().f_code.co_name, value) + @property @json_serializable('WHISPER_WEIGHT_TYPE') def WHISPER_WEIGHT_TYPE(self): @@ -826,6 +846,18 @@ class Config: self._WHISPER_WEIGHT_TYPE = value self.saveConfig(inspect.currentframe().f_code.co_name, value) + @property + @json_serializable('WHISPER_COMPUTE_TYPE') + def WHISPER_COMPUTE_TYPE(self): + return self._WHISPER_COMPUTE_TYPE + + @WHISPER_COMPUTE_TYPE.setter + def WHISPER_COMPUTE_TYPE(self, value): + if isinstance(value, str): + if value in self.SELECTABLE_WHISPER_COMPUTE_TYPE_LIST: + self._WHISPER_COMPUTE_TYPE = value + self.saveConfig(inspect.currentframe().f_code.co_name, value) + @property @json_serializable('AUTO_CLEAR_MESSAGE_BOX') def AUTO_CLEAR_MESSAGE_BOX(self): @@ -1051,6 +1083,8 @@ class Config: for i in range(torch.cuda.device_count()): self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cuda", "device_index": i, "device_name": torch.cuda.get_device_name(i)}) self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cpu", "device_index": 0, "device_name": "cpu"}) + self._SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST = ["auto"] + getComputeTypeList() + self._SELECTABLE_WHISPER_COMPUTE_TYPE_LIST = ["auto"] + getComputeTypeList() self._SEND_MESSAGE_BUTTON_TYPE_LIST = ["show", "hide", "show_and_disable_enter_key"] self._SEND_MESSAGE_FORMAT_PARTS = { "message": { @@ -1189,7 +1223,9 @@ class Config: self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0]) self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0]) self._CTRANSLATE2_WEIGHT_TYPE = "small" + self._CTRANSLATE2_COMPUTE_TYPE = "auto" self._WHISPER_WEIGHT_TYPE = "base" + self._WHISPER_COMPUTE_TYPE = "auto" self._AUTO_CLEAR_MESSAGE_BOX = True self._SEND_ONLY_TRANSLATED_MESSAGES = False self._OVERLAY_SMALL_LOG = False diff --git a/src-python/controller.py b/src-python/controller.py index c34abaf8..afa9d266 100644 --- a/src-python/controller.py +++ b/src-python/controller.py @@ -652,6 +652,14 @@ class Controller: def getComputeDeviceList(*args, **kwargs) -> dict: return {"status":200, "result":config.SELECTABLE_COMPUTE_DEVICE_LIST} + @staticmethod + def getCTranslate2ComputeTypeList(*args, **kwargs) -> dict: + return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST} + + @staticmethod + def getWhisperComputeTypeList(*args, **kwargs) -> dict: + return {"status":200, "result":config.SELECTABLE_WHISPER_COMPUTE_TYPE_LIST} + @staticmethod def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict: return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE} @@ -1447,6 +1455,22 @@ class Controller: th_callback.join() return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE} + @staticmethod + def getCtranslateComputeType(*args, **kwargs) -> dict: + return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE} + + @staticmethod + def setCtranslateComputeType(data, *args, **kwargs) -> dict: + config.CTRANSLATE2_COMPUTE_TYPE = str(data) + if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE): + def callback(): + model.changeTranslatorCTranslate2Model() + th_callback = Thread(target=callback) + th_callback.daemon = True + th_callback.start() + th_callback.join() + return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE} + @staticmethod def getWhisperWeightType(*args, **kwargs) -> dict: return {"status":200, "result":config.WHISPER_WEIGHT_TYPE} @@ -1456,6 +1480,15 @@ class Controller: config.WHISPER_WEIGHT_TYPE = str(data) return {"status":200, "result": config.WHISPER_WEIGHT_TYPE} + @staticmethod + def getWhisperComputeType(*args, **kwargs) -> dict: + return {"status":200, "result":config.WHISPER_COMPUTE_TYPE} + + @staticmethod + def setWhisperComputeType(data, *args, **kwargs) -> dict: + config.WHISPER_COMPUTE_TYPE = str(data) + return {"status":200, "result":config.WHISPER_COMPUTE_TYPE} + @staticmethod def getSendMessageFormatParts(*args, **kwargs) -> dict: return {"status":200, "result":config.SEND_MESSAGE_FORMAT_PARTS} diff --git a/src-python/mainloop.py b/src-python/mainloop.py index 0010b98a..2ad6e078 100644 --- a/src-python/mainloop.py +++ b/src-python/mainloop.py @@ -162,6 +162,9 @@ mapping = { "/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType}, "/set/data/ctranslate2_weight_type": {"status": True, "variable":controller.setCtranslate2WeightType}, + "/get/data/ctranslate2_compute_type": {"status": True, "variable":controller.getCtranslateComputeType}, + "/set/data/ctranslate2_compute_type": {"status": True, "variable":controller.setCtranslateComputeType}, + "/run/download_ctranslate2_weight": {"status": True, "variable":controller.downloadCtranslate2Weight}, "/get/data/deepl_auth_key": {"status": False, "variable":controller.getDeepLAuthKey}, @@ -261,8 +264,13 @@ mapping = { "/set/disable/check_speaker_threshold": {"status": True, "variable":controller.setDisableCheckSpeakerThreshold}, "/get/data/selectable_whisper_weight_type_dict": {"status": True, "variable":controller.getSelectableWhisperWeightTypeDict}, + "/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType}, "/set/data/whisper_weight_type": {"status": True, "variable":controller.setWhisperWeightType}, + + "/get/data/whisper_compute_type": {"status": True, "variable":controller.getWhisperComputeType}, + "/set/data/whisper_compute_type": {"status": True, "variable":controller.setWhisperComputeType}, + "/run/download_whisper_weight": {"status": True, "variable":controller.downloadWhisperWeight}, # VR diff --git a/src-python/model.py b/src-python/model.py index 333f1394..445b0a5e 100644 --- a/src-python/model.py +++ b/src-python/model.py @@ -112,10 +112,12 @@ class Model: def changeTranslatorCTranslate2Model(self): self.translator.changeCTranslate2Model( - config.PATH_LOCAL, - config.CTRANSLATE2_WEIGHT_TYPE, - config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device"], - config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"]) + path=config.PATH_LOCAL, + model_type=config.CTRANSLATE2_WEIGHT_TYPE, + device=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device"], + device_index=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"], + compute_type=config.CTRANSLATE2_COMPUTE_TYPE + ) def downloadCTranslate2ModelWeight(self, weight_type, callback=None, end_callback=None): return downloadCTranslate2Weight(config.PATH_LOCAL, weight_type, callback, end_callback) @@ -438,6 +440,7 @@ class Model: whisper_weight_type=config.WHISPER_WEIGHT_TYPE, device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"], device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"], + compute_type=config.WHISPER_COMPUTE_TYPE, ) def sendMicTranscript(): try: @@ -621,6 +624,7 @@ class Model: whisper_weight_type=config.WHISPER_WEIGHT_TYPE, device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"], device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"], + compute_type=config.WHISPER_COMPUTE_TYPE, ) def sendSpeakerTranscript(): try: diff --git a/src-python/models/transcription/transcription_transcriber.py b/src-python/models/transcription/transcription_transcriber.py index 5407253a..9d874b30 100644 --- a/src-python/models/transcription/transcription_transcriber.py +++ b/src-python/models/transcription/transcription_transcriber.py @@ -21,7 +21,7 @@ PHRASE_TIMEOUT = 3 MAX_PHRASES = 10 class AudioTranscriber: - def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0): + def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0, compute_type="auto"): self.speaker = speaker self.phrase_timeout = phrase_timeout self.max_phrases = max_phrases @@ -41,7 +41,7 @@ class AudioTranscriber: } if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True: - self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index) + self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index, compute_type=compute_type) self.transcription_engine = "Whisper" def transcribeAudioQueue(self, audio_queue, languages, countries, avg_logprob=-0.8, no_speech_prob=0.6): diff --git a/src-python/models/transcription/transcription_whisper.py b/src-python/models/transcription/transcription_whisper.py index 04f89626..5f61a121 100644 --- a/src-python/models/transcription/transcription_whisper.py +++ b/src-python/models/transcription/transcription_whisper.py @@ -74,9 +74,10 @@ def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None): if isinstance(end_callback, Callable): end_callback() -def getWhisperModel(root, weight_type, device="cpu", device_index=0): +def getWhisperModel(root, weight_type, device="cpu", device_index=0, compute_type="auto"): path = os_path.join(root, "weights", "whisper", weight_type) - compute_type = getBestComputeType(device, device_index) + if compute_type == "auto": + compute_type = getBestComputeType(device, device_index) try: model = WhisperModel( path, diff --git a/src-python/models/translation/translation_translator.py b/src-python/models/translation/translation_translator.py index 42eb828e..897fcd1b 100644 --- a/src-python/models/translation/translation_translator.py +++ b/src-python/models/translation/translation_translator.py @@ -36,14 +36,15 @@ class Translator(): result = False return result - def changeCTranslate2Model(self, path, model_type, device="cpu", device_index=0): + def changeCTranslate2Model(self, path, model_type, device="cpu", device_index=0, compute_type="auto"): self.is_loaded_ctranslate2_model = False directory_name = ctranslate2_weights[model_type]["directory_name"] tokenizer = ctranslate2_weights[model_type]["tokenizer"] weight_path = os_path.join(path, "weights", "ctranslate2", directory_name) tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer") - compute_type = getBestComputeType(device, device_index) + if compute_type == "auto": + compute_type = getBestComputeType(device, device_index) self.ctranslate2_translator = ctranslate2.Translator( weight_path, device=device, diff --git a/src-python/utils.py b/src-python/utils.py index c3a857f2..1b28fcf6 100644 --- a/src-python/utils.py +++ b/src-python/utils.py @@ -78,10 +78,13 @@ def isValidIpAddress(ip_address: str) -> bool: except ValueError: return False +def getComputeTypeList() -> list: + return ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"] + def getBestComputeType(device, device_index) -> str: compute_types = get_supported_compute_types(device, device_index) compute_types = set(compute_types) - preferred_types = ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"] + preferred_types = getComputeTypeList() for preferred_type in preferred_types: if preferred_type in compute_types: