From 6fb7ecbc52d9c1e98b2ea119b8ccb4d27826f649 Mon Sep 17 00:00:00 2001 From: misyaguziya <53165965+misyaguziya@users.noreply.github.com> Date: Sun, 7 Sep 2025 23:59:20 +0900 Subject: [PATCH] [Update] translation_utils.py: Refactor weight handling and improve error logging add translate model - jncraton/m2m100_418M-ct2-int8 - jncraton/m2m100_1.2B-ct2-int8 - OpenNMT/nllb-200-3.3B-ct2-int8 - OpenNMT/nllb-200-distilled-1.3B-ct2-int8 --- .../models/translation/translation_utils.py | 154 +++++++++--------- 1 file changed, 80 insertions(+), 74 deletions(-) diff --git a/src-python/models/translation/translation_utils.py b/src-python/models/translation/translation_utils.py index 457a65f1..ef0e5592 100644 --- a/src-python/models/translation/translation_utils.py +++ b/src-python/models/translation/translation_utils.py @@ -1,103 +1,109 @@ -import tempfile -from zipfile import ZipFile from os import path as os_path from os import makedirs as os_makedirs -from requests import get as requests_get from typing import Callable -import hashlib import transformers -from utils import errorLogging +import ctranslate2 +from huggingface_hub import hf_hub_url, list_repo_files +from requests import get as requests_get +try: + from utils import errorLogging +except Exception: + import traceback + def errorLogging(): + print(traceback.format_exc()) ctranslate2_weights = { - "small": { # M2M-100 418M-parameter model - "url": "https://github.com/misyaguziya/VRCT-weights/releases/download/v1.0/m2m100_418m.zip", - "directory_name": "m2m100_418m", - "tokenizer": "facebook/m2m100_418M", - "hash": { - "model.bin": "e7c26a9abb5260abd0268fbe3040714070dec254a990b4d7fd3f74c5230e3acb", - "sentencepiece.model": "d8f7c76ed2a5e0822be39f0a4f95a55eb19c78f4593ce609e2edbc2aea4d380a", - "shared_vocabulary.txt": "bd440aa21b8ca3453fc792a0018a1f3fe68b3464aadddd4d16a4b72f73c86d8c", - } + "small": { + # "hf_repo": "jncraton/m2m100_418M-ct2-int8", + # "directory_name": "m2m100_418M-ct2-int8", + # "tokenizer": "facebook/m2m100_418M", + "hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8", + "directory_name": "nllb-200-distilled-1.3B-ct2-int8", + "tokenizer": "facebook/nllb-200-distilled-1.3B", }, - "large": { # M2M-100 1.2B-parameter model - "url": "https://github.com/misyaguziya/VRCT-weights/releases/download/v1.0/m2m100_12b.zip", - "directory_name": "m2m100_12b", - "tokenizer": "facebook/m2m100_1.2b", - "hash": { - "model.bin": "abb7bf4ba7e5e016b6e3ed480c752459b2f783ac8fca372e7587675e5bf3a919", - "sentencepiece.model": "d8f7c76ed2a5e0822be39f0a4f95a55eb19c78f4593ce609e2edbc2aea4d380a", - "shared_vocabulary.txt": "bd440aa21b8ca3453fc792a0018a1f3fe68b3464aadddd4d16a4b72f73c86d8c", - } + "large": { + "hf_repo": "jncraton/m2m100_1.2B-ct2-int8", + "directory_name": "m2m100_1.2B-ct2-int8", + "tokenizer": "facebook/m2m100_1.2B", + }, + "nllb-200-3.3B-ct2-int8": { + "hf_repo": "OpenNMT/nllb-200-3.3B-ct2-int8", + "directory_name": "nllb-200-3.3B-ct2-int8", + "tokenizer": "facebook/nllb-200-3.3B", + }, + "nllb-200-distilled-1.3B": { + "hf_repo": "OpenNMT/nllb-200-distilled-1.3B-ct2-int8", + "directory_name": "nllb-200-distilled-1.3B-ct2-int8", + "tokenizer": "facebook/nllb-200-distilled-1.3B", }, } -def calculate_file_hash(file_path, block_size=65536): - hash_object = hashlib.sha256() - - with open(file_path, 'rb') as file: - for block in iter(lambda: file.read(block_size), b''): - hash_object.update(block) - - return hash_object.hexdigest() - -def checkCTranslate2Weight(root, weight_type="small"): +def checkCTranslate2Weight(root: str, weight_type: str = "small"): weight_directory_name = ctranslate2_weights[weight_type]["directory_name"] - hash_data = ctranslate2_weights[weight_type]["hash"] - files = [ - "model.bin", - "sentencepiece.model", - "shared_vocabulary.txt" - ] - path = os_path.join(root, "weights", "ctranslate2") + path = os_path.join(root, "weights", "ctranslate2", weight_directory_name) + try: + # モデルロード可能かどうかで判定 + ctranslate2.Translator(path) + return True + except Exception: + return False - # check already downloaded - already_downloaded = False - if all(os_path.exists(os_path.join(path, weight_directory_name, file)) for file in files): - # check hash - for file in files: - original_hash = hash_data[file] - current_hash = calculate_file_hash(os_path.join(path, weight_directory_name, file)) - if original_hash != current_hash: - break - already_downloaded = True - return already_downloaded - -def downloadCTranslate2Weight(root, weight_type="small", callback=None, end_callback=None): - url = ctranslate2_weights[weight_type]["url"] - filename = "weight.zip" - path = os_path.join(root, "weights", "ctranslate2") +def downloadCTranslate2Weight(root: str, weight_type: str = "small", callback: Callable = None, end_callback: Callable = None): + hf_repo = ctranslate2_weights[weight_type]["hf_repo"] + files = list_repo_files(repo_id=hf_repo) + path = os_path.join(root, "weights", "ctranslate2", ctranslate2_weights[weight_type]["directory_name"]) + if checkCTranslate2Weight(root, weight_type): + return True os_makedirs(path, exist_ok=True) - if checkCTranslate2Weight(root, weight_type) is False: + def downloadFile(url: str, file_path: str, func: Callable = None): try: - with tempfile.TemporaryDirectory() as tmp_path: - res = requests_get(url, stream=True) - file_size = int(res.headers.get('content-length', 0)) - total_chunk = 0 - with open(os_path.join(tmp_path, filename), 'wb') as file: - for chunk in res.iter_content(chunk_size=1024*2000): - file.write(chunk) - if isinstance(callback, Callable): - total_chunk += len(chunk) - callback(total_chunk/file_size) - - with ZipFile(os_path.join(tmp_path, filename)) as zf: - zf.extractall(path) + res = requests_get(url, stream=True) + res.raise_for_status() + file_size = int(res.headers.get('content-length', 0)) + total_chunk = 0 + with open(file_path, 'wb') as file: + for chunk in res.iter_content(chunk_size=1024*2000): + file.write(chunk) + if func is not None: + total_chunk += len(chunk) + func(total_chunk/file_size) except Exception: errorLogging() - if isinstance(end_callback, Callable): + for filename in files: + file_path = os_path.join(path, filename) + url = hf_hub_url(hf_repo, filename) + downloadFile(url, file_path, func=callback if filename == "model.bin" else None) + + if end_callback is not None: end_callback() -def downloadCTranslate2Tokenizer(path, weight_type="small"): +def downloadCTranslate2Tokenizer(path: str, weight_type: str = "small"): directory_name = ctranslate2_weights[weight_type]["directory_name"] tokenizer = ctranslate2_weights[weight_type]["tokenizer"] tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer") - try: os_makedirs(tokenizer_path, exist_ok=True) transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) except Exception: errorLogging() tokenizer_path = os_path.join("./weights", "ctranslate2", directory_name, "tokenizer") - transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) \ No newline at end of file + transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) + +# テスト用コード(直接実行時のみ) +if __name__ == "__main__": + def progress_callback(percent): + print(f"Download progress: {percent*100:.2f}%") + + def end_callback(): + print("Download finished.") + + root = "./" # 必要に応じてパスを変更 + # for weight_type in ctranslate2_weights.keys(): + # print(f"Testing download for: {weight_type}") + # downloadCTranslate2Weight(root, weight_type, callback=progress_callback, end_callback=end_callback) + # result = checkCTranslate2Weight(root, weight_type) + # print(f"Model loadable: {result}") + # break + downloadCTranslate2Tokenizer(root, "small")