From e6e62cf35034f5787d77042ede1da69e39af0168 Mon Sep 17 00:00:00 2001 From: misyaguziya Date: Sun, 19 Nov 2023 00:03:57 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=8D[Update]=20Model=20:=20CTranslate2?= =?UTF-8?q?=E3=81=AEweight=E3=83=87=E3=83=BC=E3=82=BF=E3=82=92=E8=B5=B7?= =?UTF-8?q?=E5=8B=95=E6=99=82=E3=81=AB=E5=8F=96=E5=BE=97=E3=81=99=E3=82=8B?= =?UTF-8?q?=E3=82=88=E3=81=86=E3=81=AB=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- config.py | 28 ++++++++++++++ model.py | 39 +++++++++++++++++++- models/translation/translation_translator.py | 9 +++-- 4 files changed, 72 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 0d95c681..2223cde5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ memo.txt VRCT.spec *.pyc logs/ -.venv/ \ No newline at end of file +.venv/ +weight/ \ No newline at end of file diff --git a/config.py b/config.py index fb0406d2..e9afe1d9 100644 --- a/config.py +++ b/config.py @@ -63,6 +63,10 @@ class Config: def DOCUMENTS_URL(self): return self._DOCUMENTS_URL + @property + def CTRANSLATE2_WIGHTS(self): + return self._CTRANSLATE2_WIGHTS + @property def MAX_MIC_ENERGY_THRESHOLD(self): return self._MAX_MIC_ENERGY_THRESHOLD @@ -447,6 +451,17 @@ class Config: self._AUTH_KEYS[key] = value saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, self.AUTH_KEYS) + @property + @json_serializable('WEIGHT_TYPE') + def WEIGHT_TYPE(self): + return self._WEIGHT_TYPE + + @WEIGHT_TYPE.setter + def WEIGHT_TYPE(self, value): + if isinstance(value, str): + self._WEIGHT_TYPE = value + saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) + @property @json_serializable('MESSAGE_FORMAT') def MESSAGE_FORMAT(self): @@ -537,6 +552,18 @@ class Config: self._GITHUB_URL = "https://api.github.com/repos/misyaguziya/VRCT/releases/latest" self._BOOTH_URL = "https://misyaguziya.booth.pm/" self._DOCUMENTS_URL = "https://mzsoftware.notion.site/VRCT-Documents-be79b7a165f64442ad8f326d86c22246" + self._CTRANSLATE2_WIGHTS = { + "small": { # M2M-100 418M-parameter model + "url": "https://bit.ly/33fM1AO", + "directory_name": "m2m100_418m", + "tokenizer": "facebook/m2m100_418M" + }, + "large": { # M2M-100 1.2B-parameter model + "url": "https://bit.ly/3GYiaed", + "directory_name": "m2m100_12b", + "tokenizer": "facebook/m2m100_12b" + }, + } self._MAX_MIC_ENERGY_THRESHOLD = 2000 self._MAX_SPEAKER_ENERGY_THRESHOLD = 4000 @@ -594,6 +621,7 @@ class Config: "Bing": None, "Google": None, } + self.WEIGHT_TYPE = "small" self._MESSAGE_FORMAT = "[message]([translation])" self._ENABLE_AUTO_CLEAR_MESSAGE_BOX = True self._ENABLE_NOTICE_XSOVERLAY = False diff --git a/model.py b/model.py index 071a5ff0..21ea5fd7 100644 --- a/model.py +++ b/model.py @@ -1,3 +1,4 @@ +import tempfile from zipfile import ZipFile from subprocess import Popen from os import makedirs as os_makedirs @@ -9,9 +10,10 @@ from logging import getLogger, FileHandler, Formatter, INFO from time import sleep from queue import Queue from threading import Thread, Event -from requests import get as requests_get +from requests import get as requests_get, head as requests_head import webbrowser +from tqdm import tqdm from flashtext import KeywordProcessor from models.translation.translation_translator import Translator from models.transcription.transcription_utils import getInputDevices, getDefaultOutputDevice @@ -70,7 +72,8 @@ class Model: self.speaker_audio_recorder = None self.speaker_energy_recorder = None self.speaker_energy_plot_progressbar = None - self.translator = Translator(config.PATH_LOCAL) + self.downloadCTranslate2Weight() + self.translator = Translator(config.PATH_LOCAL, config.CTRANSLATE2_WIGHTS[config.WEIGHT_TYPE]) self.keyword_processor = KeywordProcessor() def resetTranslator(self): @@ -106,6 +109,38 @@ class Model: self.logger.disabled = True self.logger = None + @staticmethod + def downloadCTranslate2Weight(): + weight_type = config.WEIGHT_TYPE + url = config.CTRANSLATE2_WIGHTS[weight_type]["url"] + filename = 'weight.zip' + directory_name = 'weight' + current_directory = config.PATH_LOCAL + weight_directory_name = config.CTRANSLATE2_WIGHTS[weight_type]["directory_name"] + files = ["model.bin", "sentencepiece.model", "shared_vocabulary.txt"] + + # check already downloaded + if all(os_path.exists(os_path.join(current_directory, directory_name, weight_directory_name, file)) for file in files): + return + + try: + os_makedirs(os_path.join(current_directory, directory_name), exist_ok=True) + print(os_path.join(current_directory, directory_name)) + with tempfile.TemporaryDirectory() as tmp_path: + file_size = int(requests_head(url).headers["content-length"]) + res = requests_get(url, stream=True) + pbar = tqdm(total=file_size, unit="B", unit_scale=True) + with open(os_path.join(tmp_path, filename), 'wb') as file: + for chunk in res.iter_content(chunk_size=1024): + file.write(chunk) + pbar.update(len(chunk)) + pbar.close() + + with ZipFile(os_path.join(tmp_path, filename)) as zf: + zf.extractall(os_path.join(current_directory, directory_name)) + except Exception as e: + print("error:downloadCTranslate2Weight()", e) + @staticmethod def getListLanguageAndCountry(): langs = [] diff --git a/models/translation/translation_translator.py b/models/translation/translation_translator.py index e78c803d..aa9c6d5c 100644 --- a/models/translation/translation_translator.py +++ b/models/translation/translation_translator.py @@ -4,7 +4,6 @@ from deepl_translate import translate as deepl_web_Translator from translators import translate_text as other_web_Translator from .translation_languages import translation_lang -from ctranslate2.converters import TransformersConverter import ctranslate2 import transformers @@ -15,11 +14,13 @@ TRANSLATE_MODELS = { # Translator class Translator(): - def __init__(self, path): + def __init__(self, path, weight_config): self.translator_status = {} - self.weight_path = os.path.join(path, "weight") + directory_name = weight_config["directory_name"] + tokenizer = weight_config["tokenizer"] + self.weight_path = os.path.join(path, "weight", directory_name) self.translator = ctranslate2.Translator(self.weight_path, device="cpu", device_index=0, compute_type="int8", inter_threads=1, intra_threads=4) - self.tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M") + self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) def authentication(self, translator_name, authkey=None): result = True