Merge branch 'bugfix_translation' into develop
This commit is contained in:
@@ -1476,6 +1476,7 @@ class Controller:
|
||||
)
|
||||
else:
|
||||
model.downloadCTranslate2ModelWeight(weight_type, download_ctranslate2.progressBar, download_ctranslate2.downloaded)
|
||||
model.downloadCTranslate2ModelTokenizer(weight_type)
|
||||
return {"status":200, "result":True}
|
||||
|
||||
def downloadWhisperWeight(self, data:str, asynchronous:bool=True, *args, **kwargs) -> dict:
|
||||
@@ -1637,10 +1638,22 @@ class Controller:
|
||||
config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT = weight_type_dict
|
||||
|
||||
def updateTranscriptionEngine(self):
|
||||
weight_type_dict = config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT
|
||||
weight_type = config.WHISPER_WEIGHT_TYPE
|
||||
if config.SELECTED_TRANSCRIPTION_ENGINE == "Whisper" and weight_type_dict[weight_type] is False:
|
||||
config.SELECTED_TRANSCRIPTION_ENGINE = "Google"
|
||||
weight_type_dict = config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT
|
||||
weight_available = bool(weight_type_dict.get(weight_type))
|
||||
current_engine = config.SELECTED_TRANSCRIPTION_ENGINE
|
||||
selected_engines = [key for key, value in config.SELECTABLE_TRANSCRIPTION_ENGINE_STATUS.items() if value is True]
|
||||
|
||||
# 選択可能なエンジンがなければ、Whisper に変更
|
||||
if current_engine in {"Whisper", "Google"}:
|
||||
if current_engine not in selected_engines:
|
||||
if weight_available:
|
||||
alternate = "Google" if current_engine == "Whisper" else "Whisper"
|
||||
config.SELECTED_TRANSCRIPTION_ENGINE = alternate if alternate in selected_engines else None
|
||||
else:
|
||||
config.SELECTED_TRANSCRIPTION_ENGINE = "Whisper"
|
||||
else:
|
||||
config.SELECTED_TRANSCRIPTION_ENGINE = "Whisper"
|
||||
|
||||
def startCheckMicEnergy(self) -> None:
|
||||
while self.device_access_status is False:
|
||||
@@ -1728,20 +1741,6 @@ class Controller:
|
||||
self.disconnectedNetwork()
|
||||
printLog(f"Connected Network: {connected_network}")
|
||||
|
||||
printLog("Init Translation Engine Status")
|
||||
for engine in config.SELECTABLE_TRANSLATION_ENGINE_LIST:
|
||||
match engine:
|
||||
case "DeepL_API":
|
||||
printLog("Start check DeepL API Key")
|
||||
config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = False
|
||||
if config.AUTH_KEYS[engine] is not None:
|
||||
if model.authenticationTranslatorDeepLAuthKey(auth_key=config.AUTH_KEYS[engine]) is True:
|
||||
config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = True
|
||||
else:
|
||||
# error update Auth key
|
||||
auth_keys = config.AUTH_KEYS
|
||||
auth_keys[engine] = None
|
||||
config.AUTH_KEYS = auth_keys
|
||||
self.initializationProgress(1)
|
||||
|
||||
if connected_network is True:
|
||||
@@ -1774,6 +1773,7 @@ class Controller:
|
||||
else:
|
||||
self.enableAiModels()
|
||||
|
||||
printLog("Init Translation Engine Status")
|
||||
for engine in config.SELECTABLE_TRANSLATION_ENGINE_LIST:
|
||||
match engine:
|
||||
case "CTranslate2":
|
||||
@@ -1781,6 +1781,17 @@ class Controller:
|
||||
config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = True
|
||||
else:
|
||||
config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = False
|
||||
case "DeepL_API":
|
||||
printLog("Start check DeepL API Key")
|
||||
config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = False
|
||||
if config.AUTH_KEYS[engine] is not None:
|
||||
if model.authenticationTranslatorDeepLAuthKey(auth_key=config.AUTH_KEYS[engine]) is True:
|
||||
config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = True
|
||||
else:
|
||||
# error update Auth key
|
||||
auth_keys = config.AUTH_KEYS
|
||||
auth_keys[engine] = None
|
||||
config.AUTH_KEYS = auth_keys
|
||||
case _:
|
||||
if connected_network is True:
|
||||
config.SELECTABLE_TRANSLATION_ENGINE_STATUS[engine] = True
|
||||
|
||||
@@ -24,7 +24,7 @@ from models.transcription.transcription_recorder import SelectedMicEnergyRecorde
|
||||
from models.transcription.transcription_transcriber import AudioTranscriber
|
||||
from models.translation.translation_languages import translation_lang
|
||||
from models.transcription.transcription_languages import transcription_lang
|
||||
from models.translation.translation_utils import checkCTranslate2Weight, downloadCTranslate2Weight
|
||||
from models.translation.translation_utils import checkCTranslate2Weight, downloadCTranslate2Weight, downloadCTranslate2Tokenizer
|
||||
from models.transcription.transcription_whisper import checkWhisperWeight, downloadWhisperWeight
|
||||
from models.overlay.overlay import Overlay
|
||||
from models.overlay.overlay_image import OverlayImage
|
||||
@@ -113,6 +113,9 @@ class Model:
|
||||
def downloadCTranslate2ModelWeight(self, weight_type, callback=None, end_callback=None):
|
||||
return downloadCTranslate2Weight(config.PATH_LOCAL, weight_type, callback, end_callback)
|
||||
|
||||
def downloadCTranslate2ModelTokenizer(self, weight_type):
|
||||
return downloadCTranslate2Tokenizer(config.PATH_LOCAL, weight_type)
|
||||
|
||||
def isLoadedCTranslate2Model(self):
|
||||
return self.translator.isLoadedCTranslate2Model()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from os import makedirs as os_makedirs
|
||||
from requests import get as requests_get
|
||||
from typing import Callable
|
||||
import hashlib
|
||||
import transformers
|
||||
from utils import errorLogging
|
||||
|
||||
ctranslate2_weights = {
|
||||
@@ -87,3 +88,16 @@ def downloadCTranslate2Weight(root, weight_type="small", callback=None, end_call
|
||||
|
||||
if isinstance(end_callback, Callable):
|
||||
end_callback()
|
||||
|
||||
def downloadCTranslate2Tokenizer(path, weight_type="small"):
|
||||
directory_name = ctranslate2_weights[weight_type]["directory_name"]
|
||||
tokenizer = ctranslate2_weights[weight_type]["tokenizer"]
|
||||
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
||||
|
||||
try:
|
||||
os_makedirs(tokenizer_path, exist_ok=True)
|
||||
transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path)
|
||||
except Exception:
|
||||
errorLogging()
|
||||
tokenizer_path = os_path.join("./weights", "ctranslate2", directory_name, "tokenizer")
|
||||
transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path)
|
||||
Reference in New Issue
Block a user