👍️[Update] Model : モデルのダウンロード処理をスレッドで進行するように変更

This commit is contained in:
misyaguziya
2024-09-04 14:19:21 +09:00
parent 93ba4a0155
commit 9d7819b186
5 changed files with 21 additions and 12 deletions

View File

@@ -1,4 +1,3 @@
import json
from typing import Callable, Union
from time import sleep
from subprocess import Popen
@@ -701,10 +700,15 @@ class DownloadCTranslate2ProgressBar:
}
})
def startThreadingDownloadCtranslate2Weight(callback:Callable[[float], None]) -> None:
th_download = Thread(target=model.downloadCTranslate2ModelWeight, args=(callback,))
th_download.daemon = True
th_download.start()
def callbackDownloadCtranslate2Weight(data, action, *args, **kwargs) -> dict:
printLog("Download CTranslate2 Weight")
download = DownloadCTranslate2ProgressBar(action)
model.downloadCTranslate2ModelWeight(download.set)
startThreadingDownloadCtranslate2Weight(download.set)
return {"status":200}
def callbackSetDeeplAuthKey(data, *args, **kwargs) -> dict:
@@ -1034,10 +1038,15 @@ class DownloadWhisperProgressBar:
}
})
def startThreadingDownloadWhisperWeight(callback:Callable[[float], None]) -> None:
th_download = Thread(target=model.downloadWhisperModelWeight, args=(callback,))
th_download.daemon = True
th_download.start()
def callbackDownloadWhisperWeight(data, action, *args, **kwargs) -> dict:
printLog("Download Whisper Weight")
download = DownloadCTranslate2ProgressBar(action)
model.downloadWhisperModelWeight(download.set)
download = DownloadWhisperProgressBar(action)
startThreadingDownloadWhisperWeight(download.set)
return {"status":200}
# VR Tab
@@ -1283,8 +1292,7 @@ def init(endpoints:dict, *args, **kwargs) -> None:
if config.USE_TRANSLATION_FEATURE is True and model.checkCTranslatorCTranslate2ModelWeight() is False:
def callback(progress):
printResponse(200, endpoints["ctranslate2"], {"progress":progress})
printLog("Download CTranslate2 Model Weight")
model.downloadCTranslate2ModelWeight(callback)
startThreadingDownloadCtranslate2Weight(callback)
# set Transcription Engine
printLog("Set Transcription Engine")
@@ -1298,7 +1306,7 @@ def init(endpoints:dict, *args, **kwargs) -> None:
if config.USE_WHISPER_FEATURE is True and model.checkTranscriptionWhisperModelWeight() is False:
def callback(progress):
printResponse(200, endpoints["whisper"], {"progress":progress})
model.downloadWhisperModelWeight(callback)
startThreadingDownloadWhisperWeight(callback)
# set word filter
printLog("Set Word Filter")