👍[Update] Model : CTranslate2のweightデータを起動時に取得するように変更

This commit is contained in:
misyaguziya
2023-11-19 00:03:57 +09:00
parent dcb6c07eee
commit e6e62cf350
4 changed files with 72 additions and 7 deletions

3
.gitignore vendored
View File

@@ -5,4 +5,5 @@ memo.txt
VRCT.spec VRCT.spec
*.pyc *.pyc
logs/ logs/
.venv/ .venv/
weight/

View File

@@ -63,6 +63,10 @@ class Config:
def DOCUMENTS_URL(self): def DOCUMENTS_URL(self):
return self._DOCUMENTS_URL return self._DOCUMENTS_URL
@property
def CTRANSLATE2_WIGHTS(self):
return self._CTRANSLATE2_WIGHTS
@property @property
def MAX_MIC_ENERGY_THRESHOLD(self): def MAX_MIC_ENERGY_THRESHOLD(self):
return self._MAX_MIC_ENERGY_THRESHOLD return self._MAX_MIC_ENERGY_THRESHOLD
@@ -447,6 +451,17 @@ class Config:
self._AUTH_KEYS[key] = value self._AUTH_KEYS[key] = value
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, self.AUTH_KEYS) saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, self.AUTH_KEYS)
@property
@json_serializable('WEIGHT_TYPE')
def WEIGHT_TYPE(self):
return self._WEIGHT_TYPE
@WEIGHT_TYPE.setter
def WEIGHT_TYPE(self, value):
if isinstance(value, str):
self._WEIGHT_TYPE = value
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value)
@property @property
@json_serializable('MESSAGE_FORMAT') @json_serializable('MESSAGE_FORMAT')
def MESSAGE_FORMAT(self): def MESSAGE_FORMAT(self):
@@ -537,6 +552,18 @@ class Config:
self._GITHUB_URL = "https://api.github.com/repos/misyaguziya/VRCT/releases/latest" self._GITHUB_URL = "https://api.github.com/repos/misyaguziya/VRCT/releases/latest"
self._BOOTH_URL = "https://misyaguziya.booth.pm/" self._BOOTH_URL = "https://misyaguziya.booth.pm/"
self._DOCUMENTS_URL = "https://mzsoftware.notion.site/VRCT-Documents-be79b7a165f64442ad8f326d86c22246" self._DOCUMENTS_URL = "https://mzsoftware.notion.site/VRCT-Documents-be79b7a165f64442ad8f326d86c22246"
self._CTRANSLATE2_WIGHTS = {
"small": { # M2M-100 418M-parameter model
"url": "https://bit.ly/33fM1AO",
"directory_name": "m2m100_418m",
"tokenizer": "facebook/m2m100_418M"
},
"large": { # M2M-100 1.2B-parameter model
"url": "https://bit.ly/3GYiaed",
"directory_name": "m2m100_12b",
"tokenizer": "facebook/m2m100_12b"
},
}
self._MAX_MIC_ENERGY_THRESHOLD = 2000 self._MAX_MIC_ENERGY_THRESHOLD = 2000
self._MAX_SPEAKER_ENERGY_THRESHOLD = 4000 self._MAX_SPEAKER_ENERGY_THRESHOLD = 4000
@@ -594,6 +621,7 @@ class Config:
"Bing": None, "Bing": None,
"Google": None, "Google": None,
} }
self.WEIGHT_TYPE = "small"
self._MESSAGE_FORMAT = "[message]([translation])" self._MESSAGE_FORMAT = "[message]([translation])"
self._ENABLE_AUTO_CLEAR_MESSAGE_BOX = True self._ENABLE_AUTO_CLEAR_MESSAGE_BOX = True
self._ENABLE_NOTICE_XSOVERLAY = False self._ENABLE_NOTICE_XSOVERLAY = False

View File

@@ -1,3 +1,4 @@
import tempfile
from zipfile import ZipFile from zipfile import ZipFile
from subprocess import Popen from subprocess import Popen
from os import makedirs as os_makedirs from os import makedirs as os_makedirs
@@ -9,9 +10,10 @@ from logging import getLogger, FileHandler, Formatter, INFO
from time import sleep from time import sleep
from queue import Queue from queue import Queue
from threading import Thread, Event from threading import Thread, Event
from requests import get as requests_get from requests import get as requests_get, head as requests_head
import webbrowser import webbrowser
from tqdm import tqdm
from flashtext import KeywordProcessor from flashtext import KeywordProcessor
from models.translation.translation_translator import Translator from models.translation.translation_translator import Translator
from models.transcription.transcription_utils import getInputDevices, getDefaultOutputDevice from models.transcription.transcription_utils import getInputDevices, getDefaultOutputDevice
@@ -70,7 +72,8 @@ class Model:
self.speaker_audio_recorder = None self.speaker_audio_recorder = None
self.speaker_energy_recorder = None self.speaker_energy_recorder = None
self.speaker_energy_plot_progressbar = None self.speaker_energy_plot_progressbar = None
self.translator = Translator(config.PATH_LOCAL) self.downloadCTranslate2Weight()
self.translator = Translator(config.PATH_LOCAL, config.CTRANSLATE2_WIGHTS[config.WEIGHT_TYPE])
self.keyword_processor = KeywordProcessor() self.keyword_processor = KeywordProcessor()
def resetTranslator(self): def resetTranslator(self):
@@ -106,6 +109,38 @@ class Model:
self.logger.disabled = True self.logger.disabled = True
self.logger = None self.logger = None
@staticmethod
def downloadCTranslate2Weight():
weight_type = config.WEIGHT_TYPE
url = config.CTRANSLATE2_WIGHTS[weight_type]["url"]
filename = 'weight.zip'
directory_name = 'weight'
current_directory = config.PATH_LOCAL
weight_directory_name = config.CTRANSLATE2_WIGHTS[weight_type]["directory_name"]
files = ["model.bin", "sentencepiece.model", "shared_vocabulary.txt"]
# check already downloaded
if all(os_path.exists(os_path.join(current_directory, directory_name, weight_directory_name, file)) for file in files):
return
try:
os_makedirs(os_path.join(current_directory, directory_name), exist_ok=True)
print(os_path.join(current_directory, directory_name))
with tempfile.TemporaryDirectory() as tmp_path:
file_size = int(requests_head(url).headers["content-length"])
res = requests_get(url, stream=True)
pbar = tqdm(total=file_size, unit="B", unit_scale=True)
with open(os_path.join(tmp_path, filename), 'wb') as file:
for chunk in res.iter_content(chunk_size=1024):
file.write(chunk)
pbar.update(len(chunk))
pbar.close()
with ZipFile(os_path.join(tmp_path, filename)) as zf:
zf.extractall(os_path.join(current_directory, directory_name))
except Exception as e:
print("error:downloadCTranslate2Weight()", e)
@staticmethod @staticmethod
def getListLanguageAndCountry(): def getListLanguageAndCountry():
langs = [] langs = []

View File

@@ -4,7 +4,6 @@ from deepl_translate import translate as deepl_web_Translator
from translators import translate_text as other_web_Translator from translators import translate_text as other_web_Translator
from .translation_languages import translation_lang from .translation_languages import translation_lang
from ctranslate2.converters import TransformersConverter
import ctranslate2 import ctranslate2
import transformers import transformers
@@ -15,11 +14,13 @@ TRANSLATE_MODELS = {
# Translator # Translator
class Translator(): class Translator():
def __init__(self, path): def __init__(self, path, weight_config):
self.translator_status = {} self.translator_status = {}
self.weight_path = os.path.join(path, "weight") directory_name = weight_config["directory_name"]
tokenizer = weight_config["tokenizer"]
self.weight_path = os.path.join(path, "weight", directory_name)
self.translator = ctranslate2.Translator(self.weight_path, device="cpu", device_index=0, compute_type="int8", inter_threads=1, intra_threads=4) self.translator = ctranslate2.Translator(self.weight_path, device="cpu", device_index=0, compute_type="int8", inter_threads=1, intra_threads=4)
self.tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M") self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)
def authentication(self, translator_name, authkey=None): def authentication(self, translator_name, authkey=None):
result = True result = True