Merge branch 'bugfix_compute_type' into develop
This commit is contained in:
@@ -747,13 +747,15 @@ class Controller:
|
||||
def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
||||
|
||||
@staticmethod
|
||||
def setSelectedTranslationComputeDevice(device:str, *args, **kwargs) -> dict:
|
||||
def setSelectedTranslationComputeDevice(self, device:str, *args, **kwargs) -> dict:
|
||||
printLog("setSelectedTranslationComputeDevice", device)
|
||||
pre_device = config.SELECTED_TRANSLATION_COMPUTE_DEVICE
|
||||
pre_compute_type = config.SELECTED_TRANSLATION_COMPUTE_TYPE
|
||||
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device
|
||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = "auto"
|
||||
try:
|
||||
model.changeTranslatorCTranslate2Model()
|
||||
self.run(200, self.run_mapping["selected_translation_compute_type"], config.SELECTED_TRANSLATION_COMPUTE_TYPE)
|
||||
except Exception as e:
|
||||
# VRAM不足エラーの検出(デバイス切り替え時)
|
||||
is_vram_error, error_message = model.detectVRAMError(e)
|
||||
@@ -761,6 +763,7 @@ class Controller:
|
||||
# 前のデバイス設定に戻す
|
||||
printLog("VRAM error detected, reverting device setting")
|
||||
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = pre_device
|
||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = pre_compute_type
|
||||
model.changeTranslatorCTranslate2Model()
|
||||
else:
|
||||
# その他のエラーは通常通り処理
|
||||
@@ -775,10 +778,11 @@ class Controller:
|
||||
def getSelectedTranscriptionComputeDevice(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
|
||||
|
||||
@staticmethod
|
||||
def setSelectedTranscriptionComputeDevice(device:str, *args, **kwargs) -> dict:
|
||||
def setSelectedTranscriptionComputeDevice(self, device:str, *args, **kwargs) -> dict:
|
||||
printLog("setSelectedTranscriptionComputeDevice", device)
|
||||
config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = device
|
||||
config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE = "auto"
|
||||
self.run(200, self.run_mapping["selected_transcription_compute_type"], config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE)
|
||||
return {"status":200,"result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
|
||||
|
||||
@staticmethod
|
||||
@@ -1549,6 +1553,7 @@ class Controller:
|
||||
|
||||
@staticmethod
|
||||
def setCtranslate2WeightType(data, *args, **kwargs) -> dict:
|
||||
pre_weight_type = config.CTRANSLATE2_WEIGHT_TYPE
|
||||
config.CTRANSLATE2_WEIGHT_TYPE = str(data)
|
||||
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
|
||||
def callback():
|
||||
@@ -1557,8 +1562,29 @@ class Controller:
|
||||
th_callback.daemon = True
|
||||
th_callback.start()
|
||||
th_callback.join()
|
||||
else:
|
||||
config.CTRANSLATE2_WEIGHT_TYPE = pre_weight_type
|
||||
return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def getSelectedTranslationComputeType(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def setSelectedTranslationComputeType(data, *args, **kwargs) -> dict:
|
||||
pre_compute_type = config.SELECTED_TRANSLATION_COMPUTE_TYPE
|
||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = str(data)
|
||||
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
|
||||
def callback():
|
||||
model.changeTranslatorCTranslate2Model()
|
||||
th_callback = Thread(target=callback)
|
||||
th_callback.daemon = True
|
||||
th_callback.start()
|
||||
th_callback.join()
|
||||
else:
|
||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = pre_compute_type
|
||||
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def getWhisperWeightType(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.WHISPER_WEIGHT_TYPE}
|
||||
@@ -1568,6 +1594,15 @@ class Controller:
|
||||
config.WHISPER_WEIGHT_TYPE = str(data)
|
||||
return {"status":200, "result": config.WHISPER_WEIGHT_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def getSelectedTranscriptionComputeType(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def setSelectedTranscriptionComputeType(data, *args, **kwargs) -> dict:
|
||||
config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE = str(data)
|
||||
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def getSendMessageFormatParts(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SEND_MESSAGE_FORMAT_PARTS}
|
||||
|
||||
Reference in New Issue
Block a user