[Update] translation: Add support for new translation models and improve weight handling
This commit is contained in:
@@ -1188,7 +1188,7 @@ class Config:
|
|||||||
self._USE_EXCLUDE_WORDS = True
|
self._USE_EXCLUDE_WORDS = True
|
||||||
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||||
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||||
self._CTRANSLATE2_WEIGHT_TYPE = "small"
|
self._CTRANSLATE2_WEIGHT_TYPE = "m2m100_418M-ct2-int8"
|
||||||
self._WHISPER_WEIGHT_TYPE = "base"
|
self._WHISPER_WEIGHT_TYPE = "base"
|
||||||
self._AUTO_CLEAR_MESSAGE_BOX = True
|
self._AUTO_CLEAR_MESSAGE_BOX = True
|
||||||
self._SEND_ONLY_TRANSLATED_MESSAGES = False
|
self._SEND_ONLY_TRANSLATED_MESSAGES = False
|
||||||
|
|||||||
@@ -190,6 +190,7 @@ class Model:
|
|||||||
success_flag = False
|
success_flag = False
|
||||||
translation = self.translator.translate(
|
translation = self.translator.translate(
|
||||||
translator_name=translator_name,
|
translator_name=translator_name,
|
||||||
|
weight_type=config.CTRANSLATE2_WEIGHT_TYPE,
|
||||||
source_language=source_language,
|
source_language=source_language,
|
||||||
target_language=target_language,
|
target_language=target_language,
|
||||||
target_country=target_country,
|
target_country=target_country,
|
||||||
@@ -203,6 +204,7 @@ class Model:
|
|||||||
while True:
|
while True:
|
||||||
translation = self.translator.translate(
|
translation = self.translator.translate(
|
||||||
translator_name="CTranslate2",
|
translator_name="CTranslate2",
|
||||||
|
weight_type=config.CTRANSLATE2_WEIGHT_TYPE,
|
||||||
source_language=source_language,
|
source_language=source_language,
|
||||||
target_language=target_language,
|
target_language=target_language,
|
||||||
target_country=target_country,
|
target_country=target_country,
|
||||||
|
|||||||
@@ -275,7 +275,7 @@ translation_lang["Papago"] = {
|
|||||||
"target":dict_papago_languages,
|
"target":dict_papago_languages,
|
||||||
}
|
}
|
||||||
|
|
||||||
dict_ctranslate2_languages = {
|
dict_m2m100_languages = {
|
||||||
"English": "en",
|
"English": "en",
|
||||||
"Chinese Simplified": "zh",
|
"Chinese Simplified": "zh",
|
||||||
"Chinese Traditional":"zh",
|
"Chinese Traditional":"zh",
|
||||||
@@ -378,7 +378,229 @@ dict_ctranslate2_languages = {
|
|||||||
"Sundanese": "su"
|
"Sundanese": "su"
|
||||||
}
|
}
|
||||||
|
|
||||||
translation_lang["CTranslate2"] = {
|
translation_lang["m2m100_418M-ct2-int8"] = {
|
||||||
"source":dict_ctranslate2_languages,
|
"source":dict_m2m100_languages,
|
||||||
"target":dict_ctranslate2_languages,
|
"target":dict_m2m100_languages,
|
||||||
|
}
|
||||||
|
|
||||||
|
translation_lang["m2m100_1.2B-ct2-int8"] = {
|
||||||
|
"source":dict_m2m100_languages,
|
||||||
|
"target":dict_m2m100_languages,
|
||||||
|
}
|
||||||
|
|
||||||
|
dict_nllb_languages = {
|
||||||
|
"Acehnese (Arabic script)": "ace_Arab",
|
||||||
|
"Acehnese (Latin script)": "ace_Latn",
|
||||||
|
"Mesopotamian Arabic": "acm_Arab",
|
||||||
|
"Ta’izzi-Adeni Arabic": "acq_Arab",
|
||||||
|
"Tunisian Arabic": "aeb_Arab",
|
||||||
|
"Afrikaans": "afr_Latn",
|
||||||
|
"South Levantine Arabic": "ajp_Arab",
|
||||||
|
"Akan": "aka_Latn",
|
||||||
|
"Amharic": "amh_Ethi",
|
||||||
|
"North Levantine Arabic": "apc_Arab",
|
||||||
|
"Modern Standard Arabic": "arb_Arab",
|
||||||
|
"Modern Standard Arabic (Romanized)": "arb_Latn",
|
||||||
|
"Najdi Arabic": "ars_Arab",
|
||||||
|
"Moroccan Arabic": "ary_Arab",
|
||||||
|
"Egyptian Arabic": "arz_Arab",
|
||||||
|
"Assamese": "asm_Beng",
|
||||||
|
"Asturian": "ast_Latn",
|
||||||
|
"Awadhi": "awa_Deva",
|
||||||
|
"Central Aymara": "ayr_Latn",
|
||||||
|
"South Azerbaijani": "azb_Arab",
|
||||||
|
"North Azerbaijani": "azj_Latn",
|
||||||
|
"Bashkir": "bak_Cyrl",
|
||||||
|
"Bambara": "bam_Latn",
|
||||||
|
"Balinese": "ban_Latn",
|
||||||
|
"Belarusian": "bel_Cyrl",
|
||||||
|
"Bemba": "bem_Latn",
|
||||||
|
"Bengali": "ben_Beng",
|
||||||
|
"Bhojpuri": "bho_Deva",
|
||||||
|
"Banjar (Arabic script)": "bjn_Arab",
|
||||||
|
"Banjar (Latin script)": "bjn_Latn",
|
||||||
|
"Standard Tibetan": "bod_Tibt",
|
||||||
|
"Bosnian": "bos_Latn",
|
||||||
|
"Buginese": "bug_Latn",
|
||||||
|
"Bulgarian": "bul_Cyrl",
|
||||||
|
"Catalan": "cat_Latn",
|
||||||
|
"Cebuano": "ceb_Latn",
|
||||||
|
"Czech": "ces_Latn",
|
||||||
|
"Chokwe": "cjk_Latn",
|
||||||
|
"Central Kurdish": "ckb_Arab",
|
||||||
|
"Crimean Tatar": "crh_Latn",
|
||||||
|
"Welsh": "cym_Latn",
|
||||||
|
"Danish": "dan_Latn",
|
||||||
|
"German": "deu_Latn",
|
||||||
|
"Southwestern Dinka": "dik_Latn",
|
||||||
|
"Dyula": "dyu_Latn",
|
||||||
|
"Dzongkha": "dzo_Tibt",
|
||||||
|
"Greek": "ell_Grek",
|
||||||
|
"English": "eng_Latn",
|
||||||
|
"Esperanto": "epo_Latn",
|
||||||
|
"Estonian": "est_Latn",
|
||||||
|
"Basque": "eus_Latn",
|
||||||
|
"Ewe": "ewe_Latn",
|
||||||
|
"Faroese": "fao_Latn",
|
||||||
|
"Fijian": "fij_Latn",
|
||||||
|
"Finnish": "fin_Latn",
|
||||||
|
"Fon": "fon_Latn",
|
||||||
|
"French": "fra_Latn",
|
||||||
|
"Friulian": "fur_Latn",
|
||||||
|
"Nigerian Fulfulde": "fuv_Latn",
|
||||||
|
"Scottish Gaelic": "gla_Latn",
|
||||||
|
"Irish": "gle_Latn",
|
||||||
|
"Galician": "glg_Latn",
|
||||||
|
"Guarani": "grn_Latn",
|
||||||
|
"Gujarati": "guj_Gujr",
|
||||||
|
"Haitian Creole": "hat_Latn",
|
||||||
|
"Hausa": "hau_Latn",
|
||||||
|
"Hebrew": "heb_Hebr",
|
||||||
|
"Hindi": "hin_Deva",
|
||||||
|
"Chhattisgarhi": "hne_Deva",
|
||||||
|
"Croatian": "hrv_Latn",
|
||||||
|
"Hungarian": "hun_Latn",
|
||||||
|
"Armenian": "hye_Armn",
|
||||||
|
"Igbo": "ibo_Latn",
|
||||||
|
"Ilocano": "ilo_Latn",
|
||||||
|
"Indonesian": "ind_Latn",
|
||||||
|
"Icelandic": "isl_Latn",
|
||||||
|
"Italian": "ita_Latn",
|
||||||
|
"Javanese": "jav_Latn",
|
||||||
|
"Japanese": "jpn_Jpan",
|
||||||
|
"Kabyle": "kab_Latn",
|
||||||
|
"Jingpho": "kac_Latn",
|
||||||
|
"Kamba": "kam_Latn",
|
||||||
|
"Kannada": "kan_Knda",
|
||||||
|
"Kashmiri (Arabic script)": "kas_Arab",
|
||||||
|
"Kashmiri (Devanagari script)": "kas_Deva",
|
||||||
|
"Georgian": "kat_Geor",
|
||||||
|
"Central Kanuri (Arabic script)": "knc_Arab",
|
||||||
|
"Central Kanuri (Latin script)": "knc_Latn",
|
||||||
|
"Kazakh": "kaz_Cyrl",
|
||||||
|
"Kabiyè": "kbp_Latn",
|
||||||
|
"Kabuverdianu": "kea_Latn",
|
||||||
|
"Khmer": "khm_Khmr",
|
||||||
|
"Kikuyu": "kik_Latn",
|
||||||
|
"Kinyarwanda": "kin_Latn",
|
||||||
|
"Kyrgyz": "kir_Cyrl",
|
||||||
|
"Kimbundu": "kmb_Latn",
|
||||||
|
"Northern Kurdish": "kmr_Latn",
|
||||||
|
"Kikongo": "kon_Latn",
|
||||||
|
"Korean": "kor_Hang",
|
||||||
|
"Lao": "lao_Laoo",
|
||||||
|
"Ligurian": "lij_Latn",
|
||||||
|
"Limburgish": "lim_Latn",
|
||||||
|
"Lingala": "lin_Latn",
|
||||||
|
"Lithuanian": "lit_Latn",
|
||||||
|
"Lombard": "lmo_Latn",
|
||||||
|
"Latgalian": "ltg_Latn",
|
||||||
|
"Luxembourgish": "ltz_Latn",
|
||||||
|
"Luba-Kasai": "lua_Latn",
|
||||||
|
"Ganda": "lug_Latn",
|
||||||
|
"Luo": "luo_Latn",
|
||||||
|
"Mizo": "lus_Latn",
|
||||||
|
"Standard Latvian": "lvs_Latn",
|
||||||
|
"Magahi": "mag_Deva",
|
||||||
|
"Maithili": "mai_Deva",
|
||||||
|
"Malayalam": "mal_Mlym",
|
||||||
|
"Marathi": "mar_Deva",
|
||||||
|
"Minangkabau (Arabic script)": "min_Arab",
|
||||||
|
"Minangkabau (Latin script)": "min_Latn",
|
||||||
|
"Macedonian": "mkd_Cyrl",
|
||||||
|
"Plateau Malagasy": "plt_Latn",
|
||||||
|
"Maltese": "mlt_Latn",
|
||||||
|
"Meitei (Bengali script)": "mni_Beng",
|
||||||
|
"Halh Mongolian": "khk_Cyrl",
|
||||||
|
"Mossi": "mos_Latn",
|
||||||
|
"Maori": "mri_Latn",
|
||||||
|
"Burmese": "mya_Mymr",
|
||||||
|
"Dutch": "nld_Latn",
|
||||||
|
"Norwegian Nynorsk": "nno_Latn",
|
||||||
|
"Norwegian Bokmål": "nob_Latn",
|
||||||
|
"Nepali": "npi_Deva",
|
||||||
|
"Northern Sotho": "nso_Latn",
|
||||||
|
"Nuer": "nus_Latn",
|
||||||
|
"Nyanja": "nya_Latn",
|
||||||
|
"Occitan": "oci_Latn",
|
||||||
|
"West Central Oromo": "gaz_Latn",
|
||||||
|
"Odia": "ory_Orya",
|
||||||
|
"Pangasinan": "pag_Latn",
|
||||||
|
"Eastern Panjabi": "pan_Guru",
|
||||||
|
"Papiamento": "pap_Latn",
|
||||||
|
"Western Persian": "pes_Arab",
|
||||||
|
"Polish": "pol_Latn",
|
||||||
|
"Portuguese": "por_Latn",
|
||||||
|
"Dari": "prs_Arab",
|
||||||
|
"Southern Pashto": "pbt_Arab",
|
||||||
|
"Ayacucho Quechua": "quy_Latn",
|
||||||
|
"Romanian": "ron_Latn",
|
||||||
|
"Rundi": "run_Latn",
|
||||||
|
"Russian": "rus_Cyrl",
|
||||||
|
"Sango": "sag_Latn",
|
||||||
|
"Sanskrit": "san_Deva",
|
||||||
|
"Santali": "sat_Olck",
|
||||||
|
"Sicilian": "scn_Latn",
|
||||||
|
"Shan": "shn_Mymr",
|
||||||
|
"Sinhala": "sin_Sinh",
|
||||||
|
"Slovak": "slk_Latn",
|
||||||
|
"Slovenian": "slv_Latn",
|
||||||
|
"Samoan": "smo_Latn",
|
||||||
|
"Shona": "sna_Latn",
|
||||||
|
"Sindhi": "snd_Arab",
|
||||||
|
"Somali": "som_Latn",
|
||||||
|
"Southern Sotho": "sot_Latn",
|
||||||
|
"Spanish": "spa_Latn",
|
||||||
|
"Tosk Albanian": "als_Latn",
|
||||||
|
"Sardinian": "srd_Latn",
|
||||||
|
"Serbian": "srp_Cyrl",
|
||||||
|
"Swati": "ssw_Latn",
|
||||||
|
"Sundanese": "sun_Latn",
|
||||||
|
"Swedish": "swe_Latn",
|
||||||
|
"Swahili": "swh_Latn",
|
||||||
|
"Silesian": "szl_Latn",
|
||||||
|
"Tamil": "tam_Taml",
|
||||||
|
"Tatar": "tat_Cyrl",
|
||||||
|
"Telugu": "tel_Telu",
|
||||||
|
"Tajik": "tgk_Cyrl",
|
||||||
|
"Tagalog": "tgl_Latn",
|
||||||
|
"Thai": "tha_Thai",
|
||||||
|
"Tigrinya": "tir_Ethi",
|
||||||
|
"Tamasheq (Latin script)": "taq_Latn",
|
||||||
|
"Tamasheq (Tifinagh script)": "taq_Tfng",
|
||||||
|
"Tok Pisin": "tpi_Latn",
|
||||||
|
"Tswana": "tsn_Latn",
|
||||||
|
"Tsonga": "tso_Latn",
|
||||||
|
"Turkmen": "tuk_Latn",
|
||||||
|
"Tumbuka": "tum_Latn",
|
||||||
|
"Turkish": "tur_Latn",
|
||||||
|
"Twi": "twi_Latn",
|
||||||
|
"Central Atlas Tamazight": "tzm_Tfng",
|
||||||
|
"Uyghur": "uig_Arab",
|
||||||
|
"Ukrainian": "ukr_Cyrl",
|
||||||
|
"Umbundu": "umb_Latn",
|
||||||
|
"Urdu": "urd_Arab",
|
||||||
|
"Northern Uzbek": "uzn_Latn",
|
||||||
|
"Venetian": "vec_Latn",
|
||||||
|
"Vietnamese": "vie_Latn",
|
||||||
|
"Waray": "war_Latn",
|
||||||
|
"Wolof": "wol_Latn",
|
||||||
|
"Xhosa": "xho_Latn",
|
||||||
|
"Eastern Yiddish": "ydd_Hebr",
|
||||||
|
"Yoruba": "yor_Latn",
|
||||||
|
"Yue Chinese": "yue_Hant",
|
||||||
|
"Chinese Simplified": "zho_Hans",
|
||||||
|
"Chinese Traditional": "zho_Hant",
|
||||||
|
"Standard Malay": "zsm_Latn",
|
||||||
|
"Zulu": "zul_Latn"
|
||||||
|
}
|
||||||
|
|
||||||
|
translation_lang["nllb-200-distilled-1.3B-ct2-int8"] = {
|
||||||
|
"source":dict_nllb_languages,
|
||||||
|
"target":dict_nllb_languages,
|
||||||
|
}
|
||||||
|
|
||||||
|
translation_lang["nllb-200-3.3B-ct2-int8"] = {
|
||||||
|
"source":dict_nllb_languages,
|
||||||
|
"target":dict_nllb_languages,
|
||||||
}
|
}
|
||||||
@@ -6,8 +6,15 @@ try:
|
|||||||
except Exception:
|
except Exception:
|
||||||
ENABLE_TRANSLATORS = False
|
ENABLE_TRANSLATORS = False
|
||||||
|
|
||||||
from .translation_languages import translation_lang
|
try:
|
||||||
from .translation_utils import ctranslate2_weights
|
from .translation_languages import translation_lang
|
||||||
|
from .translation_utils import ctranslate2_weights
|
||||||
|
except Exception:
|
||||||
|
import sys
|
||||||
|
print(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__)))))
|
||||||
|
sys.path.append(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__)))))
|
||||||
|
from translation_languages import translation_lang
|
||||||
|
from translation_utils import ctranslate2_weights
|
||||||
|
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import transformers
|
import transformers
|
||||||
@@ -63,13 +70,19 @@ class Translator():
|
|||||||
def isLoadedCTranslate2Model(self):
|
def isLoadedCTranslate2Model(self):
|
||||||
return self.is_loaded_ctranslate2_model
|
return self.is_loaded_ctranslate2_model
|
||||||
|
|
||||||
def translateCTranslate2(self, message, source_language, target_language):
|
def translateCTranslate2(self, message, source_language, target_language, weight_type):
|
||||||
result = False
|
result = False
|
||||||
if self.is_loaded_ctranslate2_model is True:
|
if self.is_loaded_ctranslate2_model is True:
|
||||||
try:
|
try:
|
||||||
self.ctranslate2_tokenizer.src_lang = source_language
|
self.ctranslate2_tokenizer.src_lang = source_language
|
||||||
source = self.ctranslate2_tokenizer.convert_ids_to_tokens(self.ctranslate2_tokenizer.encode(message))
|
source = self.ctranslate2_tokenizer.convert_ids_to_tokens(self.ctranslate2_tokenizer.encode(message))
|
||||||
target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]]
|
match weight_type:
|
||||||
|
case "m2m100_418M-ct2-int8" | "m2m100_1.2B-ct2-int8":
|
||||||
|
target_prefix = [self.ctranslate2_tokenizer.lang_code_to_token[target_language]]
|
||||||
|
case "nllb-200-distilled-1.3B-ct2-int8" | "nllb-200-3.3B-ct2-int8":
|
||||||
|
target_prefix = [target_language]
|
||||||
|
case _:
|
||||||
|
return False
|
||||||
results = self.ctranslate2_translator.translate_batch([source], target_prefix=[target_prefix])
|
results = self.ctranslate2_translator.translate_batch([source], target_prefix=[target_prefix])
|
||||||
target = results[0].hypotheses[0][1:]
|
target = results[0].hypotheses[0][1:]
|
||||||
result = self.ctranslate2_tokenizer.decode(self.ctranslate2_tokenizer.convert_tokens_to_ids(target))
|
result = self.ctranslate2_tokenizer.decode(self.ctranslate2_tokenizer.convert_tokens_to_ids(target))
|
||||||
@@ -78,7 +91,7 @@ class Translator():
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getLanguageCode(translator_name, target_country, source_language, target_language):
|
def getLanguageCode(translator_name, weight_type, target_country, source_language, target_language):
|
||||||
match translator_name:
|
match translator_name:
|
||||||
case "DeepL_API":
|
case "DeepL_API":
|
||||||
if target_language == "English":
|
if target_language == "English":
|
||||||
@@ -91,16 +104,18 @@ class Translator():
|
|||||||
target_language = "Portuguese European"
|
target_language = "Portuguese European"
|
||||||
else:
|
else:
|
||||||
target_language = "Portuguese Brazilian"
|
target_language = "Portuguese Brazilian"
|
||||||
|
case "CTranslate2":
|
||||||
|
translator_name = weight_type
|
||||||
case _:
|
case _:
|
||||||
pass
|
pass
|
||||||
source_language=translation_lang[translator_name]["source"][source_language]
|
source_language=translation_lang[translator_name]["source"][source_language]
|
||||||
target_language=translation_lang[translator_name]["target"][target_language]
|
target_language=translation_lang[translator_name]["target"][target_language]
|
||||||
return source_language, target_language
|
return source_language, target_language
|
||||||
|
|
||||||
def translate(self, translator_name, source_language, target_language, target_country, message):
|
def translate(self, translator_name, weight_type, source_language, target_language, target_country, message):
|
||||||
try:
|
try:
|
||||||
result = ""
|
result = ""
|
||||||
source_language, target_language = self.getLanguageCode(translator_name, target_country, source_language, target_language)
|
source_language, target_language = self.getLanguageCode(translator_name, weight_type, target_country, source_language, target_language)
|
||||||
match translator_name:
|
match translator_name:
|
||||||
case "DeepL":
|
case "DeepL":
|
||||||
if self.is_enable_translators is True:
|
if self.is_enable_translators is True:
|
||||||
@@ -149,8 +164,23 @@ class Translator():
|
|||||||
message=message,
|
message=message,
|
||||||
source_language=source_language,
|
source_language=source_language,
|
||||||
target_language=target_language,
|
target_language=target_language,
|
||||||
|
weight_type=weight_type,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
errorLogging()
|
errorLogging()
|
||||||
result = False
|
result = False
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
translator = Translator()
|
||||||
|
# test CTranslate2 model nllb-200-distilled-1.3B-ct2-int8
|
||||||
|
translator.changeCTranslate2Model(path=".", model_type="nllb-200-distilled-1.3B-ct2-int8", device="cpu", device_index=0)
|
||||||
|
result = translator.translate(
|
||||||
|
translator_name="CTranslate2",
|
||||||
|
weight_type="nllb-200-distilled-1.3B-ct2-int8",
|
||||||
|
source_language="English",
|
||||||
|
target_language="Japanese",
|
||||||
|
target_country="Japan",
|
||||||
|
message="Hello, world!"
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
@@ -5,50 +5,50 @@ import transformers
|
|||||||
import ctranslate2
|
import ctranslate2
|
||||||
from huggingface_hub import hf_hub_url, list_repo_files
|
from huggingface_hub import hf_hub_url, list_repo_files
|
||||||
from requests import get as requests_get
|
from requests import get as requests_get
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from utils import errorLogging
|
from utils import errorLogging, getBestComputeType
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import sys
|
||||||
def errorLogging():
|
print(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__)))))
|
||||||
print(traceback.format_exc())
|
sys.path.append(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__)))))
|
||||||
|
from utils import errorLogging, getBestComputeType
|
||||||
|
|
||||||
ctranslate2_weights = {
|
ctranslate2_weights = {
|
||||||
"small": {
|
"m2m100_418M-ct2-int8": {
|
||||||
# "hf_repo": "jncraton/m2m100_418M-ct2-int8",
|
"hf_repo": "jncraton/m2m100_418M-ct2-int8",
|
||||||
# "directory_name": "m2m100_418M-ct2-int8",
|
"directory_name": "m2m100_418M-ct2-int8",
|
||||||
# "tokenizer": "facebook/m2m100_418M",
|
"tokenizer": "facebook/m2m100_418M",
|
||||||
"hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8",
|
|
||||||
"directory_name": "nllb-200-distilled-1.3B-ct2-int8",
|
|
||||||
"tokenizer": "facebook/nllb-200-distilled-1.3B",
|
|
||||||
},
|
},
|
||||||
"large": {
|
"m2m100_1.2B-ct2-int8": {
|
||||||
"hf_repo": "jncraton/m2m100_1.2B-ct2-int8",
|
"hf_repo": "jncraton/m2m100_1.2B-ct2-int8",
|
||||||
"directory_name": "m2m100_1.2B-ct2-int8",
|
"directory_name": "m2m100_1.2B-ct2-int8",
|
||||||
"tokenizer": "facebook/m2m100_1.2B",
|
"tokenizer": "facebook/m2m100_1.2B",
|
||||||
},
|
},
|
||||||
|
"nllb-200-distilled-1.3B-ct2-int8": {
|
||||||
|
"hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8",
|
||||||
|
"directory_name": "nllb-200-distilled-1.3B-ct2-int8",
|
||||||
|
"tokenizer": "facebook/nllb-200-distilled-1.3B",
|
||||||
|
},
|
||||||
"nllb-200-3.3B-ct2-int8": {
|
"nllb-200-3.3B-ct2-int8": {
|
||||||
"hf_repo": "OpenNMT/nllb-200-3.3B-ct2-int8",
|
"hf_repo": "OpenNMT/nllb-200-3.3B-ct2-int8",
|
||||||
"directory_name": "nllb-200-3.3B-ct2-int8",
|
"directory_name": "nllb-200-3.3B-ct2-int8",
|
||||||
"tokenizer": "facebook/nllb-200-3.3B",
|
"tokenizer": "facebook/nllb-200-3.3B",
|
||||||
},
|
},
|
||||||
"nllb-200-distilled-1.3B": {
|
|
||||||
"hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8",
|
|
||||||
"directory_name": "nllb-200-distilled-1.3B-ct2-int8",
|
|
||||||
"tokenizer": "facebook/nllb-200-distilled-1.3B",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def checkCTranslate2Weight(root: str, weight_type: str = "small"):
|
def checkCTranslate2Weight(root: str, weight_type: str = "m2m100_418M-ct2-int8"):
|
||||||
weight_directory_name = ctranslate2_weights[weight_type]["directory_name"]
|
weight_directory_name = ctranslate2_weights[weight_type]["directory_name"]
|
||||||
path = os_path.join(root, "weights", "ctranslate2", weight_directory_name)
|
path = os_path.join(root, "weights", "ctranslate2", weight_directory_name)
|
||||||
try:
|
try:
|
||||||
# モデルロード可能かどうかで判定
|
# モデルロード可能かどうかで判定
|
||||||
ctranslate2.Translator(path)
|
compute_type = getBestComputeType("cpu", 0)
|
||||||
|
ctranslate2.Translator(path, compute_type=compute_type)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def downloadCTranslate2Weight(root: str, weight_type: str = "small", callback: Callable = None, end_callback: Callable = None):
|
def downloadCTranslate2Weight(root: str, weight_type: str = "m2m100_418M-ct2-int8", callback: Callable = None, end_callback: Callable = None):
|
||||||
hf_repo = ctranslate2_weights[weight_type]["hf_repo"]
|
hf_repo = ctranslate2_weights[weight_type]["hf_repo"]
|
||||||
files = list_repo_files(repo_id=hf_repo)
|
files = list_repo_files(repo_id=hf_repo)
|
||||||
path = os_path.join(root, "weights", "ctranslate2", ctranslate2_weights[weight_type]["directory_name"])
|
path = os_path.join(root, "weights", "ctranslate2", ctranslate2_weights[weight_type]["directory_name"])
|
||||||
@@ -79,7 +79,7 @@ def downloadCTranslate2Weight(root: str, weight_type: str = "small", callback: C
|
|||||||
if end_callback is not None:
|
if end_callback is not None:
|
||||||
end_callback()
|
end_callback()
|
||||||
|
|
||||||
def downloadCTranslate2Tokenizer(path: str, weight_type: str = "small"):
|
def downloadCTranslate2Tokenizer(path: str, weight_type: str = "m2m100_418M-ct2-int8"):
|
||||||
directory_name = ctranslate2_weights[weight_type]["directory_name"]
|
directory_name = ctranslate2_weights[weight_type]["directory_name"]
|
||||||
tokenizer = ctranslate2_weights[weight_type]["tokenizer"]
|
tokenizer = ctranslate2_weights[weight_type]["tokenizer"]
|
||||||
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
||||||
@@ -106,4 +106,9 @@ if __name__ == "__main__":
|
|||||||
# result = checkCTranslate2Weight(root, weight_type)
|
# result = checkCTranslate2Weight(root, weight_type)
|
||||||
# print(f"Model loadable: {result}")
|
# print(f"Model loadable: {result}")
|
||||||
# break
|
# break
|
||||||
downloadCTranslate2Tokenizer(root, "small")
|
# downloadCTranslate2Tokenizer(root, "m2m100_418M-ct2-int8")
|
||||||
|
|
||||||
|
# model download test
|
||||||
|
downloadCTranslate2Weight(root, "nllb-200-distilled-1.3B", callback=progress_callback, end_callback=end_callback)
|
||||||
|
result = checkCTranslate2Weight(root, "nllb-200-distilled-1.3B")
|
||||||
|
print(f"Model loadable: {result}")
|
||||||
Reference in New Issue
Block a user