👍[Update] Model : CTranslate2のweightデータを起動時に取得するように変更
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -5,4 +5,5 @@ memo.txt
|
|||||||
VRCT.spec
|
VRCT.spec
|
||||||
*.pyc
|
*.pyc
|
||||||
logs/
|
logs/
|
||||||
.venv/
|
.venv/
|
||||||
|
weight/
|
||||||
28
config.py
28
config.py
@@ -63,6 +63,10 @@ class Config:
|
|||||||
def DOCUMENTS_URL(self):
|
def DOCUMENTS_URL(self):
|
||||||
return self._DOCUMENTS_URL
|
return self._DOCUMENTS_URL
|
||||||
|
|
||||||
|
@property
|
||||||
|
def CTRANSLATE2_WIGHTS(self):
|
||||||
|
return self._CTRANSLATE2_WIGHTS
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def MAX_MIC_ENERGY_THRESHOLD(self):
|
def MAX_MIC_ENERGY_THRESHOLD(self):
|
||||||
return self._MAX_MIC_ENERGY_THRESHOLD
|
return self._MAX_MIC_ENERGY_THRESHOLD
|
||||||
@@ -447,6 +451,17 @@ class Config:
|
|||||||
self._AUTH_KEYS[key] = value
|
self._AUTH_KEYS[key] = value
|
||||||
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, self.AUTH_KEYS)
|
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, self.AUTH_KEYS)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@json_serializable('WEIGHT_TYPE')
|
||||||
|
def WEIGHT_TYPE(self):
|
||||||
|
return self._WEIGHT_TYPE
|
||||||
|
|
||||||
|
@WEIGHT_TYPE.setter
|
||||||
|
def WEIGHT_TYPE(self, value):
|
||||||
|
if isinstance(value, str):
|
||||||
|
self._WEIGHT_TYPE = value
|
||||||
|
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@json_serializable('MESSAGE_FORMAT')
|
@json_serializable('MESSAGE_FORMAT')
|
||||||
def MESSAGE_FORMAT(self):
|
def MESSAGE_FORMAT(self):
|
||||||
@@ -537,6 +552,18 @@ class Config:
|
|||||||
self._GITHUB_URL = "https://api.github.com/repos/misyaguziya/VRCT/releases/latest"
|
self._GITHUB_URL = "https://api.github.com/repos/misyaguziya/VRCT/releases/latest"
|
||||||
self._BOOTH_URL = "https://misyaguziya.booth.pm/"
|
self._BOOTH_URL = "https://misyaguziya.booth.pm/"
|
||||||
self._DOCUMENTS_URL = "https://mzsoftware.notion.site/VRCT-Documents-be79b7a165f64442ad8f326d86c22246"
|
self._DOCUMENTS_URL = "https://mzsoftware.notion.site/VRCT-Documents-be79b7a165f64442ad8f326d86c22246"
|
||||||
|
self._CTRANSLATE2_WIGHTS = {
|
||||||
|
"small": { # M2M-100 418M-parameter model
|
||||||
|
"url": "https://bit.ly/33fM1AO",
|
||||||
|
"directory_name": "m2m100_418m",
|
||||||
|
"tokenizer": "facebook/m2m100_418M"
|
||||||
|
},
|
||||||
|
"large": { # M2M-100 1.2B-parameter model
|
||||||
|
"url": "https://bit.ly/3GYiaed",
|
||||||
|
"directory_name": "m2m100_12b",
|
||||||
|
"tokenizer": "facebook/m2m100_12b"
|
||||||
|
},
|
||||||
|
}
|
||||||
self._MAX_MIC_ENERGY_THRESHOLD = 2000
|
self._MAX_MIC_ENERGY_THRESHOLD = 2000
|
||||||
self._MAX_SPEAKER_ENERGY_THRESHOLD = 4000
|
self._MAX_SPEAKER_ENERGY_THRESHOLD = 4000
|
||||||
|
|
||||||
@@ -594,6 +621,7 @@ class Config:
|
|||||||
"Bing": None,
|
"Bing": None,
|
||||||
"Google": None,
|
"Google": None,
|
||||||
}
|
}
|
||||||
|
self.WEIGHT_TYPE = "small"
|
||||||
self._MESSAGE_FORMAT = "[message]([translation])"
|
self._MESSAGE_FORMAT = "[message]([translation])"
|
||||||
self._ENABLE_AUTO_CLEAR_MESSAGE_BOX = True
|
self._ENABLE_AUTO_CLEAR_MESSAGE_BOX = True
|
||||||
self._ENABLE_NOTICE_XSOVERLAY = False
|
self._ENABLE_NOTICE_XSOVERLAY = False
|
||||||
|
|||||||
39
model.py
39
model.py
@@ -1,3 +1,4 @@
|
|||||||
|
import tempfile
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
from subprocess import Popen
|
from subprocess import Popen
|
||||||
from os import makedirs as os_makedirs
|
from os import makedirs as os_makedirs
|
||||||
@@ -9,9 +10,10 @@ from logging import getLogger, FileHandler, Formatter, INFO
|
|||||||
from time import sleep
|
from time import sleep
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from threading import Thread, Event
|
from threading import Thread, Event
|
||||||
from requests import get as requests_get
|
from requests import get as requests_get, head as requests_head
|
||||||
import webbrowser
|
import webbrowser
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
from flashtext import KeywordProcessor
|
from flashtext import KeywordProcessor
|
||||||
from models.translation.translation_translator import Translator
|
from models.translation.translation_translator import Translator
|
||||||
from models.transcription.transcription_utils import getInputDevices, getDefaultOutputDevice
|
from models.transcription.transcription_utils import getInputDevices, getDefaultOutputDevice
|
||||||
@@ -70,7 +72,8 @@ class Model:
|
|||||||
self.speaker_audio_recorder = None
|
self.speaker_audio_recorder = None
|
||||||
self.speaker_energy_recorder = None
|
self.speaker_energy_recorder = None
|
||||||
self.speaker_energy_plot_progressbar = None
|
self.speaker_energy_plot_progressbar = None
|
||||||
self.translator = Translator(config.PATH_LOCAL)
|
self.downloadCTranslate2Weight()
|
||||||
|
self.translator = Translator(config.PATH_LOCAL, config.CTRANSLATE2_WIGHTS[config.WEIGHT_TYPE])
|
||||||
self.keyword_processor = KeywordProcessor()
|
self.keyword_processor = KeywordProcessor()
|
||||||
|
|
||||||
def resetTranslator(self):
|
def resetTranslator(self):
|
||||||
@@ -106,6 +109,38 @@ class Model:
|
|||||||
self.logger.disabled = True
|
self.logger.disabled = True
|
||||||
self.logger = None
|
self.logger = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def downloadCTranslate2Weight():
|
||||||
|
weight_type = config.WEIGHT_TYPE
|
||||||
|
url = config.CTRANSLATE2_WIGHTS[weight_type]["url"]
|
||||||
|
filename = 'weight.zip'
|
||||||
|
directory_name = 'weight'
|
||||||
|
current_directory = config.PATH_LOCAL
|
||||||
|
weight_directory_name = config.CTRANSLATE2_WIGHTS[weight_type]["directory_name"]
|
||||||
|
files = ["model.bin", "sentencepiece.model", "shared_vocabulary.txt"]
|
||||||
|
|
||||||
|
# check already downloaded
|
||||||
|
if all(os_path.exists(os_path.join(current_directory, directory_name, weight_directory_name, file)) for file in files):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
os_makedirs(os_path.join(current_directory, directory_name), exist_ok=True)
|
||||||
|
print(os_path.join(current_directory, directory_name))
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_path:
|
||||||
|
file_size = int(requests_head(url).headers["content-length"])
|
||||||
|
res = requests_get(url, stream=True)
|
||||||
|
pbar = tqdm(total=file_size, unit="B", unit_scale=True)
|
||||||
|
with open(os_path.join(tmp_path, filename), 'wb') as file:
|
||||||
|
for chunk in res.iter_content(chunk_size=1024):
|
||||||
|
file.write(chunk)
|
||||||
|
pbar.update(len(chunk))
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
with ZipFile(os_path.join(tmp_path, filename)) as zf:
|
||||||
|
zf.extractall(os_path.join(current_directory, directory_name))
|
||||||
|
except Exception as e:
|
||||||
|
print("error:downloadCTranslate2Weight()", e)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getListLanguageAndCountry():
|
def getListLanguageAndCountry():
|
||||||
langs = []
|
langs = []
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from deepl_translate import translate as deepl_web_Translator
|
|||||||
from translators import translate_text as other_web_Translator
|
from translators import translate_text as other_web_Translator
|
||||||
from .translation_languages import translation_lang
|
from .translation_languages import translation_lang
|
||||||
|
|
||||||
from ctranslate2.converters import TransformersConverter
|
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
@@ -15,11 +14,13 @@ TRANSLATE_MODELS = {
|
|||||||
|
|
||||||
# Translator
|
# Translator
|
||||||
class Translator():
|
class Translator():
|
||||||
def __init__(self, path):
|
def __init__(self, path, weight_config):
|
||||||
self.translator_status = {}
|
self.translator_status = {}
|
||||||
self.weight_path = os.path.join(path, "weight")
|
directory_name = weight_config["directory_name"]
|
||||||
|
tokenizer = weight_config["tokenizer"]
|
||||||
|
self.weight_path = os.path.join(path, "weight", directory_name)
|
||||||
self.translator = ctranslate2.Translator(self.weight_path, device="cpu", device_index=0, compute_type="int8", inter_threads=1, intra_threads=4)
|
self.translator = ctranslate2.Translator(self.weight_path, device="cpu", device_index=0, compute_type="int8", inter_threads=1, intra_threads=4)
|
||||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M")
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer)
|
||||||
|
|
||||||
def authentication(self, translator_name, authkey=None):
|
def authentication(self, translator_name, authkey=None):
|
||||||
result = True
|
result = True
|
||||||
|
|||||||
Reference in New Issue
Block a user