👍️[Update] Main/Controller : Add CTranslate2/Whisper model weight download endpoints

This commit is contained in:
misyaguziya
2024-08-11 13:50:51 +09:00
parent c7ed586bb3
commit 3924b2eb3b
3 changed files with 199 additions and 129 deletions

View File

@@ -24,8 +24,8 @@ from models.transcription.transcription_transcriber import AudioTranscriber
from models.xsoverlay.notification import xsoverlayForVRCT from models.xsoverlay.notification import xsoverlayForVRCT
from models.translation.translation_languages import translation_lang from models.translation.translation_languages import translation_lang
from models.transcription.transcription_languages import transcription_lang from models.transcription.transcription_languages import transcription_lang
from models.translation.translation_utils import checkCTranslate2Weight from models.translation.translation_utils import checkCTranslate2Weight, downloadCTranslate2Weight
from models.transcription.transcription_whisper import checkWhisperWeight from models.transcription.transcription_whisper import checkWhisperWeight, downloadWhisperWeight
from models.overlay.overlay import Overlay from models.overlay.overlay import Overlay
from models.overlay.overlay_image import OverlayImage from models.overlay.overlay_image import OverlayImage
@@ -106,12 +106,18 @@ class Model:
def changeTranslatorCTranslate2Model(self): def changeTranslatorCTranslate2Model(self):
self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE)
def downloadCTranslate2ModelWeight(self, callback=None):
return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callback)
def isLoadedCTranslate2Model(self): def isLoadedCTranslate2Model(self):
return self.translator.isLoadedCTranslate2Model() return self.translator.isLoadedCTranslate2Model()
def checkTranscriptionWhisperModelWeight(self): def checkTranscriptionWhisperModelWeight(self):
return checkWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE) return checkWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE)
def downloadWhisperModelWeight(self, callback=None):
return downloadWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE, callback)
def resetKeywordProcessor(self): def resetKeywordProcessor(self):
del self.keyword_processor del self.keyword_processor
self.keyword_processor = KeywordProcessor() self.keyword_processor = KeywordProcessor()

View File

@@ -10,7 +10,7 @@ from utils import getKeyByValue, isUniqueStrings, strPctToInt
import argparse import argparse
# Common # Common
class DownloadProgressBar: class DownloadSoftwareProgressBar:
def __init__(self, action): def __init__(self, action):
self.action = action self.action = action
@@ -23,7 +23,7 @@ class DownloadProgressBar:
} }
}) })
class UpdateProgressBar: class UpdateSoftwareProgressBar:
def __init__(self, action): def __init__(self, action):
self.action = action self.action = action
@@ -38,8 +38,8 @@ class UpdateProgressBar:
def callbackUpdateSoftware(data, action, *args, **kwargs) -> dict: def callbackUpdateSoftware(data, action, *args, **kwargs) -> dict:
print(json.dumps({"status":348, "log": "callbackUpdateSoftware"}), flush=True) print(json.dumps({"status":348, "log": "callbackUpdateSoftware"}), flush=True)
download = DownloadProgressBar(action) download = DownloadSoftwareProgressBar(action)
update = UpdateProgressBar(action) update = UpdateSoftwareProgressBar(action)
model.updateSoftware(restart=True, download=download.set, update=update.set) model.updateSoftware(restart=True, download=download.set, update=update.set)
return {"status":200} return {"status":200}
@@ -707,6 +707,25 @@ def callbackSetCtranslate2WeightType(data, *args, **kwargs) -> dict:
}, },
} }
class DownloadCTranslate2ProgressBar:
def __init__(self, action):
self.action = action
def set(self, progress) -> None:
print(json.dumps({"status":348, "log": "CTranslate2 Weight Download Progress", "data":progress}), flush=True)
self.action("download", {
"status":200,
"result":{
"progress":progress
}
})
def callbackDownloadCtranslate2Weight(data, action, *args, **kwargs) -> dict:
print(json.dumps({"status":348, "log": "callbackDownloadCtranslate2Weight"}), flush=True)
download = DownloadCTranslate2ProgressBar(action)
model.downloadCTranslate2ModelWeight(download.set)
return {"status":200}
def callbackSetDeeplAuthKey(data, *args, **kwargs) -> dict: def callbackSetDeeplAuthKey(data, *args, **kwargs) -> dict:
print(json.dumps({"status":348, "log": "callbackSetDeeplAuthKey", "data":data}), flush=True) print(json.dumps({"status":348, "log": "callbackSetDeeplAuthKey", "data":data}), flush=True)
status = 400 status = 400
@@ -873,7 +892,7 @@ def callbackDeleteMicWordFilter(data, *args, **kwargs) -> dict:
model.resetKeywordProcessor() model.resetKeywordProcessor()
model.addKeywords() model.addKeywords()
except Exception: except Exception:
print("There was no the target word in config.INPUT_MIC_WORD_FILTER") print(json.dumps({"status":348, "log": "callbackDeleteMicWordFilter There was no the target word in config.INPUT_MIC_WORD_FILTER"}), flush=True)
return {"status":200, "result":config.INPUT_MIC_WORD_FILTER} return {"status":200, "result":config.INPUT_MIC_WORD_FILTER}
# Transcription (Speaker) # Transcription (Speaker)
@@ -1025,6 +1044,25 @@ def callbackSetWhisperWeightType(data, *args, **kwargs) -> dict:
} }
} }
class DownloadWhisperProgressBar:
def __init__(self, action):
self.action = action
def set(self, progress) -> None:
print(json.dumps({"status":348, "log": "Whisper Weight Download Progress", "data":progress}), flush=True)
self.action("download", {
"status":200,
"result":{
"progress":progress
}
})
def callbackDownloadWhisperWeight(data, action, *args, **kwargs) -> dict:
print(json.dumps({"status":348, "log": "callbackDownloadWhisperWeight"}), flush=True)
download = DownloadCTranslate2ProgressBar(action)
model.downloadWhisperModelWeight(download.set)
return {"status":200}
# VR Tab # VR Tab
def callbackSetOverlaySettingsOpacity(data, *args, **kwargs) -> dict: def callbackSetOverlaySettingsOpacity(data, *args, **kwargs) -> dict:
print(json.dumps({"status":348, "log": "callbackSetOverlaySettingsOpacity", "data":data}), flush=True) print(json.dumps({"status":348, "log": "callbackSetOverlaySettingsOpacity", "data":data}), flush=True)

View File

@@ -134,6 +134,7 @@ controller_mapping = {
"/controller/callback_enable_use_translation_feature": controller.callbackEnableUseTranslationFeature, "/controller/callback_enable_use_translation_feature": controller.callbackEnableUseTranslationFeature,
"/controller/callback_disable_use_translation_feature": controller.callbackDisableUseTranslationFeature, "/controller/callback_disable_use_translation_feature": controller.callbackDisableUseTranslationFeature,
"/controller/callback_set_ctranslate2_weight_type": controller.callbackSetCtranslate2WeightType, "/controller/callback_set_ctranslate2_weight_type": controller.callbackSetCtranslate2WeightType,
"/controller/callback_download_ctranslate2_weight": controller.callbackDownloadCtranslate2Weight,
"/controller/callback_set_deepl_auth_key": controller.callbackSetDeeplAuthKey, "/controller/callback_set_deepl_auth_key": controller.callbackSetDeeplAuthKey,
"/controller/callback_clear_deepl_auth_key": controller.callbackClearDeeplAuthKey, "/controller/callback_clear_deepl_auth_key": controller.callbackClearDeeplAuthKey,
"/controller/callback_set_mic_host": controller.callbackSetMicHost, "/controller/callback_set_mic_host": controller.callbackSetMicHost,
@@ -160,6 +161,7 @@ controller_mapping = {
"/controller/callback_enable_use_whisper_feature": controller.callbackEnableUseWhisperFeature, "/controller/callback_enable_use_whisper_feature": controller.callbackEnableUseWhisperFeature,
"/controller/callback_disable_use_whisper_feature": controller.callbackDisableUseWhisperFeature, "/controller/callback_disable_use_whisper_feature": controller.callbackDisableUseWhisperFeature,
"/controller/callback_set_whisper_weight_type": controller.callbackSetWhisperWeightType, "/controller/callback_set_whisper_weight_type": controller.callbackSetWhisperWeightType,
"/controller/callback_download_whisper_weight": controller.callbackDownloadWhisperWeight,
"/controller/callback_set_overlay_settings_opacity": controller.callbackSetOverlaySettingsOpacity, "/controller/callback_set_overlay_settings_opacity": controller.callbackSetOverlaySettingsOpacity,
"/controller/callback_set_overlay_settings_ui_scaling": controller.callbackSetOverlaySettingsUiScaling, "/controller/callback_set_overlay_settings_ui_scaling": controller.callbackSetOverlaySettingsUiScaling,
"/controller/callback_enable_overlay_small_log": controller.callbackEnableOverlaySmallLog, "/controller/callback_enable_overlay_small_log": controller.callbackEnableOverlaySmallLog,
@@ -237,6 +239,12 @@ action_mapping = {
"/controller/callback_messagebox_press_key_enter": { "/controller/callback_messagebox_press_key_enter": {
"error_translation_engine":"/action/error_translation_engine" "error_translation_engine":"/action/error_translation_engine"
}, },
"/controller/callback_download_ctranslate2_weight": {
"download":"/action/download_ctranslate2_weight"
},
"/controller/callback_download_whisper_weight": {
"download":"/action/download_whisper_weight"
},
} }
def handleConfigRequest(endpoint): def handleConfigRequest(endpoint):
@@ -318,127 +326,145 @@ def main():
print(response, flush=True) print(response, flush=True)
if __name__ == "__main__": if __name__ == "__main__":
response_test = False process = "main"
if response_test: match process:
import time case "main":
controller.init() try:
controller.init()
print(json.dumps({"status":348, "log": "Initialization from Python."}), flush=True)
while True:
main()
except Exception:
import traceback
with open('error.log', 'a') as f:
traceback.print_exc(file=f)
for endpoint, value in config_mapping.items(): case "test":
response_data, status = handleConfigRequest(endpoint)
response = {
"status": status,
"endpoint": endpoint,
"result": response_data,
}
response = json.dumps(response)
print(response, flush=True)
time.sleep(0.1)
for endpoint, value in controller_mapping.items():
print(json.dumps({"status":348, "log": f"endpoint: {endpoint}"}))
match endpoint:
case "/controller/callback_messagebox_press_key_enter":
data = "テスト"
case "/controller/set_your_language_and_country":
data = {"language": "English", "country": "Hong Kong"}
case "/controller/set_target_language_and_country":
data = {"language": "Japanese", "country": "Japan"}
case "/controller/callback_set_transparency":
data = 0.5
case "/controller/callback_set_appearance":
data = "Dark"
case "/controller/callback_set_ui_scaling":
data = 1.5
case "/controller/callback_set_textbox_ui_scaling":
data = 1.5
case "/controller/callback_set_message_box_ratio":
data = 0.5
case "/controller/callback_set_font_family":
data = "Yu Gothic UI"
case "/controller/callback_set_ui_language":
data = "ja"
case "/controller/callback_set_ctranslate2_weight_type":
data = "Small"
case "/controller/callback_set_deepl_auth_key":
data = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee:fx"
case "/controller/callback_set_mic_host":
data = "MME"
case "/controller/callback_set_mic_device":
data = "マイク (Realtek High Definition Audio)"
case "/controller/callback_set_mic_energy_threshold":
data = 0.5
case "/controller/callback_set_mic_record_timeout":
data = 5
case "/controller/callback_set_mic_phrase_timeout":
data = 5
case "/controller/callback_set_mic_max_phrases":
data = 5
case "/controller/callback_set_mic_word_filter":
data = "test0, test1, test2"
case "/controller/callback_delete_mic_word_filter":
data = "test1"
case "/controller/callback_set_speaker_device":
data = "スピーカー (Realtek High Definition Audio)"
case "/controller/callback_set_speaker_energy_threshold":
data = 0.5
case "/controller/callback_set_speaker_record_timeout":
data = 5
case "/controller/callback_set_speaker_phrase_timeout":
data = 5
case "/controller/callback_set_speaker_max_phrases":
data = 5
case "/controller/callback_set_whisper_weight_type":
data = "base"
case "/controller/callback_set_overlay_settings_opacity":
data = 0.5
case "/controller/callback_set_overlay_settings_ui_scaling":
data = 1.5
case "/controller/callback_set_overlay_small_log_settings_x_pos":
data = 0
case "/controller/callback_set_overlay_small_log_settings_y_pos":
data = 0
case "/controller/callback_set_overlay_small_log_settings_z_pos":
data = 0
case "/controller/callback_set_overlay_small_log_settings_x_rotation":
data = 0
case "/controller/callback_set_overlay_small_log_settings_y_rotation":
data = 0
case "/controller/callback_set_overlay_small_log_settings_z_rotation":
data = 0
case "/controller/callback_set_send_message_button_type":
data = "show"
case "/controller/callback_set_send_message_format":
data = "[message]"
case "/controller/callback_set_send_message_format_with_t":
data = "[message]([translation])"
case "/controller/callback_set_received_message_format":
data = "[message]"
case "/controller/callback_set_received_message_format_with_t":
data = "[message]([translation])"
case "/controller/callback_set_osc_ip_address":
data = "127.0.0.1"
case "/controller/callback_set_osc_port":
data = 8000
case _:
data = None
response_data, status = handleControllerRequest(endpoint, data)
response = {
"status": status,
"endpoint": endpoint,
"result": response_data,
}
response = json.dumps(response)
print(response, flush=True)
time.sleep(0.5)
else:
try:
controller.init() controller.init()
print(json.dumps({"status":348, "log": "Initialization from Python."}), flush=True) response_data, status = handleControllerRequest("/controller/callback_download_ctranslate2_weight")
while True: response = {
main() "status": status,
except Exception: "endpoint": "/controller/callback_download_ctranslate2_weight",
import traceback "result": response_data,
with open('error.log', 'a') as f: }
traceback.print_exc(file=f) response = json.dumps(response)
response_data, status = handleControllerRequest("/controller/callback_download_whisper_weight")
response = {
"status": status,
"endpoint": "/controller/callback_download_whisper_weight",
"result": response_data,
}
response = json.dumps(response)
case "test_all":
import time
controller.init()
for endpoint, value in config_mapping.items():
response_data, status = handleConfigRequest(endpoint)
response = {
"status": status,
"endpoint": endpoint,
"result": response_data,
}
response = json.dumps(response)
print(response, flush=True)
time.sleep(0.1)
for endpoint, value in controller_mapping.items():
print(json.dumps({"status":348, "log": f"endpoint: {endpoint}"}))
match endpoint:
case "/controller/callback_messagebox_press_key_enter":
data = "テスト"
case "/controller/set_your_language_and_country":
data = {"language": "English", "country": "Hong Kong"}
case "/controller/set_target_language_and_country":
data = {"language": "Japanese", "country": "Japan"}
case "/controller/callback_set_transparency":
data = 0.5
case "/controller/callback_set_appearance":
data = "Dark"
case "/controller/callback_set_ui_scaling":
data = 1.5
case "/controller/callback_set_textbox_ui_scaling":
data = 1.5
case "/controller/callback_set_message_box_ratio":
data = 0.5
case "/controller/callback_set_font_family":
data = "Yu Gothic UI"
case "/controller/callback_set_ui_language":
data = "ja"
case "/controller/callback_set_ctranslate2_weight_type":
data = "Small"
case "/controller/callback_set_deepl_auth_key":
data = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee:fx"
case "/controller/callback_set_mic_host":
data = "MME"
case "/controller/callback_set_mic_device":
data = "マイク (Realtek High Definition Audio)"
case "/controller/callback_set_mic_energy_threshold":
data = 0.5
case "/controller/callback_set_mic_record_timeout":
data = 5
case "/controller/callback_set_mic_phrase_timeout":
data = 5
case "/controller/callback_set_mic_max_phrases":
data = 5
case "/controller/callback_set_mic_word_filter":
data = "test0, test1, test2"
case "/controller/callback_delete_mic_word_filter":
data = "test1"
case "/controller/callback_set_speaker_device":
data = "スピーカー (Realtek High Definition Audio)"
case "/controller/callback_set_speaker_energy_threshold":
data = 0.5
case "/controller/callback_set_speaker_record_timeout":
data = 5
case "/controller/callback_set_speaker_phrase_timeout":
data = 5
case "/controller/callback_set_speaker_max_phrases":
data = 5
case "/controller/callback_set_whisper_weight_type":
data = "base"
case "/controller/callback_set_overlay_settings_opacity":
data = 0.5
case "/controller/callback_set_overlay_settings_ui_scaling":
data = 1.5
case "/controller/callback_set_overlay_small_log_settings_x_pos":
data = 0
case "/controller/callback_set_overlay_small_log_settings_y_pos":
data = 0
case "/controller/callback_set_overlay_small_log_settings_z_pos":
data = 0
case "/controller/callback_set_overlay_small_log_settings_x_rotation":
data = 0
case "/controller/callback_set_overlay_small_log_settings_y_rotation":
data = 0
case "/controller/callback_set_overlay_small_log_settings_z_rotation":
data = 0
case "/controller/callback_set_send_message_button_type":
data = "show"
case "/controller/callback_set_send_message_format":
data = "[message]"
case "/controller/callback_set_send_message_format_with_t":
data = "[message]([translation])"
case "/controller/callback_set_received_message_format":
data = "[message]"
case "/controller/callback_set_received_message_format_with_t":
data = "[message]([translation])"
case "/controller/callback_set_osc_ip_address":
data = "127.0.0.1"
case "/controller/callback_set_osc_port":
data = 8000
case _:
data = None
response_data, status = handleControllerRequest(endpoint, data)
response = {
"status": status,
"endpoint": endpoint,
"result": response_data,
}
response = json.dumps(response)
print(response, flush=True)
time.sleep(0.5)