Merge branch 'traslate_mask' into develop

This commit is contained in:
misyaguziya
2024-05-10 16:49:27 +09:00
3 changed files with 76 additions and 61 deletions

View File

@@ -61,6 +61,7 @@ def messageFormatter(format_type:str, translation, message):
return osc_message return osc_message
def changeToCTranslate2Process(): def changeToCTranslate2Process():
if config.CHOICE_INPUT_TRANSLATOR != "CTranslate2" or config.CHOICE_OUTPUT_TRANSLATOR != "CTranslate2":
config.CHOICE_INPUT_TRANSLATOR = "CTranslate2" config.CHOICE_INPUT_TRANSLATOR = "CTranslate2"
config.CHOICE_OUTPUT_TRANSLATOR = "CTranslate2" config.CHOICE_OUTPUT_TRANSLATOR = "CTranslate2"
updateTranslationEngineAndEngineList() updateTranslationEngineAndEngineList()
@@ -399,10 +400,10 @@ def callbackSelectedTranslationEngine(selected_translation_engine):
def callbackToggleTranslation(is_turned_on): def callbackToggleTranslation(is_turned_on):
config.ENABLE_TRANSLATION = is_turned_on config.ENABLE_TRANSLATION = is_turned_on
if config.ENABLE_TRANSLATION is True: if config.ENABLE_TRANSLATION is True:
if model.isLoadedCTranslate2Model() is False:
model.changeTranslatorCTranslate2Model() model.changeTranslatorCTranslate2Model()
view.printToTextbox_enableTranslation() view.printToTextbox_enableTranslation()
else: else:
model.clearTranslatorCTranslate2Model()
view.printToTextbox_disableTranslation() view.printToTextbox_disableTranslation()
def callbackToggleTranscriptionSend(is_turned_on): def callbackToggleTranscriptionSend(is_turned_on):

View File

@@ -103,8 +103,8 @@ class Model:
def changeTranslatorCTranslate2Model(self): def changeTranslatorCTranslate2Model(self):
self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE)
def clearTranslatorCTranslate2Model(self): def isLoadedCTranslate2Model(self):
self.translator.clearCTranslate2Model() return self.translator.isLoadedCTranslate2Model()
def checkTranscriptionWhisperModelWeight(self): def checkTranscriptionWhisperModelWeight(self):
return checkWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE) return checkWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE)
@@ -165,59 +165,62 @@ class Model:
compatible_engines.remove('DeepL_API') compatible_engines.remove('DeepL_API')
return compatible_engines return compatible_engines
def getTranslate(self, translator_name, source_language, target_language, target_country, message):
success_flag = False
translation = self.translator.translate(
translator_name=translator_name,
source_language=source_language,
target_language=target_language,
target_country=target_country,
message=message
)
# 翻訳失敗時のフェールセーフ処理
if translation is True:
success_flag = True
else:
while True:
translation = self.translator.translate(
translator_name="CTranslate2",
source_language=source_language,
target_language=target_language,
target_country=target_country,
message=message
)
if translation is not False:
break
sleep(0.1)
return translation, success_flag
def getInputTranslate(self, message): def getInputTranslate(self, message):
translation_success_flag = True
translator_name=config.CHOICE_INPUT_TRANSLATOR translator_name=config.CHOICE_INPUT_TRANSLATOR
source_language=config.SOURCE_LANGUAGE source_language=config.SOURCE_LANGUAGE
target_language=config.TARGET_LANGUAGE target_language=config.TARGET_LANGUAGE
target_country = config.TARGET_COUNTRY target_country = config.TARGET_COUNTRY
translation = self.translator.translate( translation, success_flag = self.getTranslate(
translator_name=translator_name, translator_name,
source_language=source_language, source_language,
target_language=target_language, target_language,
target_country=target_country, target_country,
message=message message
) )
return translation, success_flag
# 翻訳失敗時のフェールセーフ処理
if translation is False:
translation_success_flag = False
translation = self.translator.translate(
translator_name="CTranslate2",
source_language=source_language,
target_language=target_language,
target_country=target_country,
message=message
)
return translation, translation_success_flag
def getOutputTranslate(self, message): def getOutputTranslate(self, message):
translation_success_flag = True
translator_name=config.CHOICE_OUTPUT_TRANSLATOR translator_name=config.CHOICE_OUTPUT_TRANSLATOR
source_language=config.TARGET_LANGUAGE source_language=config.TARGET_LANGUAGE
target_language=config.SOURCE_LANGUAGE target_language=config.SOURCE_LANGUAGE
target_country=config.SOURCE_COUNTRY target_country=config.SOURCE_COUNTRY
translation = self.translator.translate( translation, success_flag = self.getTranslate(
translator_name=translator_name, translator_name,
source_language=source_language, source_language,
target_language=target_language, target_language,
target_country=target_country, target_country,
message=message message
) )
return translation, success_flag
# 翻訳失敗時のフェールセーフ処理
if translation is False:
translation_success_flag = False
translation = self.translator.translate(
translator_name="CTranslate2",
source_language=source_language,
target_language=target_language,
target_country=target_country,
message=message
)
return translation, translation_success_flag
def addKeywords(self): def addKeywords(self):
for f in config.INPUT_MIC_WORD_FILTER: for f in config.INPUT_MIC_WORD_FILTER:

View File

@@ -1,4 +1,3 @@
import gc
import os import os
from deepl import Translator as deepl_Translator from deepl import Translator as deepl_Translator
from translators import translate_text as other_web_Translator from translators import translate_text as other_web_Translator
@@ -14,6 +13,7 @@ class Translator():
self.deepl_client = None self.deepl_client = None
self.ctranslate2_translator = None self.ctranslate2_translator = None
self.ctranslate2_tokenizer = None self.ctranslate2_tokenizer = None
self.is_loaded_ctranslate2_model = False
def authenticationDeepLAuthKey(self, authkey): def authenticationDeepLAuthKey(self, authkey):
result = True result = True
@@ -26,6 +26,7 @@ class Translator():
return result return result
def changeCTranslate2Model(self, path, model_type): def changeCTranslate2Model(self, path, model_type):
self.is_loaded_ctranslate2_model = False
directory_name = ctranslate2_weights[model_type]["directory_name"] directory_name = ctranslate2_weights[model_type]["directory_name"]
tokenizer = ctranslate2_weights[model_type]["tokenizer"] tokenizer = ctranslate2_weights[model_type]["tokenizer"]
weight_path = os.path.join(path, "weights", "ctranslate2", directory_name) weight_path = os.path.join(path, "weights", "ctranslate2", directory_name)
@@ -44,13 +45,24 @@ class Translator():
print("Error: changeCTranslate2Model()", e) print("Error: changeCTranslate2Model()", e)
tokenizer_path = os.path.join("./weights", "ctranslate2", directory_name, "tokenizer") tokenizer_path = os.path.join("./weights", "ctranslate2", directory_name, "tokenizer")
self.ctranslate2_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) self.ctranslate2_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path)
self.is_loaded_ctranslate2_model = True
def clearCTranslate2Model(self): def isLoadedCTranslate2Model(self):
del self.ctranslate2_translator return self.is_loaded_ctranslate2_model
del self.ctranslate2_tokenizer
gc.collect() def translateCTranslate2(self, message, source_language, target_language):
self.ctranslate2_translator = None result = False
self.ctranslate2_tokenizer = None if self.is_loaded_ctranslate2_model is True:
try:
self.ctranslate2_tokenizer.src_lang = source_language
source = self.ctranslate2_tokenizer.convert_ids_to_tokens(self.ctranslate2_tokenizer.encode(message))
target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]]
results = self.ctranslate2_translator.translate_batch([source], target_prefix=[target_prefix])
target = results[0].hypotheses[0][1:]
result = self.ctranslate2_tokenizer.decode(self.ctranslate2_tokenizer.convert_tokens_to_ids(target))
except Exception:
pass
return result
@staticmethod @staticmethod
def getLanguageCode(translator_name, target_country, source_language, target_language): def getLanguageCode(translator_name, target_country, source_language, target_language):
@@ -115,12 +127,11 @@ class Translator():
to_language=target_language, to_language=target_language,
) )
case "CTranslate2": case "CTranslate2":
self.ctranslate2_tokenizer.src_lang = source_language result = self.translateCTranslate2(
source = self.ctranslate2_tokenizer.convert_ids_to_tokens(self.ctranslate2_tokenizer.encode(message)) message=message,
target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]] source_language=source_language,
results = self.ctranslate2_translator.translate_batch([source], target_prefix=[target_prefix]) target_language=target_language,
target = results[0].hypotheses[0][1:] )
result = self.ctranslate2_tokenizer.decode(self.ctranslate2_tokenizer.convert_tokens_to_ids(target))
except Exception: except Exception:
import traceback import traceback
with open('error.log', 'a') as f: with open('error.log', 'a') as f: