👍️[Update] Main/Controller : Add CTranslate2/Whisper model weight download endpoints

This commit is contained in:
misyaguziya
2024-08-11 13:50:51 +09:00
parent c7ed586bb3
commit 3924b2eb3b
3 changed files with 199 additions and 129 deletions

View File

@@ -24,8 +24,8 @@ from models.transcription.transcription_transcriber import AudioTranscriber
from models.xsoverlay.notification import xsoverlayForVRCT from models.xsoverlay.notification import xsoverlayForVRCT
from models.translation.translation_languages import translation_lang from models.translation.translation_languages import translation_lang
from models.transcription.transcription_languages import transcription_lang from models.transcription.transcription_languages import transcription_lang
from models.translation.translation_utils import checkCTranslate2Weight from models.translation.translation_utils import checkCTranslate2Weight, downloadCTranslate2Weight
from models.transcription.transcription_whisper import checkWhisperWeight from models.transcription.transcription_whisper import checkWhisperWeight, downloadWhisperWeight
from models.overlay.overlay import Overlay from models.overlay.overlay import Overlay
from models.overlay.overlay_image import OverlayImage from models.overlay.overlay_image import OverlayImage
@@ -106,12 +106,18 @@ class Model:
def changeTranslatorCTranslate2Model(self): def changeTranslatorCTranslate2Model(self):
self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE)
def downloadCTranslate2ModelWeight(self, callback=None):
return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callback)
def isLoadedCTranslate2Model(self): def isLoadedCTranslate2Model(self):
return self.translator.isLoadedCTranslate2Model() return self.translator.isLoadedCTranslate2Model()
def checkTranscriptionWhisperModelWeight(self): def checkTranscriptionWhisperModelWeight(self):
return checkWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE) return checkWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE)
def downloadWhisperModelWeight(self, callback=None):
return downloadWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE, callback)
def resetKeywordProcessor(self): def resetKeywordProcessor(self):
del self.keyword_processor del self.keyword_processor
self.keyword_processor = KeywordProcessor() self.keyword_processor = KeywordProcessor()

View File

@@ -10,7 +10,7 @@ from utils import getKeyByValue, isUniqueStrings, strPctToInt
import argparse import argparse
# Common # Common
class DownloadProgressBar: class DownloadSoftwareProgressBar:
def __init__(self, action): def __init__(self, action):
self.action = action self.action = action
@@ -23,7 +23,7 @@ class DownloadProgressBar:
} }
}) })
class UpdateProgressBar: class UpdateSoftwareProgressBar:
def __init__(self, action): def __init__(self, action):
self.action = action self.action = action
@@ -38,8 +38,8 @@ class UpdateProgressBar:
def callbackUpdateSoftware(data, action, *args, **kwargs) -> dict: def callbackUpdateSoftware(data, action, *args, **kwargs) -> dict:
print(json.dumps({"status":348, "log": "callbackUpdateSoftware"}), flush=True) print(json.dumps({"status":348, "log": "callbackUpdateSoftware"}), flush=True)
download = DownloadProgressBar(action) download = DownloadSoftwareProgressBar(action)
update = UpdateProgressBar(action) update = UpdateSoftwareProgressBar(action)
model.updateSoftware(restart=True, download=download.set, update=update.set) model.updateSoftware(restart=True, download=download.set, update=update.set)
return {"status":200} return {"status":200}
@@ -707,6 +707,25 @@ def callbackSetCtranslate2WeightType(data, *args, **kwargs) -> dict:
}, },
} }
class DownloadCTranslate2ProgressBar:
def __init__(self, action):
self.action = action
def set(self, progress) -> None:
print(json.dumps({"status":348, "log": "CTranslate2 Weight Download Progress", "data":progress}), flush=True)
self.action("download", {
"status":200,
"result":{
"progress":progress
}
})
def callbackDownloadCtranslate2Weight(data, action, *args, **kwargs) -> dict:
print(json.dumps({"status":348, "log": "callbackDownloadCtranslate2Weight"}), flush=True)
download = DownloadCTranslate2ProgressBar(action)
model.downloadCTranslate2ModelWeight(download.set)
return {"status":200}
def callbackSetDeeplAuthKey(data, *args, **kwargs) -> dict: def callbackSetDeeplAuthKey(data, *args, **kwargs) -> dict:
print(json.dumps({"status":348, "log": "callbackSetDeeplAuthKey", "data":data}), flush=True) print(json.dumps({"status":348, "log": "callbackSetDeeplAuthKey", "data":data}), flush=True)
status = 400 status = 400
@@ -873,7 +892,7 @@ def callbackDeleteMicWordFilter(data, *args, **kwargs) -> dict:
model.resetKeywordProcessor() model.resetKeywordProcessor()
model.addKeywords() model.addKeywords()
except Exception: except Exception:
print("There was no the target word in config.INPUT_MIC_WORD_FILTER") print(json.dumps({"status":348, "log": "callbackDeleteMicWordFilter There was no the target word in config.INPUT_MIC_WORD_FILTER"}), flush=True)
return {"status":200, "result":config.INPUT_MIC_WORD_FILTER} return {"status":200, "result":config.INPUT_MIC_WORD_FILTER}
# Transcription (Speaker) # Transcription (Speaker)
@@ -1025,6 +1044,25 @@ def callbackSetWhisperWeightType(data, *args, **kwargs) -> dict:
} }
} }
class DownloadWhisperProgressBar:
def __init__(self, action):
self.action = action
def set(self, progress) -> None:
print(json.dumps({"status":348, "log": "Whisper Weight Download Progress", "data":progress}), flush=True)
self.action("download", {
"status":200,
"result":{
"progress":progress
}
})
def callbackDownloadWhisperWeight(data, action, *args, **kwargs) -> dict:
print(json.dumps({"status":348, "log": "callbackDownloadWhisperWeight"}), flush=True)
download = DownloadCTranslate2ProgressBar(action)
model.downloadWhisperModelWeight(download.set)
return {"status":200}
# VR Tab # VR Tab
def callbackSetOverlaySettingsOpacity(data, *args, **kwargs) -> dict: def callbackSetOverlaySettingsOpacity(data, *args, **kwargs) -> dict:
print(json.dumps({"status":348, "log": "callbackSetOverlaySettingsOpacity", "data":data}), flush=True) print(json.dumps({"status":348, "log": "callbackSetOverlaySettingsOpacity", "data":data}), flush=True)

View File

@@ -134,6 +134,7 @@ controller_mapping = {
"/controller/callback_enable_use_translation_feature": controller.callbackEnableUseTranslationFeature, "/controller/callback_enable_use_translation_feature": controller.callbackEnableUseTranslationFeature,
"/controller/callback_disable_use_translation_feature": controller.callbackDisableUseTranslationFeature, "/controller/callback_disable_use_translation_feature": controller.callbackDisableUseTranslationFeature,
"/controller/callback_set_ctranslate2_weight_type": controller.callbackSetCtranslate2WeightType, "/controller/callback_set_ctranslate2_weight_type": controller.callbackSetCtranslate2WeightType,
"/controller/callback_download_ctranslate2_weight": controller.callbackDownloadCtranslate2Weight,
"/controller/callback_set_deepl_auth_key": controller.callbackSetDeeplAuthKey, "/controller/callback_set_deepl_auth_key": controller.callbackSetDeeplAuthKey,
"/controller/callback_clear_deepl_auth_key": controller.callbackClearDeeplAuthKey, "/controller/callback_clear_deepl_auth_key": controller.callbackClearDeeplAuthKey,
"/controller/callback_set_mic_host": controller.callbackSetMicHost, "/controller/callback_set_mic_host": controller.callbackSetMicHost,
@@ -160,6 +161,7 @@ controller_mapping = {
"/controller/callback_enable_use_whisper_feature": controller.callbackEnableUseWhisperFeature, "/controller/callback_enable_use_whisper_feature": controller.callbackEnableUseWhisperFeature,
"/controller/callback_disable_use_whisper_feature": controller.callbackDisableUseWhisperFeature, "/controller/callback_disable_use_whisper_feature": controller.callbackDisableUseWhisperFeature,
"/controller/callback_set_whisper_weight_type": controller.callbackSetWhisperWeightType, "/controller/callback_set_whisper_weight_type": controller.callbackSetWhisperWeightType,
"/controller/callback_download_whisper_weight": controller.callbackDownloadWhisperWeight,
"/controller/callback_set_overlay_settings_opacity": controller.callbackSetOverlaySettingsOpacity, "/controller/callback_set_overlay_settings_opacity": controller.callbackSetOverlaySettingsOpacity,
"/controller/callback_set_overlay_settings_ui_scaling": controller.callbackSetOverlaySettingsUiScaling, "/controller/callback_set_overlay_settings_ui_scaling": controller.callbackSetOverlaySettingsUiScaling,
"/controller/callback_enable_overlay_small_log": controller.callbackEnableOverlaySmallLog, "/controller/callback_enable_overlay_small_log": controller.callbackEnableOverlaySmallLog,
@@ -237,6 +239,12 @@ action_mapping = {
"/controller/callback_messagebox_press_key_enter": { "/controller/callback_messagebox_press_key_enter": {
"error_translation_engine":"/action/error_translation_engine" "error_translation_engine":"/action/error_translation_engine"
}, },
"/controller/callback_download_ctranslate2_weight": {
"download":"/action/download_ctranslate2_weight"
},
"/controller/callback_download_whisper_weight": {
"download":"/action/download_whisper_weight"
},
} }
def handleConfigRequest(endpoint): def handleConfigRequest(endpoint):
@@ -318,11 +326,39 @@ def main():
print(response, flush=True) print(response, flush=True)
if __name__ == "__main__": if __name__ == "__main__":
response_test = False process = "main"
if response_test: match process:
case "main":
try:
controller.init()
print(json.dumps({"status":348, "log": "Initialization from Python."}), flush=True)
while True:
main()
except Exception:
import traceback
with open('error.log', 'a') as f:
traceback.print_exc(file=f)
case "test":
controller.init()
response_data, status = handleControllerRequest("/controller/callback_download_ctranslate2_weight")
response = {
"status": status,
"endpoint": "/controller/callback_download_ctranslate2_weight",
"result": response_data,
}
response = json.dumps(response)
response_data, status = handleControllerRequest("/controller/callback_download_whisper_weight")
response = {
"status": status,
"endpoint": "/controller/callback_download_whisper_weight",
"result": response_data,
}
response = json.dumps(response)
case "test_all":
import time import time
controller.init() controller.init()
for endpoint, value in config_mapping.items(): for endpoint, value in config_mapping.items():
response_data, status = handleConfigRequest(endpoint) response_data, status = handleConfigRequest(endpoint)
response = { response = {
@@ -432,13 +468,3 @@ if __name__ == "__main__":
response = json.dumps(response) response = json.dumps(response)
print(response, flush=True) print(response, flush=True)
time.sleep(0.5) time.sleep(0.5)
else:
try:
controller.init()
print(json.dumps({"status":348, "log": "Initialization from Python."}), flush=True)
while True:
main()
except Exception:
import traceback
with open('error.log', 'a') as f:
traceback.print_exc(file=f)