👍️[Update] Main/Controller : CTranslate2/Whisper model weight download のDownload済みの場合の処理を追加

This commit is contained in:
misyaguziya
2024-08-11 19:02:55 +09:00
parent 633cdf246c
commit 83bf21c20e
3 changed files with 9 additions and 7 deletions

View File

@@ -106,8 +106,8 @@ class Model:
def changeTranslatorCTranslate2Model(self):
self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE)
def downloadCTranslate2ModelWeight(self, callback=None):
return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callback)
def downloadCTranslate2ModelWeight(self, callbackFunc=None):
return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callbackFunc)
def isLoadedCTranslate2Model(self):
return self.translator.isLoadedCTranslate2Model()
@@ -115,8 +115,8 @@ class Model:
def checkTranscriptionWhisperModelWeight(self):
return checkWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE)
def downloadWhisperModelWeight(self, callback=None):
return downloadWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE, callback)
def downloadWhisperModelWeight(self, callbackFunc=None):
return downloadWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE, callbackFunc)
def resetKeywordProcessor(self):
del self.keyword_processor

View File

@@ -65,6 +65,7 @@ def downloadWhisperWeight(root, weight_type, callbackFunc):
path = os_path.join(root, "weights", "whisper", weight_type)
os_makedirs(path, exist_ok=True)
if checkWhisperWeight(root, weight_type) is True:
callbackFunc(1)
return
for filename in _FILENAMES:

View File

@@ -59,13 +59,14 @@ def checkCTranslate2Weight(path, weight_type="Small"):
already_downloaded = True
return already_downloaded
def downloadCTranslate2Weight(root, weight_type="Small", func=None):
def downloadCTranslate2Weight(root, weight_type="Small", callbackFunc=None):
url = ctranslate2_weights[weight_type]["url"]
filename = "weight.zip"
path = os_path.join(root, "weights", "ctranslate2")
os_makedirs(path, exist_ok=True)
if checkCTranslate2Weight(path, weight_type):
callbackFunc(1)
return
try:
@@ -76,9 +77,9 @@ def downloadCTranslate2Weight(root, weight_type="Small", func=None):
with open(os_path.join(tmp_path, filename), 'wb') as file:
for chunk in res.iter_content(chunk_size=1024*5):
file.write(chunk)
if isinstance(func, Callable):
if isinstance(callbackFunc, Callable):
total_chunk += len(chunk)
func(total_chunk/file_size)
callbackFunc(total_chunk/file_size)
with ZipFile(os_path.join(tmp_path, filename)) as zf:
zf.extractall(path)