[Update] translation: Add support for new translation models and improve weight handling
This commit is contained in:
@@ -1188,7 +1188,7 @@ class Config:
|
||||
self._USE_EXCLUDE_WORDS = True
|
||||
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||
self._CTRANSLATE2_WEIGHT_TYPE = "small"
|
||||
self._CTRANSLATE2_WEIGHT_TYPE = "m2m100_418M-ct2-int8"
|
||||
self._WHISPER_WEIGHT_TYPE = "base"
|
||||
self._AUTO_CLEAR_MESSAGE_BOX = True
|
||||
self._SEND_ONLY_TRANSLATED_MESSAGES = False
|
||||
|
||||
@@ -190,6 +190,7 @@ class Model:
|
||||
success_flag = False
|
||||
translation = self.translator.translate(
|
||||
translator_name=translator_name,
|
||||
weight_type=config.CTRANSLATE2_WEIGHT_TYPE,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
target_country=target_country,
|
||||
@@ -203,6 +204,7 @@ class Model:
|
||||
while True:
|
||||
translation = self.translator.translate(
|
||||
translator_name="CTranslate2",
|
||||
weight_type=config.CTRANSLATE2_WEIGHT_TYPE,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
target_country=target_country,
|
||||
|
||||
@@ -275,7 +275,7 @@ translation_lang["Papago"] = {
|
||||
"target":dict_papago_languages,
|
||||
}
|
||||
|
||||
dict_ctranslate2_languages = {
|
||||
dict_m2m100_languages = {
|
||||
"English": "en",
|
||||
"Chinese Simplified": "zh",
|
||||
"Chinese Traditional":"zh",
|
||||
@@ -378,7 +378,229 @@ dict_ctranslate2_languages = {
|
||||
"Sundanese": "su"
|
||||
}
|
||||
|
||||
translation_lang["CTranslate2"] = {
|
||||
"source":dict_ctranslate2_languages,
|
||||
"target":dict_ctranslate2_languages,
|
||||
translation_lang["m2m100_418M-ct2-int8"] = {
|
||||
"source":dict_m2m100_languages,
|
||||
"target":dict_m2m100_languages,
|
||||
}
|
||||
|
||||
translation_lang["m2m100_1.2B-ct2-int8"] = {
|
||||
"source":dict_m2m100_languages,
|
||||
"target":dict_m2m100_languages,
|
||||
}
|
||||
|
||||
dict_nllb_languages = {
|
||||
"Acehnese (Arabic script)": "ace_Arab",
|
||||
"Acehnese (Latin script)": "ace_Latn",
|
||||
"Mesopotamian Arabic": "acm_Arab",
|
||||
"Ta’izzi-Adeni Arabic": "acq_Arab",
|
||||
"Tunisian Arabic": "aeb_Arab",
|
||||
"Afrikaans": "afr_Latn",
|
||||
"South Levantine Arabic": "ajp_Arab",
|
||||
"Akan": "aka_Latn",
|
||||
"Amharic": "amh_Ethi",
|
||||
"North Levantine Arabic": "apc_Arab",
|
||||
"Modern Standard Arabic": "arb_Arab",
|
||||
"Modern Standard Arabic (Romanized)": "arb_Latn",
|
||||
"Najdi Arabic": "ars_Arab",
|
||||
"Moroccan Arabic": "ary_Arab",
|
||||
"Egyptian Arabic": "arz_Arab",
|
||||
"Assamese": "asm_Beng",
|
||||
"Asturian": "ast_Latn",
|
||||
"Awadhi": "awa_Deva",
|
||||
"Central Aymara": "ayr_Latn",
|
||||
"South Azerbaijani": "azb_Arab",
|
||||
"North Azerbaijani": "azj_Latn",
|
||||
"Bashkir": "bak_Cyrl",
|
||||
"Bambara": "bam_Latn",
|
||||
"Balinese": "ban_Latn",
|
||||
"Belarusian": "bel_Cyrl",
|
||||
"Bemba": "bem_Latn",
|
||||
"Bengali": "ben_Beng",
|
||||
"Bhojpuri": "bho_Deva",
|
||||
"Banjar (Arabic script)": "bjn_Arab",
|
||||
"Banjar (Latin script)": "bjn_Latn",
|
||||
"Standard Tibetan": "bod_Tibt",
|
||||
"Bosnian": "bos_Latn",
|
||||
"Buginese": "bug_Latn",
|
||||
"Bulgarian": "bul_Cyrl",
|
||||
"Catalan": "cat_Latn",
|
||||
"Cebuano": "ceb_Latn",
|
||||
"Czech": "ces_Latn",
|
||||
"Chokwe": "cjk_Latn",
|
||||
"Central Kurdish": "ckb_Arab",
|
||||
"Crimean Tatar": "crh_Latn",
|
||||
"Welsh": "cym_Latn",
|
||||
"Danish": "dan_Latn",
|
||||
"German": "deu_Latn",
|
||||
"Southwestern Dinka": "dik_Latn",
|
||||
"Dyula": "dyu_Latn",
|
||||
"Dzongkha": "dzo_Tibt",
|
||||
"Greek": "ell_Grek",
|
||||
"English": "eng_Latn",
|
||||
"Esperanto": "epo_Latn",
|
||||
"Estonian": "est_Latn",
|
||||
"Basque": "eus_Latn",
|
||||
"Ewe": "ewe_Latn",
|
||||
"Faroese": "fao_Latn",
|
||||
"Fijian": "fij_Latn",
|
||||
"Finnish": "fin_Latn",
|
||||
"Fon": "fon_Latn",
|
||||
"French": "fra_Latn",
|
||||
"Friulian": "fur_Latn",
|
||||
"Nigerian Fulfulde": "fuv_Latn",
|
||||
"Scottish Gaelic": "gla_Latn",
|
||||
"Irish": "gle_Latn",
|
||||
"Galician": "glg_Latn",
|
||||
"Guarani": "grn_Latn",
|
||||
"Gujarati": "guj_Gujr",
|
||||
"Haitian Creole": "hat_Latn",
|
||||
"Hausa": "hau_Latn",
|
||||
"Hebrew": "heb_Hebr",
|
||||
"Hindi": "hin_Deva",
|
||||
"Chhattisgarhi": "hne_Deva",
|
||||
"Croatian": "hrv_Latn",
|
||||
"Hungarian": "hun_Latn",
|
||||
"Armenian": "hye_Armn",
|
||||
"Igbo": "ibo_Latn",
|
||||
"Ilocano": "ilo_Latn",
|
||||
"Indonesian": "ind_Latn",
|
||||
"Icelandic": "isl_Latn",
|
||||
"Italian": "ita_Latn",
|
||||
"Javanese": "jav_Latn",
|
||||
"Japanese": "jpn_Jpan",
|
||||
"Kabyle": "kab_Latn",
|
||||
"Jingpho": "kac_Latn",
|
||||
"Kamba": "kam_Latn",
|
||||
"Kannada": "kan_Knda",
|
||||
"Kashmiri (Arabic script)": "kas_Arab",
|
||||
"Kashmiri (Devanagari script)": "kas_Deva",
|
||||
"Georgian": "kat_Geor",
|
||||
"Central Kanuri (Arabic script)": "knc_Arab",
|
||||
"Central Kanuri (Latin script)": "knc_Latn",
|
||||
"Kazakh": "kaz_Cyrl",
|
||||
"Kabiyè": "kbp_Latn",
|
||||
"Kabuverdianu": "kea_Latn",
|
||||
"Khmer": "khm_Khmr",
|
||||
"Kikuyu": "kik_Latn",
|
||||
"Kinyarwanda": "kin_Latn",
|
||||
"Kyrgyz": "kir_Cyrl",
|
||||
"Kimbundu": "kmb_Latn",
|
||||
"Northern Kurdish": "kmr_Latn",
|
||||
"Kikongo": "kon_Latn",
|
||||
"Korean": "kor_Hang",
|
||||
"Lao": "lao_Laoo",
|
||||
"Ligurian": "lij_Latn",
|
||||
"Limburgish": "lim_Latn",
|
||||
"Lingala": "lin_Latn",
|
||||
"Lithuanian": "lit_Latn",
|
||||
"Lombard": "lmo_Latn",
|
||||
"Latgalian": "ltg_Latn",
|
||||
"Luxembourgish": "ltz_Latn",
|
||||
"Luba-Kasai": "lua_Latn",
|
||||
"Ganda": "lug_Latn",
|
||||
"Luo": "luo_Latn",
|
||||
"Mizo": "lus_Latn",
|
||||
"Standard Latvian": "lvs_Latn",
|
||||
"Magahi": "mag_Deva",
|
||||
"Maithili": "mai_Deva",
|
||||
"Malayalam": "mal_Mlym",
|
||||
"Marathi": "mar_Deva",
|
||||
"Minangkabau (Arabic script)": "min_Arab",
|
||||
"Minangkabau (Latin script)": "min_Latn",
|
||||
"Macedonian": "mkd_Cyrl",
|
||||
"Plateau Malagasy": "plt_Latn",
|
||||
"Maltese": "mlt_Latn",
|
||||
"Meitei (Bengali script)": "mni_Beng",
|
||||
"Halh Mongolian": "khk_Cyrl",
|
||||
"Mossi": "mos_Latn",
|
||||
"Maori": "mri_Latn",
|
||||
"Burmese": "mya_Mymr",
|
||||
"Dutch": "nld_Latn",
|
||||
"Norwegian Nynorsk": "nno_Latn",
|
||||
"Norwegian Bokmål": "nob_Latn",
|
||||
"Nepali": "npi_Deva",
|
||||
"Northern Sotho": "nso_Latn",
|
||||
"Nuer": "nus_Latn",
|
||||
"Nyanja": "nya_Latn",
|
||||
"Occitan": "oci_Latn",
|
||||
"West Central Oromo": "gaz_Latn",
|
||||
"Odia": "ory_Orya",
|
||||
"Pangasinan": "pag_Latn",
|
||||
"Eastern Panjabi": "pan_Guru",
|
||||
"Papiamento": "pap_Latn",
|
||||
"Western Persian": "pes_Arab",
|
||||
"Polish": "pol_Latn",
|
||||
"Portuguese": "por_Latn",
|
||||
"Dari": "prs_Arab",
|
||||
"Southern Pashto": "pbt_Arab",
|
||||
"Ayacucho Quechua": "quy_Latn",
|
||||
"Romanian": "ron_Latn",
|
||||
"Rundi": "run_Latn",
|
||||
"Russian": "rus_Cyrl",
|
||||
"Sango": "sag_Latn",
|
||||
"Sanskrit": "san_Deva",
|
||||
"Santali": "sat_Olck",
|
||||
"Sicilian": "scn_Latn",
|
||||
"Shan": "shn_Mymr",
|
||||
"Sinhala": "sin_Sinh",
|
||||
"Slovak": "slk_Latn",
|
||||
"Slovenian": "slv_Latn",
|
||||
"Samoan": "smo_Latn",
|
||||
"Shona": "sna_Latn",
|
||||
"Sindhi": "snd_Arab",
|
||||
"Somali": "som_Latn",
|
||||
"Southern Sotho": "sot_Latn",
|
||||
"Spanish": "spa_Latn",
|
||||
"Tosk Albanian": "als_Latn",
|
||||
"Sardinian": "srd_Latn",
|
||||
"Serbian": "srp_Cyrl",
|
||||
"Swati": "ssw_Latn",
|
||||
"Sundanese": "sun_Latn",
|
||||
"Swedish": "swe_Latn",
|
||||
"Swahili": "swh_Latn",
|
||||
"Silesian": "szl_Latn",
|
||||
"Tamil": "tam_Taml",
|
||||
"Tatar": "tat_Cyrl",
|
||||
"Telugu": "tel_Telu",
|
||||
"Tajik": "tgk_Cyrl",
|
||||
"Tagalog": "tgl_Latn",
|
||||
"Thai": "tha_Thai",
|
||||
"Tigrinya": "tir_Ethi",
|
||||
"Tamasheq (Latin script)": "taq_Latn",
|
||||
"Tamasheq (Tifinagh script)": "taq_Tfng",
|
||||
"Tok Pisin": "tpi_Latn",
|
||||
"Tswana": "tsn_Latn",
|
||||
"Tsonga": "tso_Latn",
|
||||
"Turkmen": "tuk_Latn",
|
||||
"Tumbuka": "tum_Latn",
|
||||
"Turkish": "tur_Latn",
|
||||
"Twi": "twi_Latn",
|
||||
"Central Atlas Tamazight": "tzm_Tfng",
|
||||
"Uyghur": "uig_Arab",
|
||||
"Ukrainian": "ukr_Cyrl",
|
||||
"Umbundu": "umb_Latn",
|
||||
"Urdu": "urd_Arab",
|
||||
"Northern Uzbek": "uzn_Latn",
|
||||
"Venetian": "vec_Latn",
|
||||
"Vietnamese": "vie_Latn",
|
||||
"Waray": "war_Latn",
|
||||
"Wolof": "wol_Latn",
|
||||
"Xhosa": "xho_Latn",
|
||||
"Eastern Yiddish": "ydd_Hebr",
|
||||
"Yoruba": "yor_Latn",
|
||||
"Yue Chinese": "yue_Hant",
|
||||
"Chinese Simplified": "zho_Hans",
|
||||
"Chinese Traditional": "zho_Hant",
|
||||
"Standard Malay": "zsm_Latn",
|
||||
"Zulu": "zul_Latn"
|
||||
}
|
||||
|
||||
translation_lang["nllb-200-distilled-1.3B-ct2-int8"] = {
|
||||
"source":dict_nllb_languages,
|
||||
"target":dict_nllb_languages,
|
||||
}
|
||||
|
||||
translation_lang["nllb-200-3.3B-ct2-int8"] = {
|
||||
"source":dict_nllb_languages,
|
||||
"target":dict_nllb_languages,
|
||||
}
|
||||
@@ -6,8 +6,15 @@ try:
|
||||
except Exception:
|
||||
ENABLE_TRANSLATORS = False
|
||||
|
||||
from .translation_languages import translation_lang
|
||||
from .translation_utils import ctranslate2_weights
|
||||
try:
|
||||
from .translation_languages import translation_lang
|
||||
from .translation_utils import ctranslate2_weights
|
||||
except Exception:
|
||||
import sys
|
||||
print(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__)))))
|
||||
sys.path.append(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__)))))
|
||||
from translation_languages import translation_lang
|
||||
from translation_utils import ctranslate2_weights
|
||||
|
||||
import ctranslate2
|
||||
import transformers
|
||||
@@ -63,13 +70,19 @@ class Translator():
|
||||
def isLoadedCTranslate2Model(self):
|
||||
return self.is_loaded_ctranslate2_model
|
||||
|
||||
def translateCTranslate2(self, message, source_language, target_language):
|
||||
def translateCTranslate2(self, message, source_language, target_language, weight_type):
|
||||
result = False
|
||||
if self.is_loaded_ctranslate2_model is True:
|
||||
try:
|
||||
self.ctranslate2_tokenizer.src_lang = source_language
|
||||
source = self.ctranslate2_tokenizer.convert_ids_to_tokens(self.ctranslate2_tokenizer.encode(message))
|
||||
target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]]
|
||||
match weight_type:
|
||||
case "m2m100_418M-ct2-int8" | "m2m100_1.2B-ct2-int8":
|
||||
target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]]
|
||||
case "nllb-200-distilled-1.3B-ct2-int8" | "nllb-200-3.3B-ct2-int8":
|
||||
target_prefix = [target_language]
|
||||
case _:
|
||||
return False
|
||||
results = self.ctranslate2_translator.translate_batch([source], target_prefix=[target_prefix])
|
||||
target = results[0].hypotheses[0][1:]
|
||||
result = self.ctranslate2_tokenizer.decode(self.ctranslate2_tokenizer.convert_tokens_to_ids(target))
|
||||
@@ -78,7 +91,7 @@ class Translator():
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def getLanguageCode(translator_name, target_country, source_language, target_language):
|
||||
def getLanguageCode(translator_name, weight_type, target_country, source_language, target_language):
|
||||
match translator_name:
|
||||
case "DeepL_API":
|
||||
if target_language == "English":
|
||||
@@ -91,16 +104,18 @@ class Translator():
|
||||
target_language = "Portuguese European"
|
||||
else:
|
||||
target_language = "Portuguese Brazilian"
|
||||
case "CTranslate2":
|
||||
translator_name = weight_type
|
||||
case _:
|
||||
pass
|
||||
source_language=translation_lang[translator_name]["source"][source_language]
|
||||
target_language=translation_lang[translator_name]["target"][target_language]
|
||||
return source_language, target_language
|
||||
|
||||
def translate(self, translator_name, source_language, target_language, target_country, message):
|
||||
def translate(self, translator_name, weight_type, source_language, target_language, target_country, message):
|
||||
try:
|
||||
result = ""
|
||||
source_language, target_language = self.getLanguageCode(translator_name, target_country, source_language, target_language)
|
||||
source_language, target_language = self.getLanguageCode(translator_name, weight_type, target_country, source_language, target_language)
|
||||
match translator_name:
|
||||
case "DeepL":
|
||||
if self.is_enable_translators is True:
|
||||
@@ -149,8 +164,23 @@ class Translator():
|
||||
message=message,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
weight_type=weight_type,
|
||||
)
|
||||
except Exception:
|
||||
errorLogging()
|
||||
result = False
|
||||
return result
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
translator = Translator()
|
||||
# test CTranslate2 model nllb-200-distilled-1.3B-ct2-int8
|
||||
translator.changeCTranslate2Model(path=".", model_type="nllb-200-distilled-1.3B-ct2-int8", device="cpu", device_index=0)
|
||||
result = translator.translate(
|
||||
translator_name="CTranslate2",
|
||||
weight_type="nllb-200-distilled-1.3B-ct2-int8",
|
||||
source_language="English",
|
||||
target_language="Japanese",
|
||||
target_country="Japan",
|
||||
message="Hello, world!"
|
||||
)
|
||||
print(result)
|
||||
@@ -5,50 +5,50 @@ import transformers
|
||||
import ctranslate2
|
||||
from huggingface_hub import hf_hub_url, list_repo_files
|
||||
from requests import get as requests_get
|
||||
|
||||
try:
|
||||
from utils import errorLogging
|
||||
from utils import errorLogging, getBestComputeType
|
||||
except Exception:
|
||||
import traceback
|
||||
def errorLogging():
|
||||
print(traceback.format_exc())
|
||||
import sys
|
||||
print(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__)))))
|
||||
sys.path.append(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__)))))
|
||||
from utils import errorLogging, getBestComputeType
|
||||
|
||||
ctranslate2_weights = {
|
||||
"small": {
|
||||
# "hf_repo": "jncraton/m2m100_418M-ct2-int8",
|
||||
# "directory_name": "m2m100_418M-ct2-int8",
|
||||
# "tokenizer": "facebook/m2m100_418M",
|
||||
"hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8",
|
||||
"directory_name": "nllb-200-distilled-1.3B-ct2-int8",
|
||||
"tokenizer": "facebook/nllb-200-distilled-1.3B",
|
||||
"m2m100_418M-ct2-int8": {
|
||||
"hf_repo": "jncraton/m2m100_418M-ct2-int8",
|
||||
"directory_name": "m2m100_418M-ct2-int8",
|
||||
"tokenizer": "facebook/m2m100_418M",
|
||||
},
|
||||
"large": {
|
||||
"m2m100_1.2B-ct2-int8": {
|
||||
"hf_repo": "jncraton/m2m100_1.2B-ct2-int8",
|
||||
"directory_name": "m2m100_1.2B-ct2-int8",
|
||||
"tokenizer": "facebook/m2m100_1.2B",
|
||||
},
|
||||
"nllb-200-distilled-1.3B-ct2-int8": {
|
||||
"hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8",
|
||||
"directory_name": "nllb-200-distilled-1.3B-ct2-int8",
|
||||
"tokenizer": "facebook/nllb-200-distilled-1.3B",
|
||||
},
|
||||
"nllb-200-3.3B-ct2-int8": {
|
||||
"hf_repo": "OpenNMT/nllb-200-3.3B-ct2-int8",
|
||||
"directory_name": "nllb-200-3.3B-ct2-int8",
|
||||
"tokenizer": "facebook/nllb-200-3.3B",
|
||||
},
|
||||
"nllb-200-distilled-1.3B": {
|
||||
"hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8",
|
||||
"directory_name": "nllb-200-distilled-1.3B-ct2-int8",
|
||||
"tokenizer": "facebook/nllb-200-distilled-1.3B",
|
||||
},
|
||||
}
|
||||
|
||||
def checkCTranslate2Weight(root: str, weight_type: str = "small"):
|
||||
def checkCTranslate2Weight(root: str, weight_type: str = "m2m100_418M-ct2-int8"):
|
||||
weight_directory_name = ctranslate2_weights[weight_type]["directory_name"]
|
||||
path = os_path.join(root, "weights", "ctranslate2", weight_directory_name)
|
||||
try:
|
||||
# モデルロード可能かどうかで判定
|
||||
ctranslate2.Translator(path)
|
||||
compute_type = getBestComputeType("cpu", 0)
|
||||
ctranslate2.Translator(path, compute_type=compute_type)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def downloadCTranslate2Weight(root: str, weight_type: str = "small", callback: Callable = None, end_callback: Callable = None):
|
||||
def downloadCTranslate2Weight(root: str, weight_type: str = "m2m100_418M-ct2-int8", callback: Callable = None, end_callback: Callable = None):
|
||||
hf_repo = ctranslate2_weights[weight_type]["hf_repo"]
|
||||
files = list_repo_files(repo_id=hf_repo)
|
||||
path = os_path.join(root, "weights", "ctranslate2", ctranslate2_weights[weight_type]["directory_name"])
|
||||
@@ -79,7 +79,7 @@ def downloadCTranslate2Weight(root: str, weight_type: str = "small", callback: C
|
||||
if end_callback is not None:
|
||||
end_callback()
|
||||
|
||||
def downloadCTranslate2Tokenizer(path: str, weight_type: str = "small"):
|
||||
def downloadCTranslate2Tokenizer(path: str, weight_type: str = "m2m100_418M-ct2-int8"):
|
||||
directory_name = ctranslate2_weights[weight_type]["directory_name"]
|
||||
tokenizer = ctranslate2_weights[weight_type]["tokenizer"]
|
||||
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
||||
@@ -106,4 +106,9 @@ if __name__ == "__main__":
|
||||
# result = checkCTranslate2Weight(root, weight_type)
|
||||
# print(f"Model loadable: {result}")
|
||||
# break
|
||||
downloadCTranslate2Tokenizer(root, "small")
|
||||
# downloadCTranslate2Tokenizer(root, "m2m100_418M-ct2-int8")
|
||||
|
||||
# model download test
|
||||
downloadCTranslate2Weight(root, "nllb-200-distilled-1.3B", callback=progress_callback, end_callback=end_callback)
|
||||
result = checkCTranslate2Weight(root, "nllb-200-distilled-1.3B")
|
||||
print(f"Model loadable: {result}")
|
||||
Reference in New Issue
Block a user