Merge branch 'init_optimize' into for_webui

This commit is contained in:
misyaguziya
2024-10-31 17:21:40 +09:00
2 changed files with 26 additions and 18 deletions

View File

@@ -118,6 +118,8 @@ class Model:
return self.translator.isLoadedCTranslate2Model() return self.translator.isLoadedCTranslate2Model()
def checkTranscriptionWhisperModelWeight(self, weight_type:str): def checkTranscriptionWhisperModelWeight(self, weight_type:str):
if weight_type == "none":
return True
return checkWhisperWeight(config.PATH_LOCAL, weight_type) return checkWhisperWeight(config.PATH_LOCAL, weight_type)
def downloadWhisperModelWeight(self, weight_type, callback=None, end_callback=None): def downloadWhisperModelWeight(self, weight_type, callback=None, end_callback=None):
@@ -424,7 +426,7 @@ class Model:
max_phrases=config.MIC_MAX_PHRASES, max_phrases=config.MIC_MAX_PHRASES,
transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE, transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE,
root=config.PATH_LOCAL, root=config.PATH_LOCAL,
whisper_weight_type=config.WHISPER_WEIGHT_TYPE, whisper_weight_type=config.WHISPER_WEIGHT_TYPE if config.WHISPER_WEIGHT_TYPE != "none" else None,
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"], device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"],
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"], device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
) )
@@ -588,7 +590,7 @@ class Model:
max_phrases=config.SPEAKER_MAX_PHRASES, max_phrases=config.SPEAKER_MAX_PHRASES,
transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE, transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE,
root=config.PATH_LOCAL, root=config.PATH_LOCAL,
whisper_weight_type=config.WHISPER_WEIGHT_TYPE, whisper_weight_type=config.WHISPER_WEIGHT_TYPE if config.WHISPER_WEIGHT_TYPE != "none" else None,
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"], device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"],
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"], device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
) )

View File

@@ -570,8 +570,8 @@ class Controller:
@staticmethod @staticmethod
def setSelectedTranscriptionEngine(data:dict, *args, **kwargs) -> dict: def setSelectedTranscriptionEngine(data:dict, *args, **kwargs) -> dict:
config.SELECTED_TRANSCRIPTION_ENGINE = data.get("engine", "Google") config.SELECTED_TRANSCRIPTION_ENGINE = data["engine"]
config.WHISPER_WEIGHT_TYPE = data.get("weight_type", "base") config.WHISPER_WEIGHT_TYPE = data["weight_type"]
data = { data = {
"engine":config.SELECTED_TRANSCRIPTION_ENGINE, "engine":config.SELECTED_TRANSCRIPTION_ENGINE,
"weight_type":config.WHISPER_WEIGHT_TYPE, "weight_type":config.WHISPER_WEIGHT_TYPE,
@@ -1432,16 +1432,19 @@ class Controller:
def downloadWhisperWeight(self, data:str, *args, **kwargs) -> dict: def downloadWhisperWeight(self, data:str, *args, **kwargs) -> dict:
weight_type = str(data) weight_type = str(data)
download_whisper = self.DownloadWhisper( if weight_type == "none":
self.run_mapping, pass
weight_type, else:
self.run download_whisper = self.DownloadWhisper(
) self.run_mapping,
self.startThreadingDownloadWhisperWeight( weight_type,
weight_type, self.run
download_whisper.progressBar,
download_whisper.downloaded,
) )
self.startThreadingDownloadWhisperWeight(
weight_type,
download_whisper.progressBar,
download_whisper.downloaded,
)
return {"status":200, "result":True} return {"status":200, "result":True}
@staticmethod @staticmethod
@@ -1577,8 +1580,9 @@ class Controller:
def updateTranscriptionEngine(self): def updateTranscriptionEngine(self):
weight_type_dict = config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT weight_type_dict = config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT
weight_type = config.WHISPER_WEIGHT_TYPE weight_type = config.WHISPER_WEIGHT_TYPE
if config.SELECTED_TRANSCRIPTION_ENGINE == "Whisper" and weight_type_dict[weight_type] is False: if config.SELECTED_TRANSCRIPTION_ENGINE == "Whisper" and (weight_type == "none" or weight_type_dict[weight_type] is False):
config.SELECTED_TRANSCRIPTION_ENGINE = "Google" config.SELECTED_TRANSCRIPTION_ENGINE = "Google"
config.WHISPER_WEIGHT_TYPE = "none"
def startCheckMicEnergy(self) -> None: def startCheckMicEnergy(self) -> None:
while self.device_access_status is False: while self.device_access_status is False:
@@ -1674,8 +1678,9 @@ class Controller:
# download CTranslate2 Model Weight # download CTranslate2 Model Weight
printLog("Download CTranslate2 Model Weight") printLog("Download CTranslate2 Model Weight")
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE) is False: weight_type = config.CTRANSLATE2_WEIGHT_TYPE
self.downloadCtranslate2Weight(config.CTRANSLATE2_WEIGHT_TYPE) if model.checkTranslatorCTranslate2ModelWeight(weight_type) is False:
self.downloadCtranslate2Weight(weight_type)
# set Translation Engine # set Translation Engine
printLog("Set Translation Engine") printLog("Set Translation Engine")
@@ -1684,8 +1689,9 @@ class Controller:
# download Whisper Model Weight # download Whisper Model Weight
printLog("Download Whisper Model Weight") printLog("Download Whisper Model Weight")
if model.checkTranscriptionWhisperModelWeight(config.WHISPER_WEIGHT_TYPE) is False: weight_type = config.WHISPER_WEIGHT_TYPE
self.downloadWhisperWeight(config.WHISPER_WEIGHT_TYPE) if model.checkTranscriptionWhisperModelWeight(weight_type) is False:
self.downloadWhisperWeight(weight_type)
# set Transcription Engine # set Transcription Engine
printLog("Set Transcription Engine") printLog("Set Transcription Engine")