[Update] Controller and Model: Refactor translation device management and add parameter change tracking
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import copy
|
||||||
from typing import Callable, Any
|
from typing import Callable, Any
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from subprocess import Popen
|
from subprocess import Popen
|
||||||
@@ -753,25 +754,10 @@ class Controller:
|
|||||||
|
|
||||||
def setSelectedTranslationComputeDevice(self, device:str, *args, **kwargs) -> dict:
|
def setSelectedTranslationComputeDevice(self, device:str, *args, **kwargs) -> dict:
|
||||||
printLog("setSelectedTranslationComputeDevice", device)
|
printLog("setSelectedTranslationComputeDevice", device)
|
||||||
pre_device = config.SELECTED_TRANSLATION_COMPUTE_DEVICE
|
|
||||||
pre_compute_type = config.SELECTED_TRANSLATION_COMPUTE_TYPE
|
|
||||||
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device
|
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device
|
||||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = "auto"
|
config.SELECTED_TRANSLATION_COMPUTE_TYPE = "auto"
|
||||||
try:
|
self.run(200, self.run_mapping["selected_translation_compute_type"], config.SELECTED_TRANSLATION_COMPUTE_TYPE)
|
||||||
model.changeTranslatorCTranslate2Model()
|
model.setChangedTranslatorParameters(True)
|
||||||
self.run(200, self.run_mapping["selected_translation_compute_type"], config.SELECTED_TRANSLATION_COMPUTE_TYPE)
|
|
||||||
except Exception as e:
|
|
||||||
# VRAM不足エラーの検出(デバイス切り替え時)
|
|
||||||
is_vram_error, error_message = model.detectVRAMError(e)
|
|
||||||
if is_vram_error:
|
|
||||||
# 前のデバイス設定に戻す
|
|
||||||
printLog("VRAM error detected, reverting device setting")
|
|
||||||
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = pre_device
|
|
||||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = pre_compute_type
|
|
||||||
model.changeTranslatorCTranslate2Model()
|
|
||||||
else:
|
|
||||||
# その他のエラーは通常通り処理
|
|
||||||
errorLogging()
|
|
||||||
return {"status":200,"result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
return {"status":200,"result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -801,12 +787,39 @@ class Controller:
|
|||||||
# def getMaxSpeakerThreshold(*args, **kwargs) -> dict:
|
# def getMaxSpeakerThreshold(*args, **kwargs) -> dict:
|
||||||
# return {"status":200, "result":config.MAX_SPEAKER_THRESHOLD}
|
# return {"status":200, "result":config.MAX_SPEAKER_THRESHOLD}
|
||||||
|
|
||||||
@staticmethod
|
def setEnableTranslation(self, *args, **kwargs) -> dict:
|
||||||
def setEnableTranslation(*args, **kwargs) -> dict:
|
|
||||||
if config.ENABLE_TRANSLATION is False:
|
if config.ENABLE_TRANSLATION is False:
|
||||||
if model.isLoadedCTranslate2Model() is False:
|
if model.isLoadedCTranslate2Model() is False or model.isChangedTranslatorParameters() is True:
|
||||||
model.changeTranslatorCTranslate2Model()
|
try:
|
||||||
config.ENABLE_TRANSLATION = True
|
model.changeTranslatorCTranslate2Model()
|
||||||
|
model.setChangedTranslatorParameters(False)
|
||||||
|
config.ENABLE_TRANSLATION = True
|
||||||
|
except Exception as e:
|
||||||
|
# VRAM不足エラーの検出(デバイス切り替え時)
|
||||||
|
is_vram_error, error_message = model.detectVRAMError(e)
|
||||||
|
if is_vram_error:
|
||||||
|
# Defaultのデバイス設定に戻す
|
||||||
|
printLog("VRAM error detected, reverting device setting")
|
||||||
|
self.setDisableTranslation()
|
||||||
|
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(config.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||||
|
config.SELECTED_TRANSLATION_COMPUTE_TYPE = "auto"
|
||||||
|
self.run(200, self.run_mapping["selected_translation_compute_device"], config.SELECTED_TRANSLATION_COMPUTE_DEVICE)
|
||||||
|
self.run(200, self.run_mapping["selected_translation_compute_type"], config.SELECTED_TRANSLATION_COMPUTE_TYPE)
|
||||||
|
self.run(
|
||||||
|
400,
|
||||||
|
self.run_mapping["enable_translation"],
|
||||||
|
{
|
||||||
|
"message":"Translation disabled due to VRAM overflow",
|
||||||
|
"data": False
|
||||||
|
},
|
||||||
|
)
|
||||||
|
model.changeTranslatorCTranslate2Model()
|
||||||
|
model.setChangedTranslatorParameters(False)
|
||||||
|
else:
|
||||||
|
# その他のエラーは通常通り処理
|
||||||
|
errorLogging()
|
||||||
|
else:
|
||||||
|
config.ENABLE_TRANSLATION = True
|
||||||
return {"status":200, "result":config.ENABLE_TRANSLATION}
|
return {"status":200, "result":config.ENABLE_TRANSLATION}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1571,17 +1584,8 @@ class Controller:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def setCtranslate2WeightType(data, *args, **kwargs) -> dict:
|
def setCtranslate2WeightType(data, *args, **kwargs) -> dict:
|
||||||
pre_weight_type = config.CTRANSLATE2_WEIGHT_TYPE
|
|
||||||
config.CTRANSLATE2_WEIGHT_TYPE = str(data)
|
config.CTRANSLATE2_WEIGHT_TYPE = str(data)
|
||||||
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
|
model.setChangedTranslatorParameters(True)
|
||||||
def callback():
|
|
||||||
model.changeTranslatorCTranslate2Model()
|
|
||||||
th_callback = Thread(target=callback)
|
|
||||||
th_callback.daemon = True
|
|
||||||
th_callback.start()
|
|
||||||
th_callback.join()
|
|
||||||
else:
|
|
||||||
config.CTRANSLATE2_WEIGHT_TYPE = pre_weight_type
|
|
||||||
return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE}
|
return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1590,17 +1594,8 @@ class Controller:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def setSelectedTranslationComputeType(data, *args, **kwargs) -> dict:
|
def setSelectedTranslationComputeType(data, *args, **kwargs) -> dict:
|
||||||
pre_compute_type = config.SELECTED_TRANSLATION_COMPUTE_TYPE
|
|
||||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = str(data)
|
config.SELECTED_TRANSLATION_COMPUTE_TYPE = str(data)
|
||||||
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
|
model.setChangedTranslatorParameters(True)
|
||||||
def callback():
|
|
||||||
model.changeTranslatorCTranslate2Model()
|
|
||||||
th_callback = Thread(target=callback)
|
|
||||||
th_callback.daemon = True
|
|
||||||
th_callback.start()
|
|
||||||
th_callback.join()
|
|
||||||
else:
|
|
||||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = pre_compute_type
|
|
||||||
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_TYPE}
|
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_TYPE}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -128,6 +128,12 @@ class Model:
|
|||||||
def isLoadedCTranslate2Model(self):
|
def isLoadedCTranslate2Model(self):
|
||||||
return self.translator.isLoadedCTranslate2Model()
|
return self.translator.isLoadedCTranslate2Model()
|
||||||
|
|
||||||
|
def isChangedTranslatorParameters(self):
|
||||||
|
return self.translator.isChangedTranslatorParameters()
|
||||||
|
|
||||||
|
def setChangedTranslatorParameters(self, is_changed):
|
||||||
|
self.translator.setChangedTranslatorParameters(is_changed)
|
||||||
|
|
||||||
def checkTranscriptionWhisperModelWeight(self, weight_type:str):
|
def checkTranscriptionWhisperModelWeight(self, weight_type:str):
|
||||||
return checkWhisperWeight(config.PATH_LOCAL, weight_type)
|
return checkWhisperWeight(config.PATH_LOCAL, weight_type)
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class Translator():
|
|||||||
self.ctranslate2_translator = None
|
self.ctranslate2_translator = None
|
||||||
self.ctranslate2_tokenizer = None
|
self.ctranslate2_tokenizer = None
|
||||||
self.is_loaded_ctranslate2_model = False
|
self.is_loaded_ctranslate2_model = False
|
||||||
|
self.is_changed_translator_parameters = False
|
||||||
self.is_enable_translators = ENABLE_TRANSLATORS
|
self.is_enable_translators = ENABLE_TRANSLATORS
|
||||||
|
|
||||||
def authenticationDeepLAuthKey(self, authkey):
|
def authenticationDeepLAuthKey(self, authkey):
|
||||||
@@ -64,6 +65,12 @@ class Translator():
|
|||||||
def isLoadedCTranslate2Model(self):
|
def isLoadedCTranslate2Model(self):
|
||||||
return self.is_loaded_ctranslate2_model
|
return self.is_loaded_ctranslate2_model
|
||||||
|
|
||||||
|
def isChangedTranslatorParameters(self):
|
||||||
|
return self.is_changed_translator_parameters
|
||||||
|
|
||||||
|
def setChangedTranslatorParameters(self, is_changed):
|
||||||
|
self.is_changed_translator_parameters = is_changed
|
||||||
|
|
||||||
def translateCTranslate2(self, message, source_language, target_language):
|
def translateCTranslate2(self, message, source_language, target_language):
|
||||||
result = False
|
result = False
|
||||||
if self.is_loaded_ctranslate2_model is True:
|
if self.is_loaded_ctranslate2_model is True:
|
||||||
|
|||||||
Reference in New Issue
Block a user