From 83a61e2e8756faf6462ec26e8e3b8e62ae38a40e Mon Sep 17 00:00:00 2001 From: misyaguziya <53165965+misyaguziya@users.noreply.github.com> Date: Mon, 8 Sep 2025 16:27:47 +0900 Subject: [PATCH] [Update] translation: Add support for new translation models and improve weight handling --- src-python/config.py | 2 +- src-python/model.py | 2 + .../translation/translation_languages.py | 230 +++++++++++++++++- .../translation/translation_translator.py | 46 +++- .../models/translation/translation_utils.py | 49 ++-- 5 files changed, 294 insertions(+), 35 deletions(-) diff --git a/src-python/config.py b/src-python/config.py index 76638fbe..ce509dc8 100644 --- a/src-python/config.py +++ b/src-python/config.py @@ -1188,7 +1188,7 @@ class Config: self._USE_EXCLUDE_WORDS = True self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0]) self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0]) - self._CTRANSLATE2_WEIGHT_TYPE = "small" + self._CTRANSLATE2_WEIGHT_TYPE = "m2m100_418M-ct2-int8" self._WHISPER_WEIGHT_TYPE = "base" self._AUTO_CLEAR_MESSAGE_BOX = True self._SEND_ONLY_TRANSLATED_MESSAGES = False diff --git a/src-python/model.py b/src-python/model.py index 333f1394..b4486bc0 100644 --- a/src-python/model.py +++ b/src-python/model.py @@ -190,6 +190,7 @@ class Model: success_flag = False translation = self.translator.translate( translator_name=translator_name, + weight_type=config.CTRANSLATE2_WEIGHT_TYPE, source_language=source_language, target_language=target_language, target_country=target_country, @@ -203,6 +204,7 @@ class Model: while True: translation = self.translator.translate( translator_name="CTranslate2", + weight_type=config.CTRANSLATE2_WEIGHT_TYPE, source_language=source_language, target_language=target_language, target_country=target_country, diff --git a/src-python/models/translation/translation_languages.py b/src-python/models/translation/translation_languages.py index a697960b..5839e249 100644 --- a/src-python/models/translation/translation_languages.py +++ b/src-python/models/translation/translation_languages.py @@ -275,7 +275,7 @@ translation_lang["Papago"] = { "target":dict_papago_languages, } -dict_ctranslate2_languages = { +dict_m2m100_languages = { "English": "en", "Chinese Simplified": "zh", "Chinese Traditional":"zh", @@ -378,7 +378,229 @@ dict_ctranslate2_languages = { "Sundanese": "su" } -translation_lang["CTranslate2"] = { - "source":dict_ctranslate2_languages, - "target":dict_ctranslate2_languages, +translation_lang["m2m100_418M-ct2-int8"] = { + "source":dict_m2m100_languages, + "target":dict_m2m100_languages, +} + +translation_lang["m2m100_1.2B-ct2-int8"] = { + "source":dict_m2m100_languages, + "target":dict_m2m100_languages, +} + +dict_nllb_languages = { + "Acehnese (Arabic script)": "ace_Arab", + "Acehnese (Latin script)": "ace_Latn", + "Mesopotamian Arabic": "acm_Arab", + "Ta’izzi-Adeni Arabic": "acq_Arab", + "Tunisian Arabic": "aeb_Arab", + "Afrikaans": "afr_Latn", + "South Levantine Arabic": "ajp_Arab", + "Akan": "aka_Latn", + "Amharic": "amh_Ethi", + "North Levantine Arabic": "apc_Arab", + "Modern Standard Arabic": "arb_Arab", + "Modern Standard Arabic (Romanized)": "arb_Latn", + "Najdi Arabic": "ars_Arab", + "Moroccan Arabic": "ary_Arab", + "Egyptian Arabic": "arz_Arab", + "Assamese": "asm_Beng", + "Asturian": "ast_Latn", + "Awadhi": "awa_Deva", + "Central Aymara": "ayr_Latn", + "South Azerbaijani": "azb_Arab", + "North Azerbaijani": "azj_Latn", + "Bashkir": "bak_Cyrl", + "Bambara": "bam_Latn", + "Balinese": "ban_Latn", + "Belarusian": "bel_Cyrl", + "Bemba": "bem_Latn", + "Bengali": "ben_Beng", + "Bhojpuri": "bho_Deva", + "Banjar (Arabic script)": "bjn_Arab", + "Banjar (Latin script)": "bjn_Latn", + "Standard Tibetan": "bod_Tibt", + "Bosnian": "bos_Latn", + "Buginese": "bug_Latn", + "Bulgarian": "bul_Cyrl", + "Catalan": "cat_Latn", + "Cebuano": "ceb_Latn", + "Czech": "ces_Latn", + "Chokwe": "cjk_Latn", + "Central Kurdish": "ckb_Arab", + "Crimean Tatar": "crh_Latn", + "Welsh": "cym_Latn", + "Danish": "dan_Latn", + "German": "deu_Latn", + "Southwestern Dinka": "dik_Latn", + "Dyula": "dyu_Latn", + "Dzongkha": "dzo_Tibt", + "Greek": "ell_Grek", + "English": "eng_Latn", + "Esperanto": "epo_Latn", + "Estonian": "est_Latn", + "Basque": "eus_Latn", + "Ewe": "ewe_Latn", + "Faroese": "fao_Latn", + "Fijian": "fij_Latn", + "Finnish": "fin_Latn", + "Fon": "fon_Latn", + "French": "fra_Latn", + "Friulian": "fur_Latn", + "Nigerian Fulfulde": "fuv_Latn", + "Scottish Gaelic": "gla_Latn", + "Irish": "gle_Latn", + "Galician": "glg_Latn", + "Guarani": "grn_Latn", + "Gujarati": "guj_Gujr", + "Haitian Creole": "hat_Latn", + "Hausa": "hau_Latn", + "Hebrew": "heb_Hebr", + "Hindi": "hin_Deva", + "Chhattisgarhi": "hne_Deva", + "Croatian": "hrv_Latn", + "Hungarian": "hun_Latn", + "Armenian": "hye_Armn", + "Igbo": "ibo_Latn", + "Ilocano": "ilo_Latn", + "Indonesian": "ind_Latn", + "Icelandic": "isl_Latn", + "Italian": "ita_Latn", + "Javanese": "jav_Latn", + "Japanese": "jpn_Jpan", + "Kabyle": "kab_Latn", + "Jingpho": "kac_Latn", + "Kamba": "kam_Latn", + "Kannada": "kan_Knda", + "Kashmiri (Arabic script)": "kas_Arab", + "Kashmiri (Devanagari script)": "kas_Deva", + "Georgian": "kat_Geor", + "Central Kanuri (Arabic script)": "knc_Arab", + "Central Kanuri (Latin script)": "knc_Latn", + "Kazakh": "kaz_Cyrl", + "Kabiyè": "kbp_Latn", + "Kabuverdianu": "kea_Latn", + "Khmer": "khm_Khmr", + "Kikuyu": "kik_Latn", + "Kinyarwanda": "kin_Latn", + "Kyrgyz": "kir_Cyrl", + "Kimbundu": "kmb_Latn", + "Northern Kurdish": "kmr_Latn", + "Kikongo": "kon_Latn", + "Korean": "kor_Hang", + "Lao": "lao_Laoo", + "Ligurian": "lij_Latn", + "Limburgish": "lim_Latn", + "Lingala": "lin_Latn", + "Lithuanian": "lit_Latn", + "Lombard": "lmo_Latn", + "Latgalian": "ltg_Latn", + "Luxembourgish": "ltz_Latn", + "Luba-Kasai": "lua_Latn", + "Ganda": "lug_Latn", + "Luo": "luo_Latn", + "Mizo": "lus_Latn", + "Standard Latvian": "lvs_Latn", + "Magahi": "mag_Deva", + "Maithili": "mai_Deva", + "Malayalam": "mal_Mlym", + "Marathi": "mar_Deva", + "Minangkabau (Arabic script)": "min_Arab", + "Minangkabau (Latin script)": "min_Latn", + "Macedonian": "mkd_Cyrl", + "Plateau Malagasy": "plt_Latn", + "Maltese": "mlt_Latn", + "Meitei (Bengali script)": "mni_Beng", + "Halh Mongolian": "khk_Cyrl", + "Mossi": "mos_Latn", + "Maori": "mri_Latn", + "Burmese": "mya_Mymr", + "Dutch": "nld_Latn", + "Norwegian Nynorsk": "nno_Latn", + "Norwegian Bokmål": "nob_Latn", + "Nepali": "npi_Deva", + "Northern Sotho": "nso_Latn", + "Nuer": "nus_Latn", + "Nyanja": "nya_Latn", + "Occitan": "oci_Latn", + "West Central Oromo": "gaz_Latn", + "Odia": "ory_Orya", + "Pangasinan": "pag_Latn", + "Eastern Panjabi": "pan_Guru", + "Papiamento": "pap_Latn", + "Western Persian": "pes_Arab", + "Polish": "pol_Latn", + "Portuguese": "por_Latn", + "Dari": "prs_Arab", + "Southern Pashto": "pbt_Arab", + "Ayacucho Quechua": "quy_Latn", + "Romanian": "ron_Latn", + "Rundi": "run_Latn", + "Russian": "rus_Cyrl", + "Sango": "sag_Latn", + "Sanskrit": "san_Deva", + "Santali": "sat_Olck", + "Sicilian": "scn_Latn", + "Shan": "shn_Mymr", + "Sinhala": "sin_Sinh", + "Slovak": "slk_Latn", + "Slovenian": "slv_Latn", + "Samoan": "smo_Latn", + "Shona": "sna_Latn", + "Sindhi": "snd_Arab", + "Somali": "som_Latn", + "Southern Sotho": "sot_Latn", + "Spanish": "spa_Latn", + "Tosk Albanian": "als_Latn", + "Sardinian": "srd_Latn", + "Serbian": "srp_Cyrl", + "Swati": "ssw_Latn", + "Sundanese": "sun_Latn", + "Swedish": "swe_Latn", + "Swahili": "swh_Latn", + "Silesian": "szl_Latn", + "Tamil": "tam_Taml", + "Tatar": "tat_Cyrl", + "Telugu": "tel_Telu", + "Tajik": "tgk_Cyrl", + "Tagalog": "tgl_Latn", + "Thai": "tha_Thai", + "Tigrinya": "tir_Ethi", + "Tamasheq (Latin script)": "taq_Latn", + "Tamasheq (Tifinagh script)": "taq_Tfng", + "Tok Pisin": "tpi_Latn", + "Tswana": "tsn_Latn", + "Tsonga": "tso_Latn", + "Turkmen": "tuk_Latn", + "Tumbuka": "tum_Latn", + "Turkish": "tur_Latn", + "Twi": "twi_Latn", + "Central Atlas Tamazight": "tzm_Tfng", + "Uyghur": "uig_Arab", + "Ukrainian": "ukr_Cyrl", + "Umbundu": "umb_Latn", + "Urdu": "urd_Arab", + "Northern Uzbek": "uzn_Latn", + "Venetian": "vec_Latn", + "Vietnamese": "vie_Latn", + "Waray": "war_Latn", + "Wolof": "wol_Latn", + "Xhosa": "xho_Latn", + "Eastern Yiddish": "ydd_Hebr", + "Yoruba": "yor_Latn", + "Yue Chinese": "yue_Hant", + "Chinese Simplified": "zho_Hans", + "Chinese Traditional": "zho_Hant", + "Standard Malay": "zsm_Latn", + "Zulu": "zul_Latn" +} + +translation_lang["nllb-200-distilled-1.3B-ct2-int8"] = { + "source":dict_nllb_languages, + "target":dict_nllb_languages, +} + +translation_lang["nllb-200-3.3B-ct2-int8"] = { + "source":dict_nllb_languages, + "target":dict_nllb_languages, } \ No newline at end of file diff --git a/src-python/models/translation/translation_translator.py b/src-python/models/translation/translation_translator.py index 42eb828e..23a82d04 100644 --- a/src-python/models/translation/translation_translator.py +++ b/src-python/models/translation/translation_translator.py @@ -6,8 +6,15 @@ try: except Exception: ENABLE_TRANSLATORS = False -from .translation_languages import translation_lang -from .translation_utils import ctranslate2_weights +try: + from .translation_languages import translation_lang + from .translation_utils import ctranslate2_weights +except Exception: + import sys + print(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__))))) + sys.path.append(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__))))) + from translation_languages import translation_lang + from translation_utils import ctranslate2_weights import ctranslate2 import transformers @@ -63,13 +70,19 @@ class Translator(): def isLoadedCTranslate2Model(self): return self.is_loaded_ctranslate2_model - def translateCTranslate2(self, message, source_language, target_language): + def translateCTranslate2(self, message, source_language, target_language, weight_type): result = False if self.is_loaded_ctranslate2_model is True: try: self.ctranslate2_tokenizer.src_lang = source_language source = self.ctranslate2_tokenizer.convert_ids_to_tokens(self.ctranslate2_tokenizer.encode(message)) - target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]] + match weight_type: + case "m2m100_418M-ct2-int8" | "m2m100_1.2B-ct2-int8": + target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]] + case "nllb-200-distilled-1.3B-ct2-int8" | "nllb-200-3.3B-ct2-int8": + target_prefix = [target_language] + case _: + return False results = self.ctranslate2_translator.translate_batch([source], target_prefix=[target_prefix]) target = results[0].hypotheses[0][1:] result = self.ctranslate2_tokenizer.decode(self.ctranslate2_tokenizer.convert_tokens_to_ids(target)) @@ -78,7 +91,7 @@ class Translator(): return result @staticmethod - def getLanguageCode(translator_name, target_country, source_language, target_language): + def getLanguageCode(translator_name, weight_type, target_country, source_language, target_language): match translator_name: case "DeepL_API": if target_language == "English": @@ -91,16 +104,18 @@ class Translator(): target_language = "Portuguese European" else: target_language = "Portuguese Brazilian" + case "CTranslate2": + translator_name = weight_type case _: pass source_language=translation_lang[translator_name]["source"][source_language] target_language=translation_lang[translator_name]["target"][target_language] return source_language, target_language - def translate(self, translator_name, source_language, target_language, target_country, message): + def translate(self, translator_name, weight_type, source_language, target_language, target_country, message): try: result = "" - source_language, target_language = self.getLanguageCode(translator_name, target_country, source_language, target_language) + source_language, target_language = self.getLanguageCode(translator_name, weight_type, target_country, source_language, target_language) match translator_name: case "DeepL": if self.is_enable_translators is True: @@ -149,8 +164,23 @@ class Translator(): message=message, source_language=source_language, target_language=target_language, + weight_type=weight_type, ) except Exception: errorLogging() result = False - return result \ No newline at end of file + return result + +if __name__ == "__main__": + translator = Translator() + # test CTranslate2 model nllb-200-distilled-1.3B-ct2-int8 + translator.changeCTranslate2Model(path=".", model_type="nllb-200-distilled-1.3B-ct2-int8", device="cpu", device_index=0) + result = translator.translate( + translator_name="CTranslate2", + weight_type="nllb-200-distilled-1.3B-ct2-int8", + source_language="English", + target_language="Japanese", + target_country="Japan", + message="Hello, world!" + ) + print(result) \ No newline at end of file diff --git a/src-python/models/translation/translation_utils.py b/src-python/models/translation/translation_utils.py index ef0e5592..1bc8fed6 100644 --- a/src-python/models/translation/translation_utils.py +++ b/src-python/models/translation/translation_utils.py @@ -5,50 +5,50 @@ import transformers import ctranslate2 from huggingface_hub import hf_hub_url, list_repo_files from requests import get as requests_get + try: - from utils import errorLogging + from utils import errorLogging, getBestComputeType except Exception: - import traceback - def errorLogging(): - print(traceback.format_exc()) + import sys + print(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__))))) + sys.path.append(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__))))) + from utils import errorLogging, getBestComputeType ctranslate2_weights = { - "small": { - # "hf_repo": "jncraton/m2m100_418M-ct2-int8", - # "directory_name": "m2m100_418M-ct2-int8", - # "tokenizer": "facebook/m2m100_418M", - "hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8", - "directory_name": "nllb-200-distilled-1.3B-ct2-int8", - "tokenizer": "facebook/nllb-200-distilled-1.3B", + "m2m100_418M-ct2-int8": { + "hf_repo": "jncraton/m2m100_418M-ct2-int8", + "directory_name": "m2m100_418M-ct2-int8", + "tokenizer": "facebook/m2m100_418M", }, - "large": { + "m2m100_1.2B-ct2-int8": { "hf_repo": "jncraton/m2m100_1.2B-ct2-int8", "directory_name": "m2m100_1.2B-ct2-int8", "tokenizer": "facebook/m2m100_1.2B", }, + "nllb-200-distilled-1.3B-ct2-int8": { + "hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8", + "directory_name": "nllb-200-distilled-1.3B-ct2-int8", + "tokenizer": "facebook/nllb-200-distilled-1.3B", + }, "nllb-200-3.3B-ct2-int8": { "hf_repo": "OpenNMT/nllb-200-3.3B-ct2-int8", "directory_name": "nllb-200-3.3B-ct2-int8", "tokenizer": "facebook/nllb-200-3.3B", }, - "nllb-200-distilled-1.3B": { - "hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8", - "directory_name": "nllb-200-distilled-1.3B-ct2-int8", - "tokenizer": "facebook/nllb-200-distilled-1.3B", - }, } -def checkCTranslate2Weight(root: str, weight_type: str = "small"): +def checkCTranslate2Weight(root: str, weight_type: str = "m2m100_418M-ct2-int8"): weight_directory_name = ctranslate2_weights[weight_type]["directory_name"] path = os_path.join(root, "weights", "ctranslate2", weight_directory_name) try: # モデルロード可能かどうかで判定 - ctranslate2.Translator(path) + compute_type = getBestComputeType("cpu", 0) + ctranslate2.Translator(path, compute_type=compute_type) return True except Exception: return False -def downloadCTranslate2Weight(root: str, weight_type: str = "small", callback: Callable = None, end_callback: Callable = None): +def downloadCTranslate2Weight(root: str, weight_type: str = "m2m100_418M-ct2-int8", callback: Callable = None, end_callback: Callable = None): hf_repo = ctranslate2_weights[weight_type]["hf_repo"] files = list_repo_files(repo_id=hf_repo) path = os_path.join(root, "weights", "ctranslate2", ctranslate2_weights[weight_type]["directory_name"]) @@ -79,7 +79,7 @@ def downloadCTranslate2Weight(root: str, weight_type: str = "small", callback: C if end_callback is not None: end_callback() -def downloadCTranslate2Tokenizer(path: str, weight_type: str = "small"): +def downloadCTranslate2Tokenizer(path: str, weight_type: str = "m2m100_418M-ct2-int8"): directory_name = ctranslate2_weights[weight_type]["directory_name"] tokenizer = ctranslate2_weights[weight_type]["tokenizer"] tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer") @@ -106,4 +106,9 @@ if __name__ == "__main__": # result = checkCTranslate2Weight(root, weight_type) # print(f"Model loadable: {result}") # break - downloadCTranslate2Tokenizer(root, "small") + # downloadCTranslate2Tokenizer(root, "m2m100_418M-ct2-int8") + + # model download test + downloadCTranslate2Weight(root, "nllb-200-distilled-1.3B", callback=progress_callback, end_callback=end_callback) + result = checkCTranslate2Weight(root, "nllb-200-distilled-1.3B") + print(f"Model loadable: {result}") \ No newline at end of file