diff --git a/src-python/model.py b/src-python/model.py index 0ef25d52..d446b2c1 100644 --- a/src-python/model.py +++ b/src-python/model.py @@ -106,8 +106,8 @@ class Model: def changeTranslatorCTranslate2Model(self): self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) - def downloadCTranslate2ModelWeight(self, callback=None): - return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callback) + def downloadCTranslate2ModelWeight(self, callbackFunc=None): + return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callbackFunc) def isLoadedCTranslate2Model(self): return self.translator.isLoadedCTranslate2Model() @@ -115,8 +115,8 @@ class Model: def checkTranscriptionWhisperModelWeight(self): return checkWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE) - def downloadWhisperModelWeight(self, callback=None): - return downloadWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE, callback) + def downloadWhisperModelWeight(self, callbackFunc=None): + return downloadWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE, callbackFunc) def resetKeywordProcessor(self): del self.keyword_processor diff --git a/src-python/models/transcription/transcription_whisper.py b/src-python/models/transcription/transcription_whisper.py index 8334f952..bb18d45e 100644 --- a/src-python/models/transcription/transcription_whisper.py +++ b/src-python/models/transcription/transcription_whisper.py @@ -65,6 +65,7 @@ def downloadWhisperWeight(root, weight_type, callbackFunc): path = os_path.join(root, "weights", "whisper", weight_type) os_makedirs(path, exist_ok=True) if checkWhisperWeight(root, weight_type) is True: + callbackFunc(1) return for filename in _FILENAMES: diff --git a/src-python/models/translation/translation_utils.py b/src-python/models/translation/translation_utils.py index 4d273add..13b52571 100644 --- a/src-python/models/translation/translation_utils.py +++ b/src-python/models/translation/translation_utils.py @@ -59,13 +59,14 @@ def checkCTranslate2Weight(path, weight_type="Small"): already_downloaded = True return already_downloaded -def downloadCTranslate2Weight(root, weight_type="Small", func=None): +def downloadCTranslate2Weight(root, weight_type="Small", callbackFunc=None): url = ctranslate2_weights[weight_type]["url"] filename = "weight.zip" path = os_path.join(root, "weights", "ctranslate2") os_makedirs(path, exist_ok=True) if checkCTranslate2Weight(path, weight_type): + callbackFunc(1) return try: @@ -76,9 +77,9 @@ def downloadCTranslate2Weight(root, weight_type="Small", func=None): with open(os_path.join(tmp_path, filename), 'wb') as file: for chunk in res.iter_content(chunk_size=1024*5): file.write(chunk) - if isinstance(func, Callable): + if isinstance(callbackFunc, Callable): total_chunk += len(chunk) - func(total_chunk/file_size) + callbackFunc(total_chunk/file_size) with ZipFile(os_path.join(tmp_path, filename)) as zf: zf.extractall(path)