👍[Update] Model : CTranslate2のweightデータを起動時に取得するように変更

This commit is contained in:
misyaguziya
2023-11-19 00:03:57 +09:00
parent dcb6c07eee
commit e6e62cf350
4 changed files with 72 additions and 7 deletions

1
.gitignore vendored
View File

@@ -6,3 +6,4 @@ VRCT.spec
*.pyc
logs/
.venv/
weight/

View File

@@ -63,6 +63,10 @@ class Config:
def DOCUMENTS_URL(self):
return self._DOCUMENTS_URL
@property
def CTRANSLATE2_WIGHTS(self):
return self._CTRANSLATE2_WIGHTS
@property
def MAX_MIC_ENERGY_THRESHOLD(self):
return self._MAX_MIC_ENERGY_THRESHOLD
@@ -447,6 +451,17 @@ class Config:
self._AUTH_KEYS[key] = value
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, self.AUTH_KEYS)
@property
@json_serializable('WEIGHT_TYPE')
def WEIGHT_TYPE(self):
return self._WEIGHT_TYPE
@WEIGHT_TYPE.setter
def WEIGHT_TYPE(self, value):
if isinstance(value, str):
self._WEIGHT_TYPE = value
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value)
@property
@json_serializable('MESSAGE_FORMAT')
def MESSAGE_FORMAT(self):
@@ -537,6 +552,18 @@ class Config:
self._GITHUB_URL = "https://api.github.com/repos/misyaguziya/VRCT/releases/latest"
self._BOOTH_URL = "https://misyaguziya.booth.pm/"
self._DOCUMENTS_URL = "https://mzsoftware.notion.site/VRCT-Documents-be79b7a165f64442ad8f326d86c22246"
self._CTRANSLATE2_WIGHTS = {
"small": { # M2M-100 418M-parameter model
"url": "https://bit.ly/33fM1AO",
"directory_name": "m2m100_418m",
"tokenizer": "facebook/m2m100_418M"
},
"large": { # M2M-100 1.2B-parameter model
"url": "https://bit.ly/3GYiaed",
"directory_name": "m2m100_12b",
"tokenizer": "facebook/m2m100_12b"
},
}
self._MAX_MIC_ENERGY_THRESHOLD = 2000
self._MAX_SPEAKER_ENERGY_THRESHOLD = 4000
@@ -594,6 +621,7 @@ class Config:
"Bing": None,
"Google": None,
}
self.WEIGHT_TYPE = "small"
self._MESSAGE_FORMAT = "[message]([translation])"
self._ENABLE_AUTO_CLEAR_MESSAGE_BOX = True
self._ENABLE_NOTICE_XSOVERLAY = False

View File

@@ -1,3 +1,4 @@
import tempfile
from zipfile import ZipFile
from subprocess import Popen
from os import makedirs as os_makedirs
@@ -9,9 +10,10 @@ from logging import getLogger, FileHandler, Formatter, INFO
from time import sleep
from queue import Queue
from threading import Thread, Event
from requests import get as requests_get
from requests import get as requests_get, head as requests_head
import webbrowser
from tqdm import tqdm
from flashtext import KeywordProcessor
from models.translation.translation_translator import Translator
from models.transcription.transcription_utils import getInputDevices, getDefaultOutputDevice
@@ -70,7 +72,8 @@ class Model:
self.speaker_audio_recorder = None
self.speaker_energy_recorder = None
self.speaker_energy_plot_progressbar = None
self.translator = Translator(config.PATH_LOCAL)
self.downloadCTranslate2Weight()
self.translator = Translator(config.PATH_LOCAL, config.CTRANSLATE2_WIGHTS[config.WEIGHT_TYPE])
self.keyword_processor = KeywordProcessor()
def resetTranslator(self):
@@ -106,6 +109,38 @@ class Model:
self.logger.disabled = True
self.logger = None
@staticmethod
def downloadCTranslate2Weight():
weight_type = config.WEIGHT_TYPE
url = config.CTRANSLATE2_WIGHTS[weight_type]["url"]
filename = 'weight.zip'
directory_name = 'weight'
current_directory = config.PATH_LOCAL
weight_directory_name = config.CTRANSLATE2_WIGHTS[weight_type]["directory_name"]
files = ["model.bin", "sentencepiece.model", "shared_vocabulary.txt"]
# check already downloaded
if all(os_path.exists(os_path.join(current_directory, directory_name, weight_directory_name, file)) for file in files):
return
try:
os_makedirs(os_path.join(current_directory, directory_name), exist_ok=True)
print(os_path.join(current_directory, directory_name))
with tempfile.TemporaryDirectory() as tmp_path:
file_size = int(requests_head(url).headers["content-length"])
res = requests_get(url, stream=True)
pbar = tqdm(total=file_size, unit="B", unit_scale=True)
with open(os_path.join(tmp_path, filename), 'wb') as file:
for chunk in res.iter_content(chunk_size=1024):
file.write(chunk)
pbar.update(len(chunk))
pbar.close()
with ZipFile(os_path.join(tmp_path, filename)) as zf:
zf.extractall(os_path.join(current_directory, directory_name))
except Exception as e:
print("error:downloadCTranslate2Weight()", e)
@staticmethod
def getListLanguageAndCountry():
langs = []

View File

@@ -4,7 +4,6 @@ from deepl_translate import translate as deepl_web_Translator
from translators import translate_text as other_web_Translator
from .translation_languages import translation_lang
from ctranslate2.converters import TransformersConverter
import ctranslate2
import transformers
@@ -15,11 +14,13 @@ TRANSLATE_MODELS = {
# Translator
class Translator():
def __init__(self, path):
def __init__(self, path, weight_config):
self.translator_status = {}
self.weight_path = os.path.join(path, "weight")
directory_name = weight_config["directory_name"]
tokenizer = weight_config["tokenizer"]
self.weight_path = os.path.join(path, "weight", directory_name)
self.translator = ctranslate2.Translator(self.weight_path, device="cpu", device_index=0, compute_type="int8", inter_threads=1, intra_threads=4)
self.tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M")
self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)
def authentication(self, translator_name, authkey=None):
result = True