From 9cd1831ecbb4f313347c90d6971eb3c7a075812b Mon Sep 17 00:00:00 2001 From: misyaguziya Date: Tue, 30 Jan 2024 18:21:55 +0900 Subject: [PATCH] =?UTF-8?q?[WIP/TEST]=20faster-whisper=E3=81=8C=E6=9C=80?= =?UTF-8?q?=E4=BD=8E=E9=99=90=E5=8B=95=E3=81=8F=E5=BD=A2=E3=81=A7=E5=AE=9F?= =?UTF-8?q?=E8=A3=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit config.jsonで設定変更で実行可能 --- config.py | 71 ++- controller.py | 4 +- main.py | 2 +- model.py | 14 +- .../transcription/transcription_languages.py | 443 ++++++++++++++---- .../transcription_transcriber.py | 70 +-- models/transcription/transcription_utils.py | 40 +- view.py | 4 +- 8 files changed, 511 insertions(+), 137 deletions(-) diff --git a/config.py b/config.py index 371ec121..6acf5e3f 100644 --- a/config.py +++ b/config.py @@ -98,6 +98,10 @@ class Config: def SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT(self): return self._SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT + @property + def SELECTABLE_WHISPER_WEIGHT_TYPE_DICT(self): + return self._SELECTABLE_WHISPER_WEIGHT_TYPE_DICT + @property def MAX_MIC_ENERGY_THRESHOLD(self): return self._MAX_MIC_ENERGY_THRESHOLD @@ -263,6 +267,17 @@ class Config: self._SELECTED_TAB_TARGET_LANGUAGES = value saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) + @property + @json_serializable('SELECTED_RECOGNIZER') + def SELECTED_RECOGNIZER(self): + return self._SELECTED_RECOGNIZER + + @SELECTED_RECOGNIZER.setter + def SELECTED_RECOGNIZER(self, value): + if isinstance(value, str): + self._SELECTED_RECOGNIZER = value + saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) + @property @json_serializable('IS_MAIN_WINDOW_SIDEBAR_COMPACT_MODE') def IS_MAIN_WINDOW_SIDEBAR_COMPACT_MODE(self): @@ -569,15 +584,37 @@ class Config: saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) @property - @json_serializable('WEIGHT_TYPE') - def WEIGHT_TYPE(self): - return self._WEIGHT_TYPE + @json_serializable('USE_RECOGNIZER_FEATURE') + def USE_RECOGNIZER_FEATURE(self): + return self._USE_RECOGNIZER_FEATURE - @WEIGHT_TYPE.setter - def WEIGHT_TYPE(self, value): + @USE_RECOGNIZER_FEATURE.setter + def USE_RECOGNIZER_FEATURE(self, value): + if isinstance(value, bool): + self._USE_RECOGNIZER_FEATURE = value + saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) + + @property + @json_serializable('CTRANSLATE2_WEIGHT_TYPE') + def CTRANSLATE2_WEIGHT_TYPE(self): + return self._CTRANSLATE2_WEIGHT_TYPE + + @CTRANSLATE2_WEIGHT_TYPE.setter + def CTRANSLATE2_WEIGHT_TYPE(self, value): # if isinstance(value, str) and value in self.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT: if isinstance(value, str): - self._WEIGHT_TYPE = value + self._CTRANSLATE2_WEIGHT_TYPE = value + saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) + + @property + @json_serializable('WHISPER_WEIGHT_TYPE') + def WHISPER_WEIGHT_TYPE(self): + return self._WHISPER_WEIGHT_TYPE + + @WHISPER_WEIGHT_TYPE.setter + def WHISPER_WEIGHT_TYPE(self, value): + if isinstance(value, str): + self._WHISPER_WEIGHT_TYPE = value saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) @property @@ -756,6 +793,23 @@ class Config: "Small": "Small", "Large": "Large", } + + self._SELECTABLE_WHISPER_WEIGHT_TYPE_DICT = { + # {Save json str}: {i18n_placeholder} pairs + "tiny": "tiny", + "tiny.en": "tiny.en", + "base": "base", + "base.en": "base.en", + "small": "small", + "small.en": "small.en", + "medium": "medium", + "medium.en": "medium.en", + "large-v1": "large-v1", + "large-v2": "large-v2", + "large-v3": "large-v3", + "large": "large", + } + self._MAX_MIC_ENERGY_THRESHOLD = 2000 self._MAX_SPEAKER_ENERGY_THRESHOLD = 4000 @@ -795,6 +849,7 @@ class Config: "2":"English\n(United States)", "3":"English\n(United States)", } + self._SELECTED_RECOGNIZER = "Google" self._IS_MAIN_WINDOW_SIDEBAR_COMPACT_MODE = False ## Config Window @@ -831,7 +886,9 @@ class Config: "DeepL_API": None, } self._USE_TRANSLATION_FEATURE = True - self._WEIGHT_TYPE = "Small" + self._CTRANSLATE2_WEIGHT_TYPE = "Small" + self._USE_RECOGNIZER_FEATURE = True + self._WHISPER_WEIGHT_TYPE = "base" self._SEND_MESSAGE_FORMAT = "[message]" self._SEND_MESSAGE_FORMAT_WITH_T = "[message]([translation])" self._RECEIVED_MESSAGE_FORMAT = "[message]" diff --git a/controller.py b/controller.py index f9e9a5b3..9d44b491 100644 --- a/controller.py +++ b/controller.py @@ -505,8 +505,8 @@ def callbackSetUseTranslationFeature(value): def callbackSetCtranslate2WeightType(value): print("callbackSetCtranslate2WeightType", value) - config.WEIGHT_TYPE = str(value) - view.updateSelectedCtranslate2WeightType(config.WEIGHT_TYPE) + config.CTRANSLATE2_WEIGHT_TYPE = str(value) + view.updateSelectedCtranslate2WeightType(config.CTRANSLATE2_WEIGHT_TYPE) view.setWidgetsStatus_changeWeightType_Pending() if model.checkCTranslatorCTranslate2ModelWeight(): config.IS_RESET_BUTTON_DISPLAYED_FOR_TRANSLATION = False diff --git a/main.py b/main.py index 4810cbe5..cf80e289 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ if __name__ == "__main__": from config import config from models.translation.utils import downloadCTranslate2Weight if config.USE_TRANSLATION_FEATURE is True: - downloadCTranslate2Weight(config.PATH_LOCAL, config.WEIGHT_TYPE, splash.updateDownloadProgress) + downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, splash.updateDownloadProgress) splash.toProgress(0) import controller diff --git a/model.py b/model.py index 573659a7..61ff24d7 100644 --- a/model.py +++ b/model.py @@ -65,14 +65,14 @@ class Model: self.speaker_energy_plot_progressbar = None self.translator = Translator() if config.USE_TRANSLATION_FEATURE is True: - self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.WEIGHT_TYPE) + self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) self.keyword_processor = KeywordProcessor() def checkCTranslatorCTranslate2ModelWeight(self): - return checkCTranslate2Weight(config.PATH_LOCAL, config.WEIGHT_TYPE) + return checkCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) def changeTranslatorCTranslate2Model(self): - self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.WEIGHT_TYPE) + self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) def resetKeywordProcessor(self): del self.keyword_processor @@ -335,9 +335,12 @@ class Model: source=self.mic_audio_recorder.source, phrase_timeout=phase_timeout, max_phrases=config.INPUT_MIC_MAX_PHRASES, + whisper_enabled=config.USE_RECOGNIZER_FEATURE, + whisper_weight_type=config.WHISPER_WEIGHT_TYPE, + whisper_weight_path=os_path.join(config.PATH_LOCAL, "weight", "whisper"), ) def sendMicTranscript(): - mic_transcriber.transcribeAudioQueue(mic_audio_queue, config.SOURCE_LANGUAGE, config.SOURCE_COUNTRY) + mic_transcriber.transcribeAudioQueue(config.SELECTED_RECOGNIZER, mic_audio_queue, config.SOURCE_LANGUAGE, config.SOURCE_COUNTRY) message = mic_transcriber.getTranscript() try: fnc(message) @@ -416,6 +419,9 @@ class Model: source=self.speaker_audio_recorder.source, phrase_timeout=phase_timeout, max_phrases=config.INPUT_SPEAKER_MAX_PHRASES, + whisper_enabled=config.USE_RECOGNIZER_FEATURE, + whisper_weight_type=config.WHISPER_WEIGHT_TYPE, + whisper_weight_path=os_path.join(config.PATH_LOCAL, "weight", "whisper"), ) def sendSpeakerTranscript(): speaker_transcriber.transcribeAudioQueue(speaker_audio_queue, config.TARGET_LANGUAGE, config.TARGET_COUNTRY) diff --git a/models/transcription/transcription_languages.py b/models/transcription/transcription_languages.py index 26f2c3f6..63d92568 100644 --- a/models/transcription/transcription_languages.py +++ b/models/transcription/transcription_languages.py @@ -1,177 +1,438 @@ transcription_lang = { "Afrikaans":{ - "South Africa":"af-ZA", + "South Africa":{ + "Google": "af-ZA", + "Whisper": "af", + }, }, "Arabic":{ - "Algeria":"ar-DZ", - "Bahrain":"ar-BH", - "Egypt":"ar-EG", - "Israel":"ar-IL", - "Iraq":"ar-IQ", - "Jordan":"ar-JO", - "Kuwait":"ar-KW", - "Lebanon":"ar-LB", - "Morocco":"ar-MA", - "Oman":"ar-OM", - "State of Palestine":"ar-PS", - "Qatar":"ar-QA", - "Saudi Arabia":"ar-SA", - "Tunisia":"ar-TN", - "United Arab Emirates":"ar-AE", + "Algeria":{ + "Google": "ar-DZ", + "Whisper": "ar", + }, + "Bahrain":{ + "Google": "ar-BH", + "Whisper": "ar", + }, + "Egypt":{ + "Google": "ar-EG", + "Whisper": "ar", + }, + "Israel":{ + "Google": "ar-IL", + "Whisper": "ar", + }, + "Iraq":{ + "Google": "ar-IQ", + "Whisper": "ar", + }, + "Jordan":{ + "Google": "ar-JO", + "Whisper": "ar", + }, + "Kuwait":{ + "Google": "ar-KW", + "Whisper": "ar", + }, + "Lebanon":{ + "Google": "ar-LB", + "Whisper": "ar", + }, + "Morocco":{ + "Google": "ar-MA", + "Whisper": "ar", + }, + "Oman":{ + "Google": "ar-OM", + "Whisper": "ar", + }, + "State of Palestine":{ + "Google": "ar-PS", + "Whisper": "ar", + }, + "Qatar":{ + "Google": "ar-QA", + "Whisper": "ar", + }, + "Saudi Arabia":{ + "Google": "ar-SA", + "Whisper": "ar", + }, + "Tunisia":{ + "Google": "ar-TN", + "Whisper": "ar", + }, + "United Arab Emirates":{ + "Google": "ar-AE", + "Whisper": "ar", + }, }, "Basque":{ - "Spain":"eu-ES", + "Spain":{ + "Google": "eu-ES", + "Whisper": "eu", + }, }, "Bulgarian":{ - "Bulgaria":"bg-BG", + "Bulgaria":{ + "Google": "bg-BG", + "Whisper": "bg", + }, }, "Catalan":{ - "Spain":"ca-ES", + "Spain":{ + "Google": "ca-ES", + "Whisper": "ca", + }, }, "Chinese":{ - "Mandarin (Simplified, China)":"cmn-Hans-CN", - "Mandarin (Simplified, Hong Kong)":"cmn-Hans-HK", - "Mandarin (Traditional, Taiwan)":"cmn-Hant-TW", - "Cantonese (Traditional Hong Kong)":"yue-Hant-HK", + "Mandarin (Simplified, China)":{ + "Google": "cmn-Hans-CN", + "Whisper": "zh", + }, + "Mandarin (Simplified, Hong Kong)":{ + "Google": "cmn-Hans-HK", + "Whisper": "zh", + }, + "Mandarin (Traditional, Taiwan)":{ + "Google": "cmn-Hant-TW", + "Whisper": "zh", + }, + "Cantonese (Traditional Hong Kong)":{ + "Google": "yue-Hant-HK", + "Whisper": "yue", + }, }, "Croatian":{ - "Croatia":"hr-HR", + "Croatia":{ + "Google": "hr-HR", + "Whisper": "hr", + }, }, "Czech":{ - "Czech Republic":"cs-CZ", + "Czech Republic":{ + "Google": "cs-CZ", + "Whisper": "cs", + }, }, "Danish":{ - "Denmark":"da-DK", + "Denmark":{ + "Google": "da-DK", + "Whisper": "da", + }, }, "Dutch":{ - "Netherlands":"nl-NL", + "Netherlands":{ + "Google": "nl-NL", + "Whisper": "nl", + }, }, "English": { - "United States":"en-US", - "United Kingdom":"en-GB", - "Australia":"en-AU", - "Canada":"en-CA", - "India":"en-IN", - "Ireland":"en-IE", - "New Zealand":"en-NZ", - "Philippines":"en-PH", - "South Africa":"en-ZA", + "United States":{ + "Google": "en-US", + "Whisper": "en", + }, + "United Kingdom":{ + "Google": "en-GB", + "Whisper": "en", + }, + "Australia":{ + "Google": "en-AU", + "Whisper": "en", + }, + "Canada":{ + "Google": "en-CA", + "Whisper": "en", + }, + "India":{ + "Google": "en-IN", + "Whisper": "en", + }, + "Ireland":{ + "Google": "en-IE", + "Whisper": "en", + }, + "New Zealand":{ + "Google": "en-NZ", + "Whisper": "en", + }, + "Philippines":{ + "Google": "en-PH", + "Whisper": "en", + }, + "South Africa":{ + "Google": "en-ZA", + "Whisper": "en", + }, }, "Filipino":{ - "Philippines":"fil-PH", + "Philippines":{ + "Google": "fil-PH", + "Whisper": "tl", + }, }, "Finnish":{ - "Finland":"fi-FI", + "Finland":{ + "Google": "fi-FI", + "Whisper": "fi", + }, }, "French":{ - "France":"fr-FR", + "France":{ + "Google": "fr-FR", + "Whisper": "fr", + }, }, "Galician":{ - "Spain":"gl-ES", + "Spain":{ + "Google": "gl-ES", + "Whisper": "gl", + }, }, "German":{ - "Germany":"de-DE", + "Germany":{ + "Google": "de-DE", + "Whisper": "de", + }, }, "Greek":{ - "Greece":"el-GR", + "Greece":{ + "Google": "el-GR", + "Whisper": "el", + }, }, "Hebrew":{ - "Israel":"he-IL", + "Israel":{ + "Google": "he-IL", + "Whisper": "he", + }, }, "Hindi": { - "India":"hi-IN", + "India":{ + "Google": "hi-IN", + "Whisper": "hi", + }, }, "Hungarian":{ - "Hungary":"hu-HU", + "Hungary":{ + "Google": "hu-HU", + "Whisper": "hu", + }, }, "Indonesian":{ - "Indonesia":"id-ID", + "Indonesia":{ + "Google": "id-ID", + "Whisper": "id", + }, }, "Icelandic":{ - "Iceland":"is-IS", + "Iceland":{ + "Google": "is-IS", + "Whisper": "is", + }, }, "Italian":{ - "Italy":"it-IT", - "Switzerland":"it-CH", + "Italy":{ + "Google": "it-IT", + "Whisper": "it", + }, + "Switzerland":{ + "Google": "it-CH", + "Whisper": "it", + }, }, "Japanese":{ - "Japan":"ja-JP", + "Japan":{ + "Google": "ja-JP", + "Whisper": "ja", + }, }, "Korean":{ - "South Korea":"ko-KR", + "South Korea":{ + "Google": "ko-KR", + "Whisper": "ko", + }, }, "Lithuanian":{ - "Lithuania":"lt-LT", + "Lithuania":{ + "Google": "lt-LT", + "Whisper": "lt", + }, }, "Malay":{ - "Malaysia":"ms-MY", + "Malaysia":{ + "Google": "ms-MY", + "Whisper": "ms", + }, }, "Norwegian":{ - "Norway":"nb-NO", + "Norway":{ + "Google": "nb-NO", + "Whisper": "no", + }, }, "Persian":{ - "Iran":"fa-IR", + "Iran":{ + "Google": "fa-IR", + "Whisper": "fa", + }, }, "Polish":{ - "Poland":"pl-PL", + "Poland":{ + "Google": "pl-PL", + "Whisper": "pl", + }, }, "Portuguese":{ - "Brazil":"pt-BR", - "Portugal":"pt-PT", + "Brazil":{ + "Google": "pt-BR", + "Whisper": "pt", + }, + "Portugal":{ + "Google": "pt-PT", + "Whisper": "pt", + }, }, "Romanian":{ - "Romania":"ro-RO", + "Romania":{ + "Google": "ro-RO", + "Whisper": "ro", + }, }, "Russian":{ - "Russia":"ru-RU", + "Russia":{ + "Google": "ru-RU", + "Whisper": "ru", + }, }, "Serbian":{ - "Serbia":"sr-RS", + "Serbia":{ + "Google": "sr-RS", + "Whisper": "sr", + }, }, "Slovak":{ - "Slovakia":"sk-SK", + "Slovakia":{ + "Google": "sk-SK", + "Whisper": "sk", + }, }, "Slovenian":{ - "Slovenia":"sl-SI", + "Slovenia":{ + "Google": "sl-SI", + "Whisper": "sl", + }, }, "Spanish":{ - "Argentina":"es-AR", - "Bolivia":"es-BO", - "Chile":"es-CL", - "Colombia":"es-CO", - "Costa Rica":"es-CR", - "Dominican Republic":"es-DO", - "Ecuador":"es-EC", - "El Salvador":"es-SV", - "Guatemala":"es-GT", - "Honduras":"es-HN", - "Mexico":"es-MX", - "Nicaragua":"es-NI", - "Panama":"es-PA", - "Paraguay":"es-PY", - "Peru":"es-PE", - "Puerto Rico":"es-PR", - "Spain":"es-ES", - "Uruguay":"es-UY", - "United States":"es-US", - "Venezuela":"es-VE", + "Argentina":{ + "Google": "es-AR", + "Whisper": "es", + }, + "Bolivia":{ + "Google": "es-BO", + "Whisper": "es", + }, + "Chile":{ + "Google": "es-CL", + "Whisper": "es", + }, + "Colombia":{ + "Google": "es-CO", + "Whisper": "es", + }, + "Costa Rica":{ + "Google": "es-CR", + "Whisper": "es", + }, + "Dominican Republic":{ + "Google": "es-DO", + "Whisper": "es", + }, + "Ecuador":{ + "Google": "es-EC", + "Whisper": "es", + }, + "El Salvador":{ + "Google": "es-SV", + "Whisper": "es", + }, + "Guatemala":{ + "Google": "es-GT", + "Whisper": "es", + }, + "Honduras":{ + "Google": "es-HN", + "Whisper": "es", + }, + "Mexico":{ + "Google": "es-MX", + "Whisper": "es", + }, + "Nicaragua":{ + "Google": "es-NI", + "Whisper": "es", + }, + "Panama":{ + "Google": "es-PA", + "Whisper": "es", + }, + "Paraguay":{ + "Google": "es-PY", + "Whisper": "es", + }, + "Peru":{ + "Google": "es-PE", + "Whisper": "es", + }, + "Puerto Rico":{ + "Google": "es-PR", + "Whisper": "es", + }, + "Spain":{ + "Google": "es-ES", + "Whisper": "es", + }, + "Uruguay":{ + "Google": "es-UY", + "Whisper": "es", + }, + "United States":{ + "Google": "es-US", + "Whisper": "es", + }, + "Venezuela":{ + "Google": "es-VE", + "Whisper": "es", + }, }, "Swedish":{ - "Sweden":"sv-SE", + "Sweden":{ + "Google": "sv-SE", + "Whisper": "sv", + }, }, "Thai":{ - "Thailand":"th-TH", + "Thailand":{ + "Google": "th-TH", + "Whisper": "th", + }, }, "Turkish":{ - "Turkey":"tr-TR", + "Turkey":{ + "Google": "tr-TR", + "Whisper": "tr", + }, }, "Ukrainian":{ - "Ukraine":"uk-UA", + "Ukraine":{ + "Google": "uk-UA", + "Whisper": "uk", + }, }, "Vietnamese":{ - "Vietnam":"vi-VN", - }, - "Zulu":{ - "South Africa":"zu-ZA" + "Vietnam":{ + "Google": "vi-VN", + "Whisper": "vi", + }, }, } \ No newline at end of file diff --git a/models/transcription/transcription_transcriber.py b/models/transcription/transcription_transcriber.py index fbea0e74..526c12dc 100644 --- a/models/transcription/transcription_transcriber.py +++ b/models/transcription/transcription_transcriber.py @@ -14,7 +14,7 @@ PHRASE_TIMEOUT = 3 MAX_PHRASES = 10 class AudioTranscriber: - def __init__(self, speaker, source, phrase_timeout, max_phrases): + def __init__(self, speaker, source, phrase_timeout, max_phrases, whisper_enabled, whisper_weight_type, whisper_weight_path): self.speaker = speaker self.phrase_timeout = phrase_timeout self.max_phrases = max_phrases @@ -30,47 +30,59 @@ class AudioTranscriber: "new_phrase": True, "process_data_func": self.processSpeakerData if speaker else self.processSpeakerData } - self.whisper_model = WhisperModel("base", device="cpu", device_index=0, compute_type="int8", cpu_threads=4, num_workers=1) + if whisper_enabled is True: + self.whisper_model = WhisperModel( + model_size_or_path=whisper_weight_type, + device="cpu", + device_index=0, + compute_type="int8", + cpu_threads=4, + num_workers=1, + download_root=whisper_weight_path) + else: + self.whisper_model = None - def transcribeAudioQueue(self, audio_queue, language, country): + def transcribeAudioQueue(self, recognizer, audio_queue, language, country): # while True: audio, time_spoken = audio_queue.get() self.updateLastSampleAndPhraseStatus(audio, time_spoken) text = '' try: - # fd, path = tempfile.mkstemp(suffix=".wav") - # os.close(fd) - audio_data = self.audio_sources["process_data_func"]() - text = self.audio_recognizer.recognize_google(audio_data, language=transcription_lang[language][country]) + # Whisperが使用できない場合はGoogle Speech-to-Textを使用する + if recognizer == "Whisper": + if self.whisper_model is None: + recognizer = "Google" - audio_data = np.frombuffer(audio_data.get_raw_data(convert_rate=16000, convert_width=2), np.int16).flatten().astype(np.float32) / 32768.0 - if isinstance(audio_data, torch.Tensor): - audio_data = audio_data.detach().numpy() - segments, _ = self.whisper_model.transcribe( - audio_data, - beam_size=5, - temperature=0.0, - log_prob_threshold=-0.8, - no_speech_threshold=0.6, - language="ja", - word_timestamps=False, - without_timestamps=True, - task="transcribe", - vad_filter=False, - ) - _text = "" - for s in segments: - if s.avg_logprob < -0.8 or s.no_speech_prob > 0.6: - continue - _text += s.text - print(_text) + audio_data = self.audio_sources["process_data_func"]() + match recognizer: + case "Google": + text = self.audio_recognizer.recognize_google(audio_data, language=transcription_lang[language][country][recognizer]) + case "Whisper": + audio_data = np.frombuffer(audio_data.get_raw_data(convert_rate=16000, convert_width=2), np.int16).flatten().astype(np.float32) / 32768.0 + if isinstance(audio_data, torch.Tensor): + audio_data = audio_data.detach().numpy() + segments, _ = self.whisper_model.transcribe( + audio_data, + beam_size=5, + temperature=0.0, + log_prob_threshold=-0.8, + no_speech_threshold=0.6, + language=transcription_lang[language][country][recognizer], + word_timestamps=False, + without_timestamps=True, + task="transcribe", + vad_filter=False, + ) + for s in segments: + if s.avg_logprob < -0.8 or s.no_speech_prob > 0.6: + continue + text += s.text except Exception: pass finally: pass - # os.unlink(path) if text != '': self.updateTranscript(text) diff --git a/models/transcription/transcription_utils.py b/models/transcription/transcription_utils.py index f40defeb..8de17e7e 100644 --- a/models/transcription/transcription_utils.py +++ b/models/transcription/transcription_utils.py @@ -1,4 +1,8 @@ from pyaudiowpatch import PyAudio, paWASAPI +from faster_whisper.utils import download_model +import logging +logger = logging.getLogger('faster_whisper') +logger.setLevel(logging.CRITICAL) def getInputDevices(): devices = {} @@ -44,4 +48,38 @@ def getDefaultOutputDevice(): if default_speakers["name"] in loopback["name"]: default_device = loopback return default_device - return {"name":"NoDevice"} \ No newline at end of file + return {"name":"NoDevice"} + +def downloadWhisperWeight(weight_type, path): + result = False + try: + download_model( + weight_type, + cache_dir=path) + result = True + except Exception: + pass + return result + +def checkWhisperWeight(weight_type, path): + result = False + try: + result = download_model( + weight_type, + local_files_only=True, + cache_dir=path) + result = True + except Exception: + pass + return result + +if __name__ == "__main__": + + + downloadWhisperWeight("base", "./weight/whisper/") + + from faster_whisper import WhisperModel + whisper_model = WhisperModel("base", device="cpu", device_index=0, compute_type="int8", cpu_threads=4, num_workers=1, download_root="./weight/whisper/") + + print(checkWhisperWeight("base", "./weight/whisper/")) + print(checkWhisperWeight("tiny", "./weight/whisper/")) \ No newline at end of file diff --git a/view.py b/view.py index 34711688..cf90dcfa 100644 --- a/view.py +++ b/view.py @@ -280,7 +280,7 @@ class View(): VAR_DESC_CTRANSLATE2_WEIGHT_TYPE=StringVar(value=i18n.t("config_window.ctranslate2_weight_type.desc")), DICT_CTRANSLATE2_WEIGHT_TYPE=self.getSelectableCtranslate2WeightTypeDict(), CALLBACK_SET_CTRANSLATE2_WEIGHT_TYPE=None, - VAR_CTRANSLATE2_WEIGHT_TYPE=StringVar(value=self.getSelectableCtranslate2WeightTypeDict()[config.WEIGHT_TYPE]), + VAR_CTRANSLATE2_WEIGHT_TYPE=StringVar(value=self.getSelectableCtranslate2WeightTypeDict()[config.CTRANSLATE2_WEIGHT_TYPE]), VAR_LABEL_DEEPL_AUTH_KEY=StringVar(value=i18n.t( "config_window.deepl_auth_key.label")), VAR_DESC_DEEPL_AUTH_KEY=StringVar( @@ -1069,7 +1069,7 @@ class View(): self.view_variable.VAR_CTRANSLATE2_WEIGHT_TYPE.set(self.getSelectableCtranslate2WeightTypeDict()[selected_weight_type]) def setLatestCTranslate2WeightType(self): - selected_weight_type = self.getSelectableCtranslate2WeightTypeDict()[config.WEIGHT_TYPE] + selected_weight_type = self.getSelectableCtranslate2WeightTypeDict()[config.CTRANSLATE2_WEIGHT_TYPE] self.view_variable.VAR_CTRANSLATE2_WEIGHT_TYPE.set(selected_weight_type)