diff --git a/config.py b/config.py index e9afe1d9..c1cc4759 100644 --- a/config.py +++ b/config.py @@ -8,6 +8,7 @@ from tkinter import font from languages import selectable_languages from models.translation.translation_languages import translatorEngine from models.transcription.transcription_utils import getInputDevices, getDefaultInputDevice +from models.translation.utils import ctranslate2_weights from utils import generatePercentageStringsList, isUniqueStrings json_serializable_vars = {} @@ -552,18 +553,7 @@ class Config: self._GITHUB_URL = "https://api.github.com/repos/misyaguziya/VRCT/releases/latest" self._BOOTH_URL = "https://misyaguziya.booth.pm/" self._DOCUMENTS_URL = "https://mzsoftware.notion.site/VRCT-Documents-be79b7a165f64442ad8f326d86c22246" - self._CTRANSLATE2_WIGHTS = { - "small": { # M2M-100 418M-parameter model - "url": "https://bit.ly/33fM1AO", - "directory_name": "m2m100_418m", - "tokenizer": "facebook/m2m100_418M" - }, - "large": { # M2M-100 1.2B-parameter model - "url": "https://bit.ly/3GYiaed", - "directory_name": "m2m100_12b", - "tokenizer": "facebook/m2m100_12b" - }, - } + self._CTRANSLATE2_WEIGHTS = ctranslate2_weights self._MAX_MIC_ENERGY_THRESHOLD = 2000 self._MAX_SPEAKER_ENERGY_THRESHOLD = 4000 diff --git a/main.py b/main.py index 14c97ac4..ff0aa6a6 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ if __name__ == "__main__": from config import config from models.translation.utils import downloadCTranslate2Weight - downloadCTranslate2Weight(config.PATH_LOCAL, config.WEIGHT_TYPE, config.CTRANSLATE2_WIGHTS) + downloadCTranslate2Weight(config.PATH_LOCAL, config.WEIGHT_TYPE, config.CTRANSLATE2_WEIGHTS, print) import controller controller.createMainWindow() diff --git a/models/translation/utils.py b/models/translation/utils.py new file mode 100644 index 00000000..5b15f16a --- /dev/null +++ b/models/translation/utils.py @@ -0,0 +1,54 @@ +import tempfile +from zipfile import ZipFile +from os import path as os_path +from os import makedirs as os_makedirs +from requests import get as requests_get, head as requests_head +from tqdm import tqdm +from typing import Callable + +ctranslate2_weights = { + "small": { # M2M-100 418M-parameter model + "url": "https://bit.ly/33fM1AO", + "directory_name": "m2m100_418m", + "tokenizer": "facebook/m2m100_418M" + }, + "large": { # M2M-100 1.2B-parameter model + "url": "https://bit.ly/3GYiaed", + "directory_name": "m2m100_12b", + "tokenizer": "facebook/m2m100_12b" + }, +} + +def downloadCTranslate2Weight(path, weight_type="small", ctranslate2_weights=ctranslate2_weights, func=None): + url = ctranslate2_weights[weight_type]["url"] + filename = 'weight.zip' + directory_name = 'weight' + current_directory = path + weight_directory_name = ctranslate2_weights[weight_type]["directory_name"] + files = ["model.bin", "sentencepiece.model", "shared_vocabulary.txt"] + + # check already downloaded + if all(os_path.exists(os_path.join(current_directory, directory_name, weight_directory_name, file)) for file in files): + return + + try: + os_makedirs(os_path.join(current_directory, directory_name), exist_ok=True) + print(os_path.join(current_directory, directory_name)) + with tempfile.TemporaryDirectory() as tmp_path: + res = requests_get(url, stream=True) + file_size = int(res.headers.get('content-length', 0)) + pbar = tqdm(total=file_size, unit="B", unit_scale=True) + total_chunk = 0 + with open(os_path.join(tmp_path, filename), 'wb') as file: + for chunk in res.iter_content(chunk_size=1024): + file.write(chunk) + pbar.update(len(chunk)) + if isinstance(func, Callable): + total_chunk += len(chunk) + func(total_chunk/file_size) + pbar.close() + + with ZipFile(os_path.join(tmp_path, filename)) as zf: + zf.extractall(os_path.join(current_directory, directory_name)) + except Exception as e: + print("error:downloadCTranslate2Weight()", e) \ No newline at end of file