From af3fe1f0f9523e73d5b9469486da19f75cbbf452 Mon Sep 17 00:00:00 2001 From: misyaguziya <53165965+misyaguziya@users.noreply.github.com> Date: Wed, 23 Oct 2024 13:41:34 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=8D=EF=B8=8F[Update]=20Model:=20cpu/cu?= =?UTF-8?q?da=E3=82=92translation=E3=82=82=E3=81=97=E3=81=8F=E3=81=AFtrans?= =?UTF-8?q?cription=E3=81=A7=E9=81=B8=E6=8A=9E=E3=81=A7=E3=81=8D=E3=82=8B?= =?UTF-8?q?=E3=82=88=E3=81=86=E3=81=AB=E5=AE=9F=E8=A3=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src-python/config.py | 26 +++++++++++++++++- src-python/model.py | 10 ++++++- .../transcription_transcriber.py | 4 +-- .../transcription/transcription_whisper.py | 12 +++++---- .../translation/translation_translator.py | 10 ++++--- src-python/webui_controller.py | 27 +++++++++++++++++++ src-python/webui_mainloop.py | 8 ++++++ 7 files changed, 84 insertions(+), 13 deletions(-) diff --git a/src-python/config.py b/src-python/config.py index 54d8b720..38459a95 100644 --- a/src-python/config.py +++ b/src-python/config.py @@ -724,6 +724,28 @@ class Config: self._USE_WHISPER_FEATURE = value saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) + @property + @json_serializable('SELECTED_TRANSLATION_COMPUTE_DEVICE') + def SELECTED_TRANSLATION_COMPUTE_DEVICE(self): + return self._SELECTED_TRANSLATION_COMPUTE_DEVICE + + @SELECTED_TRANSLATION_COMPUTE_DEVICE.setter + def SELECTED_TRANSLATION_COMPUTE_DEVICE(self, value): + if isinstance(value, dict): + self._SELECTED_TRANSLATION_COMPUTE_DEVICE = value + saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) + + @property + @json_serializable('SELECTED_TRANSCRIPTION_COMPUTE_DEVICE') + def SELECTED_TRANSCRIPTION_COMPUTE_DEVICE(self): + return self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE + + @SELECTED_TRANSCRIPTION_COMPUTE_DEVICE.setter + def SELECTED_TRANSCRIPTION_COMPUTE_DEVICE(self, value): + if isinstance(value, dict): + self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = value + saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) + @property @json_serializable('CTRANSLATE2_WEIGHT_TYPE') def CTRANSLATE2_WEIGHT_TYPE(self): @@ -1105,8 +1127,10 @@ class Config: } self._USE_EXCLUDE_WORDS = True self._USE_TRANSLATION_FEATURE = True - self._CTRANSLATE2_WEIGHT_TYPE = "Small" self._USE_WHISPER_FEATURE = False + self._SELECTED_TRANSLATION_COMPUTE_DEVICE = {"type": "cpu", "index": 0, "name":"cpu"} + self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = {"type": "cpu", "index": 0, "name":"cpu"} + self._CTRANSLATE2_WEIGHT_TYPE = "Small" self._WHISPER_WEIGHT_TYPE = "base" self._SEND_MESSAGE_FORMAT = "[message]" self._SEND_MESSAGE_FORMAT_WITH_T = "[message]([translation])" diff --git a/src-python/model.py b/src-python/model.py index fa22b622..0d659d18 100644 --- a/src-python/model.py +++ b/src-python/model.py @@ -109,7 +109,11 @@ class Model: return checkCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) def changeTranslatorCTranslate2Model(self): - self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) + self.translator.changeCTranslate2Model( + config.PATH_LOCAL, + config.CTRANSLATE2_WEIGHT_TYPE, + config.SELECTED_TRANSLATION_COMPUTE_DEVICE["type"], + config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"]) def downloadCTranslate2ModelWeight(self, callbackFunc=None): return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callbackFunc) @@ -425,6 +429,8 @@ class Model: transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE, root=config.PATH_LOCAL, whisper_weight_type=config.WHISPER_WEIGHT_TYPE, + device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["type"], + device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"], ) def sendMicTranscript(): try: @@ -587,6 +593,8 @@ class Model: transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE, root=config.PATH_LOCAL, whisper_weight_type=config.WHISPER_WEIGHT_TYPE, + device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["type"], + device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"], ) def sendSpeakerTranscript(): try: diff --git a/src-python/models/transcription/transcription_transcriber.py b/src-python/models/transcription/transcription_transcriber.py index bfbb24ca..331a02d8 100644 --- a/src-python/models/transcription/transcription_transcriber.py +++ b/src-python/models/transcription/transcription_transcriber.py @@ -18,7 +18,7 @@ PHRASE_TIMEOUT = 3 MAX_PHRASES = 10 class AudioTranscriber: - def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None): + def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0): self.speaker = speaker self.phrase_timeout = phrase_timeout self.max_phrases = max_phrases @@ -38,7 +38,7 @@ class AudioTranscriber: } if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True: - self.whisper_model = getWhisperModel(root, whisper_weight_type) + self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index) self.transcription_engine = "Whisper" def transcribeAudioQueue(self, audio_queue, language, country, avg_logprob=-0.8, no_speech_prob=0.6): diff --git a/src-python/models/transcription/transcription_whisper.py b/src-python/models/transcription/transcription_whisper.py index 77976ef0..5e6f00bf 100644 --- a/src-python/models/transcription/transcription_whisper.py +++ b/src-python/models/transcription/transcription_whisper.py @@ -1,6 +1,7 @@ from os import path as os_path, makedirs as os_makedirs from requests import get as requests_get from typing import Callable +import torch import huggingface_hub from faster_whisper import WhisperModel import logging @@ -51,7 +52,7 @@ def checkWhisperWeight(root, weight_type): try: WhisperModel( path, - device="cuda", + device="cpu", device_index=0, compute_type="int8", cpu_threads=4, @@ -75,13 +76,14 @@ def downloadWhisperWeight(root, weight_type, callbackFunc): url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename) downloadFile(url, file_path, func=callbackFunc) -def getWhisperModel(root, weight_type): +def getWhisperModel(root, weight_type, device="cpu", device_index=0): path = os_path.join(root, "weights", "whisper", weight_type) + compute_type = "int8" if device == "cpu" else "float16" return WhisperModel( path, - device="cuda", - device_index=0, - compute_type="int8", + device=device, + device_index=device_index, + compute_type=compute_type, cpu_threads=4, num_workers=1, local_files_only=True, diff --git a/src-python/models/translation/translation_translator.py b/src-python/models/translation/translation_translator.py index 6c22a2c2..2e6049c3 100644 --- a/src-python/models/translation/translation_translator.py +++ b/src-python/models/translation/translation_translator.py @@ -29,17 +29,19 @@ class Translator(): result = False return result - def changeCTranslate2Model(self, path, model_type): + def changeCTranslate2Model(self, path, model_type, device="cpu", device_index=0): self.is_loaded_ctranslate2_model = False directory_name = ctranslate2_weights[model_type]["directory_name"] tokenizer = ctranslate2_weights[model_type]["tokenizer"] weight_path = os_path.join(path, "weights", "ctranslate2", directory_name) tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer") + + compute_type = "int8" if device == "cpu" else "float16" self.ctranslate2_translator = ctranslate2.Translator( weight_path, - device="cuda", - device_index=0, - compute_type="int8", + device=device, + device_index=device_index, + compute_type=compute_type, inter_threads=1, intra_threads=4 ) diff --git a/src-python/webui_controller.py b/src-python/webui_controller.py index 9664cbd4..6621def1 100644 --- a/src-python/webui_controller.py +++ b/src-python/webui_controller.py @@ -7,6 +7,7 @@ from device_manager import device_manager from config import config from model import model from utils import isUniqueStrings, printLog +import torch class Controller: def __init__(self) -> None: @@ -363,10 +364,36 @@ class Controller: def getMessageBoxRatioRange(*args, **kwargs) -> dict: return {"status":200, "result":config.MESSAGE_BOX_RATIO_RANGE} + @staticmethod + def getComputeDeviceList(*args, **kwargs) -> dict: + device_list = [{"type":"cuda", "device_index": i, "name": torch.cuda.get_device_name(i)} for i in range(torch.cuda.device_count())] + device_list.append({"type":"cpu", "device_index": 0, "name": "cpu"}) + return {"status":200, "result":device_list} + + @staticmethod + def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict: + return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE} + + @staticmethod + def setSelectedTranslationComputeDevice(device:str, *args, **kwargs) -> dict: + printLog("setSelectedTranslationComputeDevice", device) + config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device + return {"status":200,"result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE} + @staticmethod def getSelectableCtranslate2WeightTypeDict(*args, **kwargs) -> dict: return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT} + @staticmethod + def getSelectedTranscriptionComputeDevice(*args, **kwargs) -> dict: + return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE} + + @staticmethod + def setSelectedTranscriptionComputeDevice(device:str, *args, **kwargs) -> dict: + printLog("setSelectedTranscriptionComputeDevice", device) + config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = device + return {"status":200,"result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE} + @staticmethod def getSelectableWhisperModelTypeDict(*args, **kwargs) -> dict: return {"status":200, "result":config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT} diff --git a/src-python/webui_mainloop.py b/src-python/webui_mainloop.py index ab3f5b88..f49872cf 100644 --- a/src-python/webui_mainloop.py +++ b/src-python/webui_mainloop.py @@ -128,6 +128,10 @@ mapping = { "/set/enable/use_translation_feature": {"status": True, "variable":controller.setEnableUseTranslationFeature}, "/set/disable/use_translation_feature": {"status": True, "variable":controller.setDisableUseTranslationFeature}, + "/get/data/translation_compute_device_dict": {"status": True, "variable":controller.getComputeDeviceList}, + "/get/data/selected_translation_compute_device": {"status": True, "variable":controller.getSelectedTranslationComputeDevice}, + "/set/data/selected_translation_compute_device": {"status": True, "variable":controller.setSelectedTranslationComputeDevice}, + "/get/data/selectable_ctranslate2_weight_type_dict": {"status": True, "variable":controller.getSelectableCtranslate2WeightTypeDict}, "/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType}, @@ -229,6 +233,10 @@ mapping = { "/set/enable/check_speaker_threshold": {"status": True, "variable":controller.setEnableCheckSpeakerThreshold}, "/set/disable/check_speaker_threshold": {"status": True, "variable":controller.setDisableCheckSpeakerThreshold}, + "/get/data/transcription_compute_device_dict": {"status": True, "variable":controller.getComputeDeviceList}, + "/get/data/selected_transcription_compute_device": {"status": True, "variable":controller.getSelectedTranscriptionComputeDevice}, + "/set/data/selected_transcription_compute_device": {"status": True, "variable":controller.setSelectedTranscriptionComputeDevice}, + "/get/data/selectable_whisper_weight_type_dict": {"status": True, "variable":controller.getSelectableWhisperModelTypeDict}, "/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType},