From 1de239549f7dc3c00b55fefb1dc46da35aec2b24 Mon Sep 17 00:00:00 2001 From: misyaguziya Date: Thu, 1 Feb 2024 15:49:17 +0900 Subject: [PATCH] =?UTF-8?q?[WIP/TEST]=20Model=20:=20=E3=83=A2=E3=83=87?= =?UTF-8?q?=E3=83=AB=E3=81=AE=E4=BF=9D=E5=AD=98=E4=BD=8D=E7=BD=AE=E3=81=AE?= =?UTF-8?q?=E5=A4=89=E6=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - speakerの文字起こし処理のバグを修正 --- .gitignore | 2 +- main.py | 5 ++-- model.py | 4 ++-- models/transcription/transcription_whisper.py | 7 ++---- models/translation/translation_translator.py | 8 +++---- .../{utils.py => translation_utils.py} | 24 +++++++++---------- 6 files changed, 23 insertions(+), 27 deletions(-) rename models/translation/{utils.py => translation_utils.py} (78%) diff --git a/.gitignore b/.gitignore index 75c28a41..52825c27 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ VRCT.spec *.pyc logs/ .venv/ -weight/ +weights/ .vscode error.log *.exe diff --git a/main.py b/main.py index 37bc53af..0df15326 100644 --- a/main.py +++ b/main.py @@ -8,14 +8,13 @@ if __name__ == "__main__": splash.showSplash() from config import config - from models.translation.utils import downloadCTranslate2Weight + from models.translation.translation_utils import downloadCTranslate2Weight if config.USE_TRANSLATION_FEATURE is True: downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, splash.updateDownloadProgress) - splash.toProgress(0) + from models.transcription.transcription_whisper import downloadWhisperWeight # whisperのダウンロードの説明に変更する必要あり if config.USE_WHISPER_FEATURE is True: - from models.transcription.transcription_whisper import downloadWhisperWeight downloadWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE, splash.updateDownloadProgress) splash.toProgress(0) diff --git a/model.py b/model.py index 98d0a896..2c29d4c7 100644 --- a/model.py +++ b/model.py @@ -23,7 +23,7 @@ from models.transcription.transcription_transcriber import AudioTranscriber from models.xsoverlay.notification import xsoverlayForVRCT from models.translation.translation_languages import translation_lang from models.transcription.transcription_languages import transcription_lang -from models.translation.utils import checkCTranslate2Weight +from models.translation.translation_utils import checkCTranslate2Weight from config import config class threadFnc(Thread): @@ -424,7 +424,7 @@ class Model: root=config.PATH_LOCAL, ) def sendSpeakerTranscript(): - speaker_transcriber.transcribeAudioQueue(speaker_audio_queue, config.TARGET_LANGUAGE, config.TARGET_COUNTRY) + speaker_transcriber.transcribeAudioQueue(config.SELECTED_RECOGNIZER, speaker_audio_queue, config.TARGET_LANGUAGE, config.TARGET_COUNTRY) message = speaker_transcriber.getTranscript() try: fnc(message) diff --git a/models/transcription/transcription_whisper.py b/models/transcription/transcription_whisper.py index 67ad61f0..e30fee2d 100644 --- a/models/transcription/transcription_whisper.py +++ b/models/transcription/transcription_whisper.py @@ -60,10 +60,9 @@ def checkWhisperWeight(path): return result def downloadWhisperWeight(root, weight_type, callbackFunc): - path = os_path.join(root, "weight", "whisper", weight_type) + path = os_path.join(root, "weights", "whisper", weight_type) os_makedirs(path, exist_ok=True) if checkWhisperWeight(path) is True: - print("weight_type:", weight_type, checkWhisperWeight(path)) return for filename in _FILENAMES: @@ -72,10 +71,8 @@ def downloadWhisperWeight(root, weight_type, callbackFunc): url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename) downloadFile(url, file_path, func=callbackFunc) - print("weight_type:", weight_type, checkWhisperWeight(path)) - def getWhisperModel(root, weight_type): - path = os_path.join(root, "weight", "whisper", weight_type) + path = os_path.join(root, "weights", "whisper", weight_type) return WhisperModel( path, device="cpu", diff --git a/models/translation/translation_translator.py b/models/translation/translation_translator.py index ea02e490..c966c672 100644 --- a/models/translation/translation_translator.py +++ b/models/translation/translation_translator.py @@ -2,7 +2,7 @@ import os from deepl import Translator as deepl_Translator from translators import translate_text as other_web_Translator from .translation_languages import translation_lang -from .utils import ctranslate2_weights +from .translation_utils import ctranslate2_weights import ctranslate2 import transformers @@ -27,8 +27,8 @@ class Translator(): def changeCTranslate2Model(self, path, model_type): directory_name = ctranslate2_weights[model_type]["directory_name"] tokenizer = ctranslate2_weights[model_type]["tokenizer"] - weight_path = os.path.join(path, "weight", directory_name) - tokenizer_path = os.path.join(path, "weight", directory_name, "tokenizer") + weight_path = os.path.join(path, "weights", "ctranslate2", directory_name) + tokenizer_path = os.path.join(path, "weights", "ctranslate2", directory_name, "tokenizer") self.ctranslate2_translator = ctranslate2.Translator( weight_path, device="cpu", @@ -41,7 +41,7 @@ class Translator(): self.ctranslate2_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) except Exception as e: print("Error: changeCTranslate2Model()", e) - tokenizer_path = os.path.join("./weight", directory_name, "tokenizer") + tokenizer_path = os.path.join("./weights", "ctranslate2", directory_name, "tokenizer") self.ctranslate2_tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer, cache_dir=tokenizer_path) @staticmethod diff --git a/models/translation/utils.py b/models/translation/translation_utils.py similarity index 78% rename from models/translation/utils.py rename to models/translation/translation_utils.py index d47401cf..73805cdc 100644 --- a/models/translation/utils.py +++ b/models/translation/translation_utils.py @@ -39,36 +39,36 @@ def calculate_file_hash(file_path, block_size=65536): return hash_object.hexdigest() def checkCTranslate2Weight(path, weight_type="Small"): - directory_name = 'weight' - current_directory = path weight_directory_name = ctranslate2_weights[weight_type]["directory_name"] hash_data = ctranslate2_weights[weight_type]["hash"] - files = ["model.bin", "sentencepiece.model", "shared_vocabulary.txt"] + files = [ + "model.bin", + "sentencepiece.model", + "shared_vocabulary.txt" + ] # check already downloaded already_downloaded = False - if all(os_path.exists(os_path.join(current_directory, directory_name, weight_directory_name, file)) for file in files): + if all(os_path.exists(os_path.join(path, weight_directory_name, file)) for file in files): # check hash for file in files: original_hash = hash_data[file] - current_hash = calculate_file_hash(os_path.join(current_directory, directory_name, weight_directory_name, file)) + current_hash = calculate_file_hash(os_path.join(path, weight_directory_name, file)) if original_hash != current_hash: break already_downloaded = True return already_downloaded -def downloadCTranslate2Weight(path, weight_type="Small", func=None): +def downloadCTranslate2Weight(root, weight_type="Small", func=None): url = ctranslate2_weights[weight_type]["url"] - filename = 'weight.zip' - directory_name = 'weight' - current_directory = path + filename = "weight.zip" + path = os_path.join(root, "weights", "ctranslate2") + os_makedirs(path, exist_ok=True) if checkCTranslate2Weight(path, weight_type): return try: - os_makedirs(os_path.join(current_directory, directory_name), exist_ok=True) - print(os_path.join(current_directory, directory_name)) with tempfile.TemporaryDirectory() as tmp_path: res = requests_get(url, stream=True) file_size = int(res.headers.get('content-length', 0)) @@ -81,6 +81,6 @@ def downloadCTranslate2Weight(path, weight_type="Small", func=None): func(total_chunk/file_size) with ZipFile(os_path.join(tmp_path, filename)) as zf: - zf.extractall(os_path.join(current_directory, directory_name)) + zf.extractall(path) except Exception as e: print("error:downloadCTranslate2Weight()", e) \ No newline at end of file