[Update] Refactor compute type management: rename CTranslate2 and Whisper compute types to Translation and Transcription

This commit is contained in:
misyaguziya
2025-09-25 22:56:16 +09:00
parent 9d94fd6a5e
commit 92f9d645f8
4 changed files with 39 additions and 33 deletions

View File

@@ -815,15 +815,15 @@ class Config:
self.saveConfig(inspect.currentframe().f_code.co_name, value) self.saveConfig(inspect.currentframe().f_code.co_name, value)
@property @property
@json_serializable('CTRANSLATE2_COMPUTE_TYPE') @json_serializable('TRANSLATION_COMPUTE_TYPE')
def CTRANSLATE2_COMPUTE_TYPE(self): def TRANSLATION_COMPUTE_TYPE(self):
return self._CTRANSLATE2_COMPUTE_TYPE return self._TRANSLATION_COMPUTE_TYPE
@CTRANSLATE2_COMPUTE_TYPE.setter @TRANSLATION_COMPUTE_TYPE.setter
def CTRANSLATE2_COMPUTE_TYPE(self, value): def TRANSLATION_COMPUTE_TYPE(self, value):
if isinstance(value, str): if isinstance(value, str):
if value in self.SELECTED_TRANSLATION_COMPUTE_DEVICE["compute_type"]: if value in self.SELECTED_TRANSLATION_COMPUTE_DEVICE["compute_types"]:
self._CTRANSLATE2_COMPUTE_TYPE = value self._TRANSLATION_COMPUTE_TYPE = value
self.saveConfig(inspect.currentframe().f_code.co_name, value) self.saveConfig(inspect.currentframe().f_code.co_name, value)
@property @property
@@ -839,15 +839,15 @@ class Config:
self.saveConfig(inspect.currentframe().f_code.co_name, value) self.saveConfig(inspect.currentframe().f_code.co_name, value)
@property @property
@json_serializable('WHISPER_COMPUTE_TYPE') @json_serializable('TRANSCRIPTION_COMPUTE_TYPE')
def WHISPER_COMPUTE_TYPE(self): def TRANSCRIPTION_COMPUTE_TYPE(self):
return self._WHISPER_COMPUTE_TYPE return self._TRANSCRIPTION_COMPUTE_TYPE
@WHISPER_COMPUTE_TYPE.setter @TRANSCRIPTION_COMPUTE_TYPE.setter
def WHISPER_COMPUTE_TYPE(self, value): def TRANSCRIPTION_COMPUTE_TYPE(self, value):
if isinstance(value, str): if isinstance(value, str):
if value in self.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["compute_type"]: if value in self.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["compute_types"]:
self._WHISPER_COMPUTE_TYPE = value self._TRANSCRIPTION_COMPUTE_TYPE = value
self.saveConfig(inspect.currentframe().f_code.co_name, value) self.saveConfig(inspect.currentframe().f_code.co_name, value)
@property @property
@@ -1209,9 +1209,9 @@ class Config:
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0]) self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0]) self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
self._CTRANSLATE2_WEIGHT_TYPE = "small" self._CTRANSLATE2_WEIGHT_TYPE = "small"
self._CTRANSLATE2_COMPUTE_TYPE = "auto" self._TRANSLATION_COMPUTE_TYPE = "auto"
self._WHISPER_WEIGHT_TYPE = "base" self._WHISPER_WEIGHT_TYPE = "base"
self._WHISPER_COMPUTE_TYPE = "auto" self._TRANSCRIPTION_COMPUTE_TYPE = "auto"
self._AUTO_CLEAR_MESSAGE_BOX = True self._AUTO_CLEAR_MESSAGE_BOX = True
self._SEND_ONLY_TRANSLATED_MESSAGES = False self._SEND_ONLY_TRANSLATED_MESSAGES = False
self._OVERLAY_SMALL_LOG = False self._OVERLAY_SMALL_LOG = False

View File

@@ -1437,6 +1437,7 @@ class Controller:
@staticmethod @staticmethod
def setCtranslate2WeightType(data, *args, **kwargs) -> dict: def setCtranslate2WeightType(data, *args, **kwargs) -> dict:
pre_weight_type = config.CTRANSLATE2_WEIGHT_TYPE
config.CTRANSLATE2_WEIGHT_TYPE = str(data) config.CTRANSLATE2_WEIGHT_TYPE = str(data)
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE): if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
def callback(): def callback():
@@ -1445,15 +1446,18 @@ class Controller:
th_callback.daemon = True th_callback.daemon = True
th_callback.start() th_callback.start()
th_callback.join() th_callback.join()
else:
config.CTRANSLATE2_WEIGHT_TYPE = pre_weight_type
return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE} return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE}
@staticmethod @staticmethod
def getCtranslate2ComputeType(*args, **kwargs) -> dict: def getTranslationComputeType(*args, **kwargs) -> dict:
return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE} return {"status":200, "result":config.TRANSLATION_COMPUTE_TYPE}
@staticmethod @staticmethod
def setCtranslate2ComputeType(data, *args, **kwargs) -> dict: def setTranslationComputeType(data, *args, **kwargs) -> dict:
config.CTRANSLATE2_COMPUTE_TYPE = str(data) pre_compute_type = config.TRANSLATION_COMPUTE_TYPE
config.TRANSLATION_COMPUTE_TYPE = str(data)
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE): if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
def callback(): def callback():
model.changeTranslatorCTranslate2Model() model.changeTranslatorCTranslate2Model()
@@ -1461,7 +1465,9 @@ class Controller:
th_callback.daemon = True th_callback.daemon = True
th_callback.start() th_callback.start()
th_callback.join() th_callback.join()
return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE} else:
config.TRANSLATION_COMPUTE_TYPE = pre_compute_type
return {"status":200, "result":config.TRANSLATION_COMPUTE_TYPE}
@staticmethod @staticmethod
def getWhisperWeightType(*args, **kwargs) -> dict: def getWhisperWeightType(*args, **kwargs) -> dict:
@@ -1473,13 +1479,13 @@ class Controller:
return {"status":200, "result": config.WHISPER_WEIGHT_TYPE} return {"status":200, "result": config.WHISPER_WEIGHT_TYPE}
@staticmethod @staticmethod
def getWhisperComputeType(*args, **kwargs) -> dict: def getTranscriptionComputeType(*args, **kwargs) -> dict:
return {"status":200, "result":config.WHISPER_COMPUTE_TYPE} return {"status":200, "result":config.TRANSCRIPTION_COMPUTE_TYPE}
@staticmethod @staticmethod
def setWhisperComputeType(data, *args, **kwargs) -> dict: def setTranscriptionComputeType(data, *args, **kwargs) -> dict:
config.WHISPER_COMPUTE_TYPE = str(data) config.TRANSCRIPTION_COMPUTE_TYPE = str(data)
return {"status":200, "result":config.WHISPER_COMPUTE_TYPE} return {"status":200, "result":config.TRANSCRIPTION_COMPUTE_TYPE}
@staticmethod @staticmethod
def getSendMessageFormatParts(*args, **kwargs) -> dict: def getSendMessageFormatParts(*args, **kwargs) -> dict:

View File

@@ -162,8 +162,8 @@ mapping = {
"/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType}, "/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType},
"/set/data/ctranslate2_weight_type": {"status": True, "variable":controller.setCtranslate2WeightType}, "/set/data/ctranslate2_weight_type": {"status": True, "variable":controller.setCtranslate2WeightType},
"/get/data/ctranslate2_compute_type": {"status": True, "variable":controller.getCtranslate2ComputeType}, "/get/data/translation_compute_type": {"status": True, "variable":controller.getTranslationComputeType},
"/set/data/ctranslate2_compute_type": {"status": True, "variable":controller.setCtranslate2ComputeType}, "/set/data/translation_compute_type": {"status": True, "variable":controller.setTranslationComputeType},
"/run/download_ctranslate2_weight": {"status": True, "variable":controller.downloadCtranslate2Weight}, "/run/download_ctranslate2_weight": {"status": True, "variable":controller.downloadCtranslate2Weight},
@@ -268,8 +268,8 @@ mapping = {
"/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType}, "/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType},
"/set/data/whisper_weight_type": {"status": True, "variable":controller.setWhisperWeightType}, "/set/data/whisper_weight_type": {"status": True, "variable":controller.setWhisperWeightType},
"/get/data/whisper_compute_type": {"status": True, "variable":controller.getWhisperComputeType}, "/get/data/transcription_compute_type": {"status": True, "variable":controller.getTranscriptionComputeType},
"/set/data/whisper_compute_type": {"status": True, "variable":controller.setWhisperComputeType}, "/set/data/transcription_compute_type": {"status": True, "variable":controller.setTranscriptionComputeType},
"/run/download_whisper_weight": {"status": True, "variable":controller.downloadWhisperWeight}, "/run/download_whisper_weight": {"status": True, "variable":controller.downloadWhisperWeight},

View File

@@ -116,7 +116,7 @@ class Model:
model_type=config.CTRANSLATE2_WEIGHT_TYPE, model_type=config.CTRANSLATE2_WEIGHT_TYPE,
device=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device"], device=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device"],
device_index=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"], device_index=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"],
compute_type=config.CTRANSLATE2_COMPUTE_TYPE compute_type=config.TRANSLATION_COMPUTE_TYPE
) )
def downloadCTranslate2ModelWeight(self, weight_type, callback=None, end_callback=None): def downloadCTranslate2ModelWeight(self, weight_type, callback=None, end_callback=None):
@@ -440,7 +440,7 @@ class Model:
whisper_weight_type=config.WHISPER_WEIGHT_TYPE, whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"], device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"],
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"], device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
compute_type=config.WHISPER_COMPUTE_TYPE, compute_type=config.TRANSCRIPTION_COMPUTE_TYPE,
) )
def sendMicTranscript(): def sendMicTranscript():
try: try:
@@ -624,7 +624,7 @@ class Model:
whisper_weight_type=config.WHISPER_WEIGHT_TYPE, whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"], device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"],
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"], device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
compute_type=config.WHISPER_COMPUTE_TYPE, compute_type=config.TRANSCRIPTION_COMPUTE_TYPE,
) )
def sendSpeakerTranscript(): def sendSpeakerTranscript():
try: try: