👍️[Update] Controller : AI modelのダウンロードの同期処理を追加

This commit is contained in:
misyaguziya
2024-11-14 22:53:42 +09:00
parent a3d257fa18
commit c5cd80b542

View File

@@ -150,7 +150,7 @@ class Controller:
def downloaded(self) -> None:
weight_type_dict = config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT
weight_type_dict["self.weight_type"] = True
weight_type_dict[self.weight_type] = True
config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT = weight_type_dict
self.run(
@@ -1441,7 +1441,7 @@ class Controller:
th_start_update_cuda_software.start()
return {"status":200, "result":True}
def downloadCtranslate2Weight(self, data:str, *args, **kwargs) -> dict:
def downloadCtranslate2Weight(self, data:str, asynchronous:bool=True, *args, **kwargs) -> dict:
weight_type = str(data)
download_ctranslate2 = self.DownloadCTranslate2(
self.run_mapping,
@@ -1449,25 +1449,31 @@ class Controller:
self.run
)
self.startThreadingDownloadCtranslate2Weight(
weight_type,
download_ctranslate2.progressBar,
download_ctranslate2.downloaded,
)
if asynchronous is True:
self.startThreadingDownloadCtranslate2Weight(
weight_type,
download_ctranslate2.progressBar,
download_ctranslate2.downloaded,
)
else:
model.downloadCTranslate2ModelWeight(weight_type, download_ctranslate2.progressBar, download_ctranslate2.downloaded)
return {"status":200, "result":True}
def downloadWhisperWeight(self, data:str, *args, **kwargs) -> dict:
def downloadWhisperWeight(self, data:str, asynchronous:bool=True, *args, **kwargs) -> dict:
weight_type = str(data)
download_whisper = self.DownloadWhisper(
self.run_mapping,
weight_type,
self.run
)
self.startThreadingDownloadWhisperWeight(
weight_type,
download_whisper.progressBar,
download_whisper.downloaded,
)
if asynchronous is True:
self.startThreadingDownloadWhisperWeight(
weight_type,
download_whisper.progressBar,
download_whisper.downloaded,
)
else:
model.downloadWhisperModelWeight(weight_type, download_whisper.progressBar, download_whisper.downloaded)
return {"status":200, "result":True}
@staticmethod
@@ -1702,7 +1708,7 @@ class Controller:
printLog("Download CTranslate2 Model Weight")
weight_type = config.CTRANSLATE2_WEIGHT_TYPE
if model.checkTranslatorCTranslate2ModelWeight(weight_type) is False:
self.downloadCtranslate2Weight(weight_type)
self.downloadCtranslate2Weight(weight_type, False)
# set Translation Engine
printLog("Set Translation Engine")
@@ -1713,7 +1719,7 @@ class Controller:
printLog("Download Whisper Model Weight")
weight_type = config.WHISPER_WEIGHT_TYPE
if model.checkTranscriptionWhisperModelWeight(weight_type) is False:
self.downloadWhisperWeight(weight_type)
self.downloadWhisperWeight(weight_type, False)
# set Transcription Engine
printLog("Set Transcription Engine")