diff --git a/src-python/model.py b/src-python/model.py index 494c72fb..3d5537a8 100644 --- a/src-python/model.py +++ b/src-python/model.py @@ -222,8 +222,12 @@ class Model: transcription_langs = list(transcription_lang.keys()) translation_langs = [] for tl_key in translation_lang.keys(): - for lang in translation_lang[tl_key]["source"]: - translation_langs.append(lang) + if tl_key == "CTranslate2": + for lang in translation_lang[tl_key][config.CTRANSLATE2_WEIGHT_TYPE]["source"]: + translation_langs.append(lang) + else: + for lang in translation_lang[tl_key]["source"]: + translation_langs.append(lang) translation_langs = list(set(translation_langs)) supported_langs = list(filter(lambda x: x in transcription_langs, translation_langs)) @@ -243,7 +247,10 @@ class Model: selectable_engines = [key for key, value in engines_status.items() if value is True] compatible_engines = [] for engine in list(translation_lang.keys()): - languages = translation_lang.get(engine, {}).get("source", {}) + if engine == "CTranslate2": + languages = translation_lang.get(engine, {}).get(config.CTRANSLATE2_WEIGHT_TYPE, {}).get("source", {}) + else: + languages = translation_lang.get(engine, {}).get("source", {}) source_langs = [e["language"] for e in list(source_lang.values()) if e["enable"] is True] target_langs = [e["language"] for e in list(target_lang.values()) if e["enable"] is True] language_list = list(languages.keys()) diff --git a/src-python/models/translation/translation_languages.py b/src-python/models/translation/translation_languages.py index 9a54430a..b08292aa 100644 --- a/src-python/models/translation/translation_languages.py +++ b/src-python/models/translation/translation_languages.py @@ -372,8 +372,9 @@ dict_m2m100_languages = { "Sundanese": "su" } -translation_lang["m2m100_418M-ct2-int8"] = {"source":dict_m2m100_languages, "target":dict_m2m100_languages} -translation_lang["m2m100_1.2B-ct2-int8"] = {"source":dict_m2m100_languages, "target":dict_m2m100_languages} +translation_lang["CTranslate2"] = {} +translation_lang["CTranslate2"]["m2m100_418M-ct2-int8"] = {"source":dict_m2m100_languages, "target":dict_m2m100_languages} +translation_lang["CTranslate2"]["m2m100_1.2B-ct2-int8"] = {"source":dict_m2m100_languages, "target":dict_m2m100_languages} dict_nllb_languages = { "Acehnese (Arabic script)": "ace_Arab", @@ -582,8 +583,8 @@ dict_nllb_languages = { "Zulu": "zul_Latn" } -translation_lang["nllb-200-distilled-1.3B-ct2-int8"] = {"source":dict_nllb_languages, "target":dict_nllb_languages} -translation_lang["nllb-200-3.3B-ct2-int8"] = {"source":dict_nllb_languages, "target":dict_nllb_languages} +translation_lang["CTranslate2"]["nllb-200-distilled-1.3B-ct2-int8"] = {"source":dict_nllb_languages, "target":dict_nllb_languages} +translation_lang["CTranslate2"]["nllb-200-3.3B-ct2-int8"] = {"source":dict_nllb_languages, "target":dict_nllb_languages} dict_plamo_languages = { "English": "English", diff --git a/src-python/models/translation/translation_translator.py b/src-python/models/translation/translation_translator.py index 8b240ffd..bc4d26f1 100644 --- a/src-python/models/translation/translation_translator.py +++ b/src-python/models/translation/translation_translator.py @@ -177,12 +177,14 @@ class Translator: target_language = "Portuguese European" else: target_language = "Portuguese Brazilian" + source_language = translation_lang[translator_name]["source"][source_language] + target_language = translation_lang[translator_name]["target"][target_language] case "CTranslate2": - translator_name = weight_type + source_language = translation_lang[translator_name][weight_type]["source"][source_language] + target_language = translation_lang[translator_name][weight_type]["target"][target_language] case _: - pass - source_language = translation_lang[translator_name]["source"][source_language] - target_language = translation_lang[translator_name]["target"][target_language] + source_language = translation_lang[translator_name]["source"][source_language] + target_language = translation_lang[translator_name]["target"][target_language] return source_language, target_language def translate(self, translator_name: str, weight_type: str, source_language: str, target_language: str, target_country: str, message: str) -> Any: diff --git a/src-python/models/translation/translation_utils.py b/src-python/models/translation/translation_utils.py index 688f131f..895a9680 100644 --- a/src-python/models/translation/translation_utils.py +++ b/src-python/models/translation/translation_utils.py @@ -95,8 +95,8 @@ def downloadCTranslate2Tokenizer(path: str, weight_type: str = "m2m100_418M-ct2- tokenizer = ctranslate2_weights[weight_type]["tokenizer"] tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer") try: - os_makedirs(tokenizer_cache, exist_ok=True) - transformers.AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=tokenizer_cache) + os_makedirs(tokenizer_path, exist_ok=True) + transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) except Exception: errorLogging() tokenizer_path = os_path.join("./weights", "ctranslate2", directory_name, "tokenizer")