From c5cd80b54264c253f3488dd38fd32bbe1531ce34 Mon Sep 17 00:00:00 2001 From: misyaguziya <53165965+misyaguziya@users.noreply.github.com> Date: Thu, 14 Nov 2024 22:53:42 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=8D=EF=B8=8F[Update]=20Controller=20:?= =?UTF-8?q?=20AI=20model=E3=81=AE=E3=83=80=E3=82=A6=E3=83=B3=E3=83=AD?= =?UTF-8?q?=E3=83=BC=E3=83=89=E3=81=AE=E5=90=8C=E6=9C=9F=E5=87=A6=E7=90=86?= =?UTF-8?q?=E3=82=92=E8=BF=BD=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src-python/webui_controller.py | 36 ++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src-python/webui_controller.py b/src-python/webui_controller.py index 6616de8e..158e66bf 100644 --- a/src-python/webui_controller.py +++ b/src-python/webui_controller.py @@ -150,7 +150,7 @@ class Controller: def downloaded(self) -> None: weight_type_dict = config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT - weight_type_dict["self.weight_type"] = True + weight_type_dict[self.weight_type] = True config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT = weight_type_dict self.run( @@ -1441,7 +1441,7 @@ class Controller: th_start_update_cuda_software.start() return {"status":200, "result":True} - def downloadCtranslate2Weight(self, data:str, *args, **kwargs) -> dict: + def downloadCtranslate2Weight(self, data:str, asynchronous:bool=True, *args, **kwargs) -> dict: weight_type = str(data) download_ctranslate2 = self.DownloadCTranslate2( self.run_mapping, @@ -1449,25 +1449,31 @@ class Controller: self.run ) - self.startThreadingDownloadCtranslate2Weight( - weight_type, - download_ctranslate2.progressBar, - download_ctranslate2.downloaded, - ) + if asynchronous is True: + self.startThreadingDownloadCtranslate2Weight( + weight_type, + download_ctranslate2.progressBar, + download_ctranslate2.downloaded, + ) + else: + model.downloadCTranslate2ModelWeight(weight_type, download_ctranslate2.progressBar, download_ctranslate2.downloaded) return {"status":200, "result":True} - def downloadWhisperWeight(self, data:str, *args, **kwargs) -> dict: + def downloadWhisperWeight(self, data:str, asynchronous:bool=True, *args, **kwargs) -> dict: weight_type = str(data) download_whisper = self.DownloadWhisper( self.run_mapping, weight_type, self.run ) - self.startThreadingDownloadWhisperWeight( - weight_type, - download_whisper.progressBar, - download_whisper.downloaded, - ) + if asynchronous is True: + self.startThreadingDownloadWhisperWeight( + weight_type, + download_whisper.progressBar, + download_whisper.downloaded, + ) + else: + model.downloadWhisperModelWeight(weight_type, download_whisper.progressBar, download_whisper.downloaded) return {"status":200, "result":True} @staticmethod @@ -1702,7 +1708,7 @@ class Controller: printLog("Download CTranslate2 Model Weight") weight_type = config.CTRANSLATE2_WEIGHT_TYPE if model.checkTranslatorCTranslate2ModelWeight(weight_type) is False: - self.downloadCtranslate2Weight(weight_type) + self.downloadCtranslate2Weight(weight_type, False) # set Translation Engine printLog("Set Translation Engine") @@ -1713,7 +1719,7 @@ class Controller: printLog("Download Whisper Model Weight") weight_type = config.WHISPER_WEIGHT_TYPE if model.checkTranscriptionWhisperModelWeight(weight_type) is False: - self.downloadWhisperWeight(weight_type) + self.downloadWhisperWeight(weight_type, False) # set Transcription Engine printLog("Set Transcription Engine")