👍️[Update] Model: cpu/cudaをtranslationもしくはtranscriptionで選択できるように実装

This commit is contained in:
misyaguziya
2024-10-23 13:41:34 +09:00
parent 2136865493
commit af3fe1f0f9
7 changed files with 84 additions and 13 deletions

View File

@@ -724,6 +724,28 @@ class Config:
self._USE_WHISPER_FEATURE = value self._USE_WHISPER_FEATURE = value
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value) saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value)
@property
@json_serializable('SELECTED_TRANSLATION_COMPUTE_DEVICE')
def SELECTED_TRANSLATION_COMPUTE_DEVICE(self):
return self._SELECTED_TRANSLATION_COMPUTE_DEVICE
@SELECTED_TRANSLATION_COMPUTE_DEVICE.setter
def SELECTED_TRANSLATION_COMPUTE_DEVICE(self, value):
if isinstance(value, dict):
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = value
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value)
@property
@json_serializable('SELECTED_TRANSCRIPTION_COMPUTE_DEVICE')
def SELECTED_TRANSCRIPTION_COMPUTE_DEVICE(self):
return self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE
@SELECTED_TRANSCRIPTION_COMPUTE_DEVICE.setter
def SELECTED_TRANSCRIPTION_COMPUTE_DEVICE(self, value):
if isinstance(value, dict):
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = value
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value)
@property @property
@json_serializable('CTRANSLATE2_WEIGHT_TYPE') @json_serializable('CTRANSLATE2_WEIGHT_TYPE')
def CTRANSLATE2_WEIGHT_TYPE(self): def CTRANSLATE2_WEIGHT_TYPE(self):
@@ -1105,8 +1127,10 @@ class Config:
} }
self._USE_EXCLUDE_WORDS = True self._USE_EXCLUDE_WORDS = True
self._USE_TRANSLATION_FEATURE = True self._USE_TRANSLATION_FEATURE = True
self._CTRANSLATE2_WEIGHT_TYPE = "Small"
self._USE_WHISPER_FEATURE = False self._USE_WHISPER_FEATURE = False
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = {"type": "cpu", "index": 0, "name":"cpu"}
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = {"type": "cpu", "index": 0, "name":"cpu"}
self._CTRANSLATE2_WEIGHT_TYPE = "Small"
self._WHISPER_WEIGHT_TYPE = "base" self._WHISPER_WEIGHT_TYPE = "base"
self._SEND_MESSAGE_FORMAT = "[message]" self._SEND_MESSAGE_FORMAT = "[message]"
self._SEND_MESSAGE_FORMAT_WITH_T = "[message]([translation])" self._SEND_MESSAGE_FORMAT_WITH_T = "[message]([translation])"

View File

@@ -109,7 +109,11 @@ class Model:
return checkCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) return checkCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE)
def changeTranslatorCTranslate2Model(self): def changeTranslatorCTranslate2Model(self):
self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE) self.translator.changeCTranslate2Model(
config.PATH_LOCAL,
config.CTRANSLATE2_WEIGHT_TYPE,
config.SELECTED_TRANSLATION_COMPUTE_DEVICE["type"],
config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"])
def downloadCTranslate2ModelWeight(self, callbackFunc=None): def downloadCTranslate2ModelWeight(self, callbackFunc=None):
return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callbackFunc) return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callbackFunc)
@@ -425,6 +429,8 @@ class Model:
transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE, transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE,
root=config.PATH_LOCAL, root=config.PATH_LOCAL,
whisper_weight_type=config.WHISPER_WEIGHT_TYPE, whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["type"],
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
) )
def sendMicTranscript(): def sendMicTranscript():
try: try:
@@ -587,6 +593,8 @@ class Model:
transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE, transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE,
root=config.PATH_LOCAL, root=config.PATH_LOCAL,
whisper_weight_type=config.WHISPER_WEIGHT_TYPE, whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["type"],
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
) )
def sendSpeakerTranscript(): def sendSpeakerTranscript():
try: try:

View File

@@ -18,7 +18,7 @@ PHRASE_TIMEOUT = 3
MAX_PHRASES = 10 MAX_PHRASES = 10
class AudioTranscriber: class AudioTranscriber:
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None): def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0):
self.speaker = speaker self.speaker = speaker
self.phrase_timeout = phrase_timeout self.phrase_timeout = phrase_timeout
self.max_phrases = max_phrases self.max_phrases = max_phrases
@@ -38,7 +38,7 @@ class AudioTranscriber:
} }
if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True: if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True:
self.whisper_model = getWhisperModel(root, whisper_weight_type) self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index)
self.transcription_engine = "Whisper" self.transcription_engine = "Whisper"
def transcribeAudioQueue(self, audio_queue, language, country, avg_logprob=-0.8, no_speech_prob=0.6): def transcribeAudioQueue(self, audio_queue, language, country, avg_logprob=-0.8, no_speech_prob=0.6):

View File

@@ -1,6 +1,7 @@
from os import path as os_path, makedirs as os_makedirs from os import path as os_path, makedirs as os_makedirs
from requests import get as requests_get from requests import get as requests_get
from typing import Callable from typing import Callable
import torch
import huggingface_hub import huggingface_hub
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
import logging import logging
@@ -51,7 +52,7 @@ def checkWhisperWeight(root, weight_type):
try: try:
WhisperModel( WhisperModel(
path, path,
device="cuda", device="cpu",
device_index=0, device_index=0,
compute_type="int8", compute_type="int8",
cpu_threads=4, cpu_threads=4,
@@ -75,13 +76,14 @@ def downloadWhisperWeight(root, weight_type, callbackFunc):
url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename) url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename)
downloadFile(url, file_path, func=callbackFunc) downloadFile(url, file_path, func=callbackFunc)
def getWhisperModel(root, weight_type): def getWhisperModel(root, weight_type, device="cpu", device_index=0):
path = os_path.join(root, "weights", "whisper", weight_type) path = os_path.join(root, "weights", "whisper", weight_type)
compute_type = "int8" if device == "cpu" else "float16"
return WhisperModel( return WhisperModel(
path, path,
device="cuda", device=device,
device_index=0, device_index=device_index,
compute_type="int8", compute_type=compute_type,
cpu_threads=4, cpu_threads=4,
num_workers=1, num_workers=1,
local_files_only=True, local_files_only=True,

View File

@@ -29,17 +29,19 @@ class Translator():
result = False result = False
return result return result
def changeCTranslate2Model(self, path, model_type): def changeCTranslate2Model(self, path, model_type, device="cpu", device_index=0):
self.is_loaded_ctranslate2_model = False self.is_loaded_ctranslate2_model = False
directory_name = ctranslate2_weights[model_type]["directory_name"] directory_name = ctranslate2_weights[model_type]["directory_name"]
tokenizer = ctranslate2_weights[model_type]["tokenizer"] tokenizer = ctranslate2_weights[model_type]["tokenizer"]
weight_path = os_path.join(path, "weights", "ctranslate2", directory_name) weight_path = os_path.join(path, "weights", "ctranslate2", directory_name)
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer") tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
compute_type = "int8" if device == "cpu" else "float16"
self.ctranslate2_translator = ctranslate2.Translator( self.ctranslate2_translator = ctranslate2.Translator(
weight_path, weight_path,
device="cuda", device=device,
device_index=0, device_index=device_index,
compute_type="int8", compute_type=compute_type,
inter_threads=1, inter_threads=1,
intra_threads=4 intra_threads=4
) )

View File

@@ -7,6 +7,7 @@ from device_manager import device_manager
from config import config from config import config
from model import model from model import model
from utils import isUniqueStrings, printLog from utils import isUniqueStrings, printLog
import torch
class Controller: class Controller:
def __init__(self) -> None: def __init__(self) -> None:
@@ -363,10 +364,36 @@ class Controller:
def getMessageBoxRatioRange(*args, **kwargs) -> dict: def getMessageBoxRatioRange(*args, **kwargs) -> dict:
return {"status":200, "result":config.MESSAGE_BOX_RATIO_RANGE} return {"status":200, "result":config.MESSAGE_BOX_RATIO_RANGE}
@staticmethod
def getComputeDeviceList(*args, **kwargs) -> dict:
device_list = [{"type":"cuda", "device_index": i, "name": torch.cuda.get_device_name(i)} for i in range(torch.cuda.device_count())]
device_list.append({"type":"cpu", "device_index": 0, "name": "cpu"})
return {"status":200, "result":device_list}
@staticmethod
def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict:
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
@staticmethod
def setSelectedTranslationComputeDevice(device:str, *args, **kwargs) -> dict:
printLog("setSelectedTranslationComputeDevice", device)
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device
return {"status":200,"result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
@staticmethod @staticmethod
def getSelectableCtranslate2WeightTypeDict(*args, **kwargs) -> dict: def getSelectableCtranslate2WeightTypeDict(*args, **kwargs) -> dict:
return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT} return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT}
@staticmethod
def getSelectedTranscriptionComputeDevice(*args, **kwargs) -> dict:
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
@staticmethod
def setSelectedTranscriptionComputeDevice(device:str, *args, **kwargs) -> dict:
printLog("setSelectedTranscriptionComputeDevice", device)
config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = device
return {"status":200,"result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
@staticmethod @staticmethod
def getSelectableWhisperModelTypeDict(*args, **kwargs) -> dict: def getSelectableWhisperModelTypeDict(*args, **kwargs) -> dict:
return {"status":200, "result":config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT} return {"status":200, "result":config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT}

View File

@@ -128,6 +128,10 @@ mapping = {
"/set/enable/use_translation_feature": {"status": True, "variable":controller.setEnableUseTranslationFeature}, "/set/enable/use_translation_feature": {"status": True, "variable":controller.setEnableUseTranslationFeature},
"/set/disable/use_translation_feature": {"status": True, "variable":controller.setDisableUseTranslationFeature}, "/set/disable/use_translation_feature": {"status": True, "variable":controller.setDisableUseTranslationFeature},
"/get/data/translation_compute_device_dict": {"status": True, "variable":controller.getComputeDeviceList},
"/get/data/selected_translation_compute_device": {"status": True, "variable":controller.getSelectedTranslationComputeDevice},
"/set/data/selected_translation_compute_device": {"status": True, "variable":controller.setSelectedTranslationComputeDevice},
"/get/data/selectable_ctranslate2_weight_type_dict": {"status": True, "variable":controller.getSelectableCtranslate2WeightTypeDict}, "/get/data/selectable_ctranslate2_weight_type_dict": {"status": True, "variable":controller.getSelectableCtranslate2WeightTypeDict},
"/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType}, "/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType},
@@ -229,6 +233,10 @@ mapping = {
"/set/enable/check_speaker_threshold": {"status": True, "variable":controller.setEnableCheckSpeakerThreshold}, "/set/enable/check_speaker_threshold": {"status": True, "variable":controller.setEnableCheckSpeakerThreshold},
"/set/disable/check_speaker_threshold": {"status": True, "variable":controller.setDisableCheckSpeakerThreshold}, "/set/disable/check_speaker_threshold": {"status": True, "variable":controller.setDisableCheckSpeakerThreshold},
"/get/data/transcription_compute_device_dict": {"status": True, "variable":controller.getComputeDeviceList},
"/get/data/selected_transcription_compute_device": {"status": True, "variable":controller.getSelectedTranscriptionComputeDevice},
"/set/data/selected_transcription_compute_device": {"status": True, "variable":controller.setSelectedTranscriptionComputeDevice},
"/get/data/selectable_whisper_weight_type_dict": {"status": True, "variable":controller.getSelectableWhisperModelTypeDict}, "/get/data/selectable_whisper_weight_type_dict": {"status": True, "variable":controller.getSelectableWhisperModelTypeDict},
"/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType}, "/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType},