[Add] Model: ctranslate2のテストコードを追加
This commit is contained in:
@@ -3,12 +3,24 @@ from deepl_translate import translate as deepl_web_Translator
|
||||
from translators import translate_text as other_web_Translator
|
||||
from .translation_languages import translation_lang
|
||||
|
||||
from ctranslate2.converters import TransformersConverter
|
||||
import ctranslate2
|
||||
import transformers
|
||||
|
||||
TRANSLATE_MODELS = {
|
||||
"small": "facebook/m2m100_418M",
|
||||
"large": "facebook/m2m100_1.2B"
|
||||
}
|
||||
|
||||
# Translator
|
||||
class Translator():
|
||||
def __init__(self):
|
||||
pass
|
||||
self.translator_status = {}
|
||||
|
||||
self.translator = ctranslate2.Translator("D:\\WORKSPACE\\WORK\\VRChatProject\\VRCT\\weight", device="cpu", device_index=0, compute_type="int8", inter_threads=1, intra_threads=4)
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M")
|
||||
|
||||
def authentication(self, translator_name, authkey=None):
|
||||
result = True
|
||||
match translator_name:
|
||||
@@ -57,4 +69,19 @@ class Translator():
|
||||
with open('error.log', 'a') as f:
|
||||
traceback.print_exc(file=f)
|
||||
result = False
|
||||
return result
|
||||
|
||||
def translate_ctranslate2(self, translator_name, source_language, target_language, message):
|
||||
|
||||
source_language=translation_lang["ctranslate2"]["source"][source_language]
|
||||
target_language=translation_lang["ctranslate2"]["target"][target_language]
|
||||
|
||||
self.tokenizer.src_lang = source_language
|
||||
source = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(message))
|
||||
target_prefix = [self.tokenizer.lang_code_to_token[target_language]]
|
||||
results = self.translator.translate_batch([source], target_prefix=[target_prefix])
|
||||
target = results[0].hypotheses[0][1:]
|
||||
|
||||
result = self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(target))
|
||||
print(result)
|
||||
return result
|
||||
Reference in New Issue
Block a user