👍️[Update] Model : AI モデルのダウンロード方法を修正
- AIモデルのダウンロード済み確認辞書を追加 - selectable_ctranslate2_weight_type_dict - selectable_whisper_weight_type_dict - AIモデルダウンロード処理完了のエンドポイントを追加 - /run/download_ctranslate2_weight - /run/downloaded_whisper_weight
This commit is contained in:
@@ -134,21 +134,55 @@ class Controller:
|
||||
energy,
|
||||
)
|
||||
|
||||
def downloadCTranslate2ProgressBar(self, progress) -> None:
|
||||
printLog("CTranslate2 Weight Download Progress", progress)
|
||||
self.run(
|
||||
200,
|
||||
self.run_mapping["download_ctranslate2"],
|
||||
progress,
|
||||
)
|
||||
class DownloadCTranslate2:
|
||||
def __init__(self, run_mapping:dict, weight_type:str, run:Callable[[int, str, Any], None]) -> None:
|
||||
self.run_mapping = run_mapping
|
||||
self.weight_type = weight_type
|
||||
self.run = run
|
||||
|
||||
def downloadWhisperProgressBar(self, progress) -> None:
|
||||
printLog("Whisper Weight Download Progress", progress)
|
||||
self.run(
|
||||
200,
|
||||
self.run_mapping["download_whisper"],
|
||||
progress,
|
||||
)
|
||||
def progressBar(self, progress) -> None:
|
||||
printLog("CTranslate2 Weight Download Progress", progress)
|
||||
self.run(
|
||||
200,
|
||||
self.run_mapping["download_ctranslate2_weight"],
|
||||
{"weight_type": self.weight_type, "progress": progress},
|
||||
)
|
||||
|
||||
def downloaded(self) -> None:
|
||||
weight_type_dict = config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT
|
||||
weight_type_dict["self.weight_type"] = True
|
||||
config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT = weight_type_dict
|
||||
|
||||
self.run(
|
||||
200,
|
||||
self.run_mapping["downloaded_ctranslate2_weight"],
|
||||
self.weight_type,
|
||||
)
|
||||
|
||||
class DownloadWhisper:
|
||||
def __init__(self, run_mapping:dict, weight_type:str, run:Callable[[int, str, Any], None]) -> None:
|
||||
self.run_mapping = run_mapping
|
||||
self.weight_type = weight_type
|
||||
self.run = run
|
||||
|
||||
def progressBar(self, progress) -> None:
|
||||
printLog("Whisper Weight Download Progress", progress)
|
||||
self.run(
|
||||
200,
|
||||
self.run_mapping["download_whisper_weight"],
|
||||
{"weight_type": self.weight_type, "progress": progress},
|
||||
)
|
||||
|
||||
def downloaded(self) -> None:
|
||||
weight_type_dict = config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT
|
||||
weight_type_dict[self.weight_type] = True
|
||||
config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT = weight_type_dict
|
||||
|
||||
self.run(
|
||||
200,
|
||||
self.run_mapping["downloaded_whisper_weight"],
|
||||
self.weight_type,
|
||||
)
|
||||
|
||||
def micMessage(self, message: Union[str, bool]) -> None:
|
||||
if isinstance(message, bool) and message is False:
|
||||
@@ -397,8 +431,8 @@ class Controller:
|
||||
return {"status":200,"result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
||||
|
||||
@staticmethod
|
||||
def getSelectableCtranslate2WeightTypeList(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_LIST}
|
||||
def getSelectableCtranslate2WeightTypeDict(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT}
|
||||
|
||||
@staticmethod
|
||||
def getSelectedTranscriptionComputeDevice(*args, **kwargs) -> dict:
|
||||
@@ -411,8 +445,8 @@ class Controller:
|
||||
return {"status":200,"result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
|
||||
|
||||
@staticmethod
|
||||
def getSelectableWhisperWeightTypeList(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTABLE_WHISPER_WEIGHT_TYPE_LIST}
|
||||
def getSelectableWhisperWeightTypeDict(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT}
|
||||
|
||||
@staticmethod
|
||||
def getMaxMicThreshold(*args, **kwargs) -> dict:
|
||||
@@ -511,6 +545,24 @@ class Controller:
|
||||
def getSelectedTranscriptionEngine(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_ENGINE}
|
||||
|
||||
@staticmethod
|
||||
def setSelectedTranscriptionEngine(data, *args, **kwargs) -> dict:
|
||||
engine = data["engine"]
|
||||
weight_type = data["weight_type"]
|
||||
if engine == "Whisper" and config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT[weight_type] is False:
|
||||
config.SELECTED_TRANSCRIPTION_ENGINE = "Google"
|
||||
config.WHISPER_WEIGHT_TYPE = None
|
||||
else:
|
||||
config.SELECTED_TRANSCRIPTION_ENGINE = engine
|
||||
config.WHISPER_WEIGHT_TYPE = weight_type
|
||||
return {
|
||||
"status":200,
|
||||
"result":{
|
||||
"engine": config.SELECTED_TRANSCRIPTION_ENGINE,
|
||||
"weight_type": config.WHISPER_WEIGHT_TYPE,
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def getMultiLanguageTranslation(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.MULTI_LANGUAGE_TRANSLATION}
|
||||
@@ -1079,7 +1131,7 @@ class Controller:
|
||||
@staticmethod
|
||||
def setEnableUseTranslationFeature(*args, **kwargs) -> dict:
|
||||
config.USE_TRANSLATION_FEATURE = True
|
||||
if model.checkCTranslatorCTranslate2ModelWeight():
|
||||
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
|
||||
def callback():
|
||||
model.changeTranslatorCTranslate2Model()
|
||||
th_callback = Thread(target=callback)
|
||||
@@ -1099,7 +1151,7 @@ class Controller:
|
||||
@staticmethod
|
||||
def setCtranslate2WeightType(data, *args, **kwargs) -> dict:
|
||||
config.CTRANSLATE2_WEIGHT_TYPE = str(data)
|
||||
if model.checkCTranslatorCTranslate2ModelWeight():
|
||||
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
|
||||
def callback():
|
||||
model.changeTranslatorCTranslate2Model()
|
||||
th_callback = Thread(target=callback)
|
||||
@@ -1114,16 +1166,7 @@ class Controller:
|
||||
@staticmethod
|
||||
def setWhisperWeightType(data, *args, **kwargs) -> dict:
|
||||
config.WHISPER_WEIGHT_TYPE = str(data)
|
||||
if model.checkTranscriptionWhisperModelWeight() is True:
|
||||
config.SELECTED_TRANSCRIPTION_ENGINE = "Whisper"
|
||||
else:
|
||||
config.SELECTED_TRANSCRIPTION_ENGINE = "Google"
|
||||
return {"status":200,
|
||||
"result":{
|
||||
"weight_type":config.WHISPER_WEIGHT_TYPE,
|
||||
"transcription_engine":config.SELECTED_TRANSCRIPTION_ENGINE,
|
||||
}
|
||||
}
|
||||
return {"status":200, "result": config.WHISPER_WEIGHT_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def getAutoClearMessageBox(*args, **kwargs) -> dict:
|
||||
@@ -1410,12 +1453,33 @@ class Controller:
|
||||
th_start_update_software.start()
|
||||
return {"status":200, "result":True}
|
||||
|
||||
def downloadCtranslate2Weight(self, *args, **kwargs) -> dict:
|
||||
self.startThreadingDownloadCtranslate2Weight(self.downloadCTranslate2ProgressBar)
|
||||
def downloadCtranslate2Weight(self, data:str, *args, **kwargs) -> dict:
|
||||
weight_type = str(data)
|
||||
download_ctranslate2 = self.DownloadCTranslate2(
|
||||
self.run_mapping,
|
||||
weight_type,
|
||||
self.run
|
||||
)
|
||||
|
||||
self.startThreadingDownloadCtranslate2Weight(
|
||||
weight_type,
|
||||
download_ctranslate2.progressBar,
|
||||
download_ctranslate2.downloaded,
|
||||
)
|
||||
return {"status":200, "result":True}
|
||||
|
||||
def downloadWhisperWeight(self, *args, **kwargs) -> dict:
|
||||
self.startThreadingDownloadWhisperWeight(self.downloadWhisperProgressBar)
|
||||
def downloadWhisperWeight(self, data:str, *args, **kwargs) -> dict:
|
||||
weight_type = str(data)
|
||||
download_whisper = self.DownloadWhisper(
|
||||
self.run_mapping,
|
||||
weight_type,
|
||||
self.run
|
||||
)
|
||||
self.startThreadingDownloadWhisperWeight(
|
||||
weight_type,
|
||||
download_whisper.progressBar,
|
||||
download_whisper.downloaded,
|
||||
)
|
||||
return {"status":200, "result":True}
|
||||
|
||||
@staticmethod
|
||||
@@ -1524,14 +1588,35 @@ class Controller:
|
||||
cleaned_text = re.sub(pattern, r'\1', text)
|
||||
return cleaned_text
|
||||
|
||||
def updateDownloadedCTranslate2ModelWeight(self) -> None:
|
||||
weight_type_dict = config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT
|
||||
for weight_type in weight_type_dict.keys():
|
||||
weight_type_dict[weight_type] = model.checkTranslatorCTranslate2ModelWeight(weight_type)
|
||||
config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT = weight_type_dict
|
||||
|
||||
def updateTranslationEngineAndEngineList(self):
|
||||
engine = config.SELECTED_TRANSLATION_ENGINES[config.SELECTED_TAB_NO]
|
||||
engines = self.getTranslationEngines()["result"]
|
||||
if engine not in engines:
|
||||
engines = config.SELECTED_TRANSLATION_ENGINES
|
||||
engine = engines[config.SELECTED_TAB_NO]
|
||||
selectable_engines = self.getTranslationEngines()["result"]
|
||||
if engine not in selectable_engines:
|
||||
engine = "CTranslate2"
|
||||
config.SELECTED_TRANSLATION_ENGINES[config.SELECTED_TAB_NO] = engine
|
||||
engines[config.SELECTED_TAB_NO] = engine
|
||||
config.SELECTED_TRANSLATION_ENGINES = engines
|
||||
|
||||
self.run(200, self.run_mapping["selected_translation_engines"], config.SELECTED_TRANSLATION_ENGINES)
|
||||
self.run(200, self.run_mapping["translation_engines"], engines)
|
||||
self.run(200, self.run_mapping["translation_engines"], selectable_engines)
|
||||
|
||||
def updateDownloadedWhisperModelWeight(self) -> None:
|
||||
weight_type_dict = config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT
|
||||
for weight_type in weight_type_dict.keys():
|
||||
weight_type_dict[weight_type] = model.checkTranscriptionWhisperModelWeight(weight_type)
|
||||
config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT = weight_type_dict
|
||||
|
||||
def updateTranscriptionEngine(self):
|
||||
weight_type_dict = config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT
|
||||
weight_type = config.WHISPER_WEIGHT_TYPE
|
||||
if config.SELECTED_TRANSCRIPTION_ENGINE == "Whisper" and weight_type_dict[weight_type] is False:
|
||||
config.SELECTED_TRANSCRIPTION_ENGINE = "Google"
|
||||
|
||||
def startCheckMicEnergy(self) -> None:
|
||||
while self.device_access_status is False:
|
||||
@@ -1576,14 +1661,14 @@ class Controller:
|
||||
th_stopCheckSpeakerEnergy.join()
|
||||
|
||||
@staticmethod
|
||||
def startThreadingDownloadCtranslate2Weight(callback:Callable[[float], None]) -> None:
|
||||
th_download = Thread(target=model.downloadCTranslate2ModelWeight, args=(callback,))
|
||||
def startThreadingDownloadCtranslate2Weight(weight_type:str, callback:Callable[[float], None], end_callback:Callable[[float], None]) -> None:
|
||||
th_download = Thread(target=model.downloadCTranslate2ModelWeight, args=(weight_type, callback, end_callback))
|
||||
th_download.daemon = True
|
||||
th_download.start()
|
||||
|
||||
@staticmethod
|
||||
def startThreadingDownloadWhisperWeight(callback:Callable[[float], None]) -> None:
|
||||
th_download = Thread(target=model.downloadWhisperModelWeight, args=(callback,))
|
||||
def startThreadingDownloadWhisperWeight(weight_type:str, callback:Callable[[float], None], end_callback:Callable[[float], None]) -> None:
|
||||
th_download = Thread(target=model.downloadWhisperModelWeight, args=(weight_type, callback, end_callback))
|
||||
th_download.daemon = True
|
||||
th_download.start()
|
||||
|
||||
@@ -1619,11 +1704,23 @@ class Controller:
|
||||
|
||||
# set Translation Engine
|
||||
printLog("Set Translation Engine")
|
||||
self.updateDownloadedCTranslate2ModelWeight()
|
||||
self.updateTranslationEngineAndEngineList()
|
||||
|
||||
# download CTranslate2 Model Weight
|
||||
printLog("Download CTranslate2 Model Weight")
|
||||
if config.USE_TRANSLATION_FEATURE is True and model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE) is False:
|
||||
self.downloadCtranslate2Weight(config.CTRANSLATE2_WEIGHT_TYPE)
|
||||
|
||||
# set Transcription Engine
|
||||
# printLog("Set Transcription Engine")
|
||||
# self.updateTranscriptionEngineAndEngineList()
|
||||
printLog("Set Transcription Engine")
|
||||
self.updateDownloadedWhisperModelWeight()
|
||||
self.updateTranscriptionEngine()
|
||||
|
||||
# download Whisper Model Weight
|
||||
printLog("Download Whisper Model Weight")
|
||||
if model.checkTranscriptionWhisperModelWeight(config.WHISPER_WEIGHT_TYPE) is False:
|
||||
self.downloadWhisperWeight(config.WHISPER_WEIGHT_TYPE)
|
||||
|
||||
# set word filter
|
||||
printLog("Set Word Filter")
|
||||
|
||||
Reference in New Issue
Block a user