From 10b8d115a118f3cfeaf400af76186115c084950b Mon Sep 17 00:00:00 2001 From: misyaguziya Date: Wed, 31 Jan 2024 22:50:31 +0900 Subject: [PATCH] =?UTF-8?q?[WIP/TEST]=20faster-whisper=20model=20weight=20?= =?UTF-8?q?=E3=81=AE=E3=83=80=E3=82=A6=E3=83=B3=E3=83=AD=E3=83=BC=E3=83=89?= =?UTF-8?q?/=E3=83=99=E3=83=AA=E3=83=95=E3=82=A1=E3=82=A4=E5=87=A6?= =?UTF-8?q?=E7=90=86=E3=82=92=E5=AE=9F=E8=A3=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 8 +- model.py | 4 +- .../transcription_transcriber.py | 13 +-- models/transcription/transcription_utils.py | 40 +------- models/transcription/transcription_whisper.py | 98 +++++++++++++++++++ 5 files changed, 111 insertions(+), 52 deletions(-) create mode 100644 models/transcription/transcription_whisper.py diff --git a/main.py b/main.py index cf80e289..4aaa7232 100644 --- a/main.py +++ b/main.py @@ -11,8 +11,14 @@ if __name__ == "__main__": from models.translation.utils import downloadCTranslate2Weight if config.USE_TRANSLATION_FEATURE is True: downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, splash.updateDownloadProgress) - splash.toProgress(0) + + # whisperのダウンロードの説明に変更する必要あり + if config.USE_RECOGNIZER_FEATURE is True: + from models.transcription.transcription_whisper import downloadWhisperWeight + downloadWhisperWeight(config.PATH_LOCAL, config.WHISPER_WEIGHT_TYPE, splash.updateDownloadProgress) + splash.toProgress(0) + import controller controller.createMainWindow(splash) splash.destroySplash() diff --git a/model.py b/model.py index 61ff24d7..6b73bece 100644 --- a/model.py +++ b/model.py @@ -337,7 +337,7 @@ class Model: max_phrases=config.INPUT_MIC_MAX_PHRASES, whisper_enabled=config.USE_RECOGNIZER_FEATURE, whisper_weight_type=config.WHISPER_WEIGHT_TYPE, - whisper_weight_path=os_path.join(config.PATH_LOCAL, "weight", "whisper"), + root=config.PATH_LOCAL, ) def sendMicTranscript(): mic_transcriber.transcribeAudioQueue(config.SELECTED_RECOGNIZER, mic_audio_queue, config.SOURCE_LANGUAGE, config.SOURCE_COUNTRY) @@ -421,7 +421,7 @@ class Model: max_phrases=config.INPUT_SPEAKER_MAX_PHRASES, whisper_enabled=config.USE_RECOGNIZER_FEATURE, whisper_weight_type=config.WHISPER_WEIGHT_TYPE, - whisper_weight_path=os_path.join(config.PATH_LOCAL, "weight", "whisper"), + root=config.PATH_LOCAL, ) def sendSpeakerTranscript(): speaker_transcriber.transcribeAudioQueue(speaker_audio_queue, config.TARGET_LANGUAGE, config.TARGET_COUNTRY) diff --git a/models/transcription/transcription_transcriber.py b/models/transcription/transcription_transcriber.py index 526c12dc..0f5b1790 100644 --- a/models/transcription/transcription_transcriber.py +++ b/models/transcription/transcription_transcriber.py @@ -5,16 +5,16 @@ from speech_recognition import Recognizer, AudioData, AudioFile from datetime import timedelta from pyaudiowpatch import get_sample_size, paInt16 from .transcription_languages import transcription_lang +from .transcription_whisper import getWhisperModel import torch import numpy as np -from faster_whisper import WhisperModel PHRASE_TIMEOUT = 3 MAX_PHRASES = 10 class AudioTranscriber: - def __init__(self, speaker, source, phrase_timeout, max_phrases, whisper_enabled, whisper_weight_type, whisper_weight_path): + def __init__(self, speaker, source, phrase_timeout, max_phrases, whisper_enabled, whisper_weight_type, root): self.speaker = speaker self.phrase_timeout = phrase_timeout self.max_phrases = max_phrases @@ -31,14 +31,7 @@ class AudioTranscriber: "process_data_func": self.processSpeakerData if speaker else self.processSpeakerData } if whisper_enabled is True: - self.whisper_model = WhisperModel( - model_size_or_path=whisper_weight_type, - device="cpu", - device_index=0, - compute_type="int8", - cpu_threads=4, - num_workers=1, - download_root=whisper_weight_path) + self.whisper_model = getWhisperModel(root, whisper_weight_type) else: self.whisper_model = None diff --git a/models/transcription/transcription_utils.py b/models/transcription/transcription_utils.py index 8de17e7e..f40defeb 100644 --- a/models/transcription/transcription_utils.py +++ b/models/transcription/transcription_utils.py @@ -1,8 +1,4 @@ from pyaudiowpatch import PyAudio, paWASAPI -from faster_whisper.utils import download_model -import logging -logger = logging.getLogger('faster_whisper') -logger.setLevel(logging.CRITICAL) def getInputDevices(): devices = {} @@ -48,38 +44,4 @@ def getDefaultOutputDevice(): if default_speakers["name"] in loopback["name"]: default_device = loopback return default_device - return {"name":"NoDevice"} - -def downloadWhisperWeight(weight_type, path): - result = False - try: - download_model( - weight_type, - cache_dir=path) - result = True - except Exception: - pass - return result - -def checkWhisperWeight(weight_type, path): - result = False - try: - result = download_model( - weight_type, - local_files_only=True, - cache_dir=path) - result = True - except Exception: - pass - return result - -if __name__ == "__main__": - - - downloadWhisperWeight("base", "./weight/whisper/") - - from faster_whisper import WhisperModel - whisper_model = WhisperModel("base", device="cpu", device_index=0, compute_type="int8", cpu_threads=4, num_workers=1, download_root="./weight/whisper/") - - print(checkWhisperWeight("base", "./weight/whisper/")) - print(checkWhisperWeight("tiny", "./weight/whisper/")) \ No newline at end of file + return {"name":"NoDevice"} \ No newline at end of file diff --git a/models/transcription/transcription_whisper.py b/models/transcription/transcription_whisper.py new file mode 100644 index 00000000..dc606cb7 --- /dev/null +++ b/models/transcription/transcription_whisper.py @@ -0,0 +1,98 @@ +from os import path as os_path, makedirs as os_makedirs +from requests import get as requests_get +from typing import Callable +import huggingface_hub +from faster_whisper import WhisperModel +import logging +logger = logging.getLogger('faster_whisper') +logger.setLevel(logging.CRITICAL) + +_MODELS = { + "tiny.en": "Systran/faster-whisper-tiny.en", + "tiny": "Systran/faster-whisper-tiny", + "base.en": "Systran/faster-whisper-base.en", + "base": "Systran/faster-whisper-base", + "small.en": "Systran/faster-whisper-small.en", + "small": "Systran/faster-whisper-small", + "medium.en": "Systran/faster-whisper-medium.en", + "medium": "Systran/faster-whisper-medium", + "large-v1": "Systran/faster-whisper-large-v1", + "large-v2": "Systran/faster-whisper-large-v2", + "large-v3": "Systran/faster-whisper-large-v3", + "large": "Systran/faster-whisper-large-v3", +} + +_FILENAMES = [ + "config.json", + "preprocessor_config.json", + "model.bin", + "tokenizer.json", + "vocabulary.txt", +] + +def downloadFile(url, path, func=None): + try: + res = requests_get(url, stream=True) + res.raise_for_status() + file_size = int(res.headers.get('content-length', 0)) + total_chunk = 0 + with open(os_path.join(path), 'wb') as file: + for chunk in res.iter_content(chunk_size=1024*5): + file.write(chunk) + if isinstance(func, Callable): + total_chunk += len(chunk) + func(total_chunk/file_size) + + except Exception as e: + print("error:downloadFile()", e) + +def checkWhisperWeight(path): + result = False + try: + WhisperModel( + path, + device="cpu", + device_index=0, + compute_type="int8", + cpu_threads=4, + num_workers=1, + local_files_only=True, + ) + result = True + except Exception: + pass + return result + +def downloadWhisperWeight(root, weight_type, callbackFunc): + path = os_path.join(root, "weight", "whisper", weight_type) + os_makedirs(path, exist_ok=True) + if checkWhisperWeight(path) is True: + return + + for filename in _FILENAMES: + print("Downloading", filename, "...") + file_path = os_path.join(path, filename) + url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename) + downloadFile(url, file_path, func=callbackFunc) + +def getWhisperModel(root, weight_type): + path = os_path.join(root, "weight", "whisper", weight_type) + return WhisperModel( + path, + device="cpu", + device_index=0, + compute_type="int8", + cpu_threads=4, + num_workers=1, + local_files_only=True, + ) + +if __name__ == "__main__": + def callback(value): + print(value) + + downloadWhisperWeight("./", "tiny", callback) + downloadWhisperWeight("./", "base", callback) + downloadWhisperWeight("./", "small", callback) + downloadWhisperWeight("./", "medium", callback) + downloadWhisperWeight("./", "large", callback) \ No newline at end of file