[Update] Refactor compute device management: change methods to instance methods and set compute types to "auto"

This commit is contained in:
misyaguziya
2025-09-26 23:30:39 +09:00
parent 8c5f1b5db2
commit 5366622fca
2 changed files with 11 additions and 4 deletions

View File

@@ -656,13 +656,15 @@ class Controller:
def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict: def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict:
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE} return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
@staticmethod def setSelectedTranslationComputeDevice(self, device:str, *args, **kwargs) -> dict:
def setSelectedTranslationComputeDevice(device:str, *args, **kwargs) -> dict:
printLog("setSelectedTranslationComputeDevice", device) printLog("setSelectedTranslationComputeDevice", device)
pre_device = config.SELECTED_TRANSLATION_COMPUTE_DEVICE pre_device = config.SELECTED_TRANSLATION_COMPUTE_DEVICE
pre_compute_type = config.TRANSLATION_COMPUTE_TYPE
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device
config.TRANSLATION_COMPUTE_TYPE = "auto"
try: try:
model.changeTranslatorCTranslate2Model() model.changeTranslatorCTranslate2Model()
self.run(200, self.run_mapping["translation_compute_type"], config.TRANSLATION_COMPUTE_TYPE)
except Exception as e: except Exception as e:
# VRAM不足エラーの検出デバイス切り替え時 # VRAM不足エラーの検出デバイス切り替え時
is_vram_error, error_message = model.detectVRAMError(e) is_vram_error, error_message = model.detectVRAMError(e)
@@ -670,6 +672,7 @@ class Controller:
# 前のデバイス設定に戻す # 前のデバイス設定に戻す
printLog("VRAM error detected, reverting device setting") printLog("VRAM error detected, reverting device setting")
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = pre_device config.SELECTED_TRANSLATION_COMPUTE_DEVICE = pre_device
config.TRANSLATION_COMPUTE_TYPE = pre_compute_type
model.changeTranslatorCTranslate2Model() model.changeTranslatorCTranslate2Model()
else: else:
# その他のエラーは通常通り処理 # その他のエラーは通常通り処理
@@ -684,10 +687,11 @@ class Controller:
def getSelectedTranscriptionComputeDevice(*args, **kwargs) -> dict: def getSelectedTranscriptionComputeDevice(*args, **kwargs) -> dict:
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE} return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
@staticmethod def setSelectedTranscriptionComputeDevice(self, device:str, *args, **kwargs) -> dict:
def setSelectedTranscriptionComputeDevice(device:str, *args, **kwargs) -> dict:
printLog("setSelectedTranscriptionComputeDevice", device) printLog("setSelectedTranscriptionComputeDevice", device)
config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = device config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = device
config.TRANSCRIPTION_COMPUTE_TYPE = "auto"
self.run(200, self.run_mapping["transcription_compute_type"], config.TRANSCRIPTION_COMPUTE_TYPE)
return {"status":200,"result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE} return {"status":200,"result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
@staticmethod @staticmethod

View File

@@ -48,6 +48,9 @@ run_mapping = {
"selected_translation_engines":"/run/selected_translation_engines", "selected_translation_engines":"/run/selected_translation_engines",
"translation_engines":"/run/translation_engines", "translation_engines":"/run/translation_engines",
"translation_compute_type":"/run/translation_compute_type",
"transcription_compute_type":"/run/transcription_compute_type",
"mic_host_list":"/run/mic_host_list", "mic_host_list":"/run/mic_host_list",
"mic_device_list":"/run/mic_device_list", "mic_device_list":"/run/mic_device_list",
"speaker_device_list":"/run/speaker_device_list", "speaker_device_list":"/run/speaker_device_list",