From bdf67ab7c83782a2d002936a5097229de5a924e6 Mon Sep 17 00:00:00 2001 From: misyaguziya <53165965+misyaguziya@users.noreply.github.com> Date: Wed, 12 Feb 2025 13:48:16 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B[bugfix]=20offline=E6=99=82?= =?UTF-8?q?=E3=81=AE=E5=87=A6=E7=90=86=E3=82=92=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src-python/controller.py | 45 ++++++++++++------- src-python/model.py | 5 ++- .../models/translation/translation_utils.py | 16 ++++++- 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/src-python/controller.py b/src-python/controller.py index 8fb05825..bd193ff0 100644 --- a/src-python/controller.py +++ b/src-python/controller.py @@ -1476,6 +1476,7 @@ class Controller: ) else: model.downloadCTranslate2ModelWeight(weight_type, download_ctranslate2.progressBar, download_ctranslate2.downloaded) + model.downloadCTranslate2ModelTokenizer(weight_type) return {"status":200, "result":True} def downloadWhisperWeight(self, data:str, asynchronous:bool=True, *args, **kwargs) -> dict: @@ -1637,10 +1638,22 @@ class Controller: config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT = weight_type_dict def updateTranscriptionEngine(self): - weight_type_dict = config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT weight_type = config.WHISPER_WEIGHT_TYPE - if config.SELECTED_TRANSCRIPTION_ENGINE == "Whisper" and weight_type_dict[weight_type] is False: - config.SELECTED_TRANSCRIPTION_ENGINE = "Google" + weight_type_dict = config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT + weight_available = bool(weight_type_dict.get(weight_type)) + current_engine = config.SELECTED_TRANSCRIPTION_ENGINE + selected_engines = [key for key, value in config.SELECTABLE_TRANSCRIPTION_ENGINE_STATUS.items() if value is True] + + # 選択可能なエンジンがなければ、Whisper に変更 + if current_engine in {"Whisper", "Google"}: + if current_engine not in selected_engines: + if weight_available: + alternate = "Google" if current_engine == "Whisper" else "Whisper" + config.SELECTED_TRANSCRIPTION_ENGINE = alternate if alternate in selected_engines else None + else: + config.SELECTED_TRANSCRIPTION_ENGINE = "Whisper" + else: + config.SELECTED_TRANSCRIPTION_ENGINE = "Whisper" def startCheckMicEnergy(self) -> None: while self.device_access_status is False: @@ -1728,20 +1741,6 @@ class Controller: self.disconnectedNetwork() printLog(f"Connected Network: {connected_network}") - printLog("Init Translation Engine Status") - for engine in config.SELECTABLE_TRANSLATION_ENGINE_LIST: - match engine: - case "DeepL_API": - printLog("Start check DeepL API Key") - config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = False - if config.AUTH_KEYS[engine] is not None: - if model.authenticationTranslatorDeepLAuthKey(auth_key=config.AUTH_KEYS[engine]) is True: - config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = True - else: - # error update Auth key - auth_keys = config.AUTH_KEYS - auth_keys[engine] = None - config.AUTH_KEYS = auth_keys self.initializationProgress(1) if connected_network is True: @@ -1774,6 +1773,7 @@ class Controller: else: self.enableAiModels() + printLog("Init Translation Engine Status") for engine in config.SELECTABLE_TRANSLATION_ENGINE_LIST: match engine: case "CTranslate2": @@ -1781,6 +1781,17 @@ class Controller: config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = True else: config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = False + case "DeepL_API": + printLog("Start check DeepL API Key") + config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = False + if config.AUTH_KEYS[engine] is not None: + if model.authenticationTranslatorDeepLAuthKey(auth_key=config.AUTH_KEYS[engine]) is True: + config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = True + else: + # error update Auth key + auth_keys = config.AUTH_KEYS + auth_keys[engine] = None + config.AUTH_KEYS = auth_keys case _: if connected_network is True: config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = True diff --git a/src-python/model.py b/src-python/model.py index 19bbdd80..b6a416b4 100644 --- a/src-python/model.py +++ b/src-python/model.py @@ -24,7 +24,7 @@ from models.transcription.transcription_recorder import SelectedMicEnergyRecorde from models.transcription.transcription_transcriber import AudioTranscriber from models.translation.translation_languages import translation_lang from models.transcription.transcription_languages import transcription_lang -from models.translation.translation_utils import checkCTranslate2Weight, downloadCTranslate2Weight +from models.translation.translation_utils import checkCTranslate2Weight, downloadCTranslate2Weight, downloadCTranslate2Tokenizer from models.transcription.transcription_whisper import checkWhisperWeight, downloadWhisperWeight from models.overlay.overlay import Overlay from models.overlay.overlay_image import OverlayImage @@ -113,6 +113,9 @@ class Model: def downloadCTranslate2ModelWeight(self, weight_type, callback=None, end_callback=None): return downloadCTranslate2Weight(config.PATH_LOCAL, weight_type, callback, end_callback) + def downloadCTranslate2ModelTokenizer(self, weight_type): + return downloadCTranslate2Tokenizer(config.PATH_LOCAL, weight_type) + def isLoadedCTranslate2Model(self): return self.translator.isLoadedCTranslate2Model() diff --git a/src-python/models/translation/translation_utils.py b/src-python/models/translation/translation_utils.py index 47c53e05..457a65f1 100644 --- a/src-python/models/translation/translation_utils.py +++ b/src-python/models/translation/translation_utils.py @@ -5,6 +5,7 @@ from os import makedirs as os_makedirs from requests import get as requests_get from typing import Callable import hashlib +import transformers from utils import errorLogging ctranslate2_weights = { @@ -86,4 +87,17 @@ def downloadCTranslate2Weight(root, weight_type="small", callback=None, end_call errorLogging() if isinstance(end_callback, Callable): - end_callback() \ No newline at end of file + end_callback() + +def downloadCTranslate2Tokenizer(path, weight_type="small"): + directory_name = ctranslate2_weights[weight_type]["directory_name"] + tokenizer = ctranslate2_weights[weight_type]["tokenizer"] + tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer") + + try: + os_makedirs(tokenizer_path, exist_ok=True) + transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) + except Exception: + errorLogging() + tokenizer_path = os_path.join("./weights", "ctranslate2", directory_name, "tokenizer") + transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) \ No newline at end of file