[Update] Refactor compute type management: rename properties to 'SELECTED_TRANSLATION_COMPUTE_TYPE' and 'SELECTED_TRANSCRIPTION_COMPUTE_TYPE'

This commit is contained in:
misyaguziya
2025-09-27 07:07:54 +09:00
parent 5366622fca
commit 6effedcce2
4 changed files with 41 additions and 41 deletions

View File

@@ -659,12 +659,12 @@ class Controller:
def setSelectedTranslationComputeDevice(self, device:str, *args, **kwargs) -> dict:
printLog("setSelectedTranslationComputeDevice", device)
pre_device = config.SELECTED_TRANSLATION_COMPUTE_DEVICE
pre_compute_type = config.TRANSLATION_COMPUTE_TYPE
pre_compute_type = config.SELECTED_TRANSLATION_COMPUTE_TYPE
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device
config.TRANSLATION_COMPUTE_TYPE = "auto"
config.SELECTED_TRANSLATION_COMPUTE_TYPE = "auto"
try:
model.changeTranslatorCTranslate2Model()
self.run(200, self.run_mapping["translation_compute_type"], config.TRANSLATION_COMPUTE_TYPE)
self.run(200, self.run_mapping["selected_translation_compute_type"], config.SELECTED_TRANSLATION_COMPUTE_TYPE)
except Exception as e:
# VRAM不足エラーの検出デバイス切り替え時
is_vram_error, error_message = model.detectVRAMError(e)
@@ -672,7 +672,7 @@ class Controller:
# 前のデバイス設定に戻す
printLog("VRAM error detected, reverting device setting")
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = pre_device
config.TRANSLATION_COMPUTE_TYPE = pre_compute_type
config.SELECTED_TRANSLATION_COMPUTE_TYPE = pre_compute_type
model.changeTranslatorCTranslate2Model()
else:
# その他のエラーは通常通り処理
@@ -690,8 +690,8 @@ class Controller:
def setSelectedTranscriptionComputeDevice(self, device:str, *args, **kwargs) -> dict:
printLog("setSelectedTranscriptionComputeDevice", device)
config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = device
config.TRANSCRIPTION_COMPUTE_TYPE = "auto"
self.run(200, self.run_mapping["transcription_compute_type"], config.TRANSCRIPTION_COMPUTE_TYPE)
config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE = "auto"
self.run(200, self.run_mapping["selected_transcription_compute_type"], config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE)
return {"status":200,"result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
@staticmethod
@@ -1455,13 +1455,13 @@ class Controller:
return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE}
@staticmethod
def getTranslationComputeType(*args, **kwargs) -> dict:
return {"status":200, "result":config.TRANSLATION_COMPUTE_TYPE}
def getSelectedTranslationComputeType(*args, **kwargs) -> dict:
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_TYPE}
@staticmethod
def setTranslationComputeType(data, *args, **kwargs) -> dict:
pre_compute_type = config.TRANSLATION_COMPUTE_TYPE
config.TRANSLATION_COMPUTE_TYPE = str(data)
def setSelectedTranslationComputeType(data, *args, **kwargs) -> dict:
pre_compute_type = config.SELECTED_TRANSLATION_COMPUTE_TYPE
config.SELECTED_TRANSLATION_COMPUTE_TYPE = str(data)
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
def callback():
model.changeTranslatorCTranslate2Model()
@@ -1470,8 +1470,8 @@ class Controller:
th_callback.start()
th_callback.join()
else:
config.TRANSLATION_COMPUTE_TYPE = pre_compute_type
return {"status":200, "result":config.TRANSLATION_COMPUTE_TYPE}
config.SELECTED_TRANSLATION_COMPUTE_TYPE = pre_compute_type
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_TYPE}
@staticmethod
def getWhisperWeightType(*args, **kwargs) -> dict:
@@ -1483,13 +1483,13 @@ class Controller:
return {"status":200, "result": config.WHISPER_WEIGHT_TYPE}
@staticmethod
def getTranscriptionComputeType(*args, **kwargs) -> dict:
return {"status":200, "result":config.TRANSCRIPTION_COMPUTE_TYPE}
def getSelectedTranscriptionComputeType(*args, **kwargs) -> dict:
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE}
@staticmethod
def setTranscriptionComputeType(data, *args, **kwargs) -> dict:
config.TRANSCRIPTION_COMPUTE_TYPE = str(data)
return {"status":200, "result":config.TRANSCRIPTION_COMPUTE_TYPE}
def setSelectedTranscriptionComputeType(data, *args, **kwargs) -> dict:
config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE = str(data)
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE}
@staticmethod
def getSendMessageFormatParts(*args, **kwargs) -> dict: