👍️[Update] Model: cpu/cudaをtranslationもしくはtranscriptionで選択できるように実装
This commit is contained in:
@@ -724,6 +724,28 @@ class Config:
|
|||||||
self._USE_WHISPER_FEATURE = value
|
self._USE_WHISPER_FEATURE = value
|
||||||
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value)
|
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@json_serializable('SELECTED_TRANSLATION_COMPUTE_DEVICE')
|
||||||
|
def SELECTED_TRANSLATION_COMPUTE_DEVICE(self):
|
||||||
|
return self._SELECTED_TRANSLATION_COMPUTE_DEVICE
|
||||||
|
|
||||||
|
@SELECTED_TRANSLATION_COMPUTE_DEVICE.setter
|
||||||
|
def SELECTED_TRANSLATION_COMPUTE_DEVICE(self, value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = value
|
||||||
|
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@json_serializable('SELECTED_TRANSCRIPTION_COMPUTE_DEVICE')
|
||||||
|
def SELECTED_TRANSCRIPTION_COMPUTE_DEVICE(self):
|
||||||
|
return self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE
|
||||||
|
|
||||||
|
@SELECTED_TRANSCRIPTION_COMPUTE_DEVICE.setter
|
||||||
|
def SELECTED_TRANSCRIPTION_COMPUTE_DEVICE(self, value):
|
||||||
|
if isinstance(value, dict):
|
||||||
|
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = value
|
||||||
|
saveJson(self.PATH_CONFIG, inspect.currentframe().f_code.co_name, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@json_serializable('CTRANSLATE2_WEIGHT_TYPE')
|
@json_serializable('CTRANSLATE2_WEIGHT_TYPE')
|
||||||
def CTRANSLATE2_WEIGHT_TYPE(self):
|
def CTRANSLATE2_WEIGHT_TYPE(self):
|
||||||
@@ -1105,8 +1127,10 @@ class Config:
|
|||||||
}
|
}
|
||||||
self._USE_EXCLUDE_WORDS = True
|
self._USE_EXCLUDE_WORDS = True
|
||||||
self._USE_TRANSLATION_FEATURE = True
|
self._USE_TRANSLATION_FEATURE = True
|
||||||
self._CTRANSLATE2_WEIGHT_TYPE = "Small"
|
|
||||||
self._USE_WHISPER_FEATURE = False
|
self._USE_WHISPER_FEATURE = False
|
||||||
|
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = {"type": "cpu", "index": 0, "name":"cpu"}
|
||||||
|
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = {"type": "cpu", "index": 0, "name":"cpu"}
|
||||||
|
self._CTRANSLATE2_WEIGHT_TYPE = "Small"
|
||||||
self._WHISPER_WEIGHT_TYPE = "base"
|
self._WHISPER_WEIGHT_TYPE = "base"
|
||||||
self._SEND_MESSAGE_FORMAT = "[message]"
|
self._SEND_MESSAGE_FORMAT = "[message]"
|
||||||
self._SEND_MESSAGE_FORMAT_WITH_T = "[message]([translation])"
|
self._SEND_MESSAGE_FORMAT_WITH_T = "[message]([translation])"
|
||||||
|
|||||||
@@ -109,7 +109,11 @@ class Model:
|
|||||||
return checkCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE)
|
return checkCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE)
|
||||||
|
|
||||||
def changeTranslatorCTranslate2Model(self):
|
def changeTranslatorCTranslate2Model(self):
|
||||||
self.translator.changeCTranslate2Model(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE)
|
self.translator.changeCTranslate2Model(
|
||||||
|
config.PATH_LOCAL,
|
||||||
|
config.CTRANSLATE2_WEIGHT_TYPE,
|
||||||
|
config.SELECTED_TRANSLATION_COMPUTE_DEVICE["type"],
|
||||||
|
config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"])
|
||||||
|
|
||||||
def downloadCTranslate2ModelWeight(self, callbackFunc=None):
|
def downloadCTranslate2ModelWeight(self, callbackFunc=None):
|
||||||
return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callbackFunc)
|
return downloadCTranslate2Weight(config.PATH_LOCAL, config.CTRANSLATE2_WEIGHT_TYPE, callbackFunc)
|
||||||
@@ -425,6 +429,8 @@ class Model:
|
|||||||
transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE,
|
transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE,
|
||||||
root=config.PATH_LOCAL,
|
root=config.PATH_LOCAL,
|
||||||
whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
|
whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
|
||||||
|
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["type"],
|
||||||
|
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
|
||||||
)
|
)
|
||||||
def sendMicTranscript():
|
def sendMicTranscript():
|
||||||
try:
|
try:
|
||||||
@@ -587,6 +593,8 @@ class Model:
|
|||||||
transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE,
|
transcription_engine=config.SELECTED_TRANSCRIPTION_ENGINE,
|
||||||
root=config.PATH_LOCAL,
|
root=config.PATH_LOCAL,
|
||||||
whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
|
whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
|
||||||
|
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["type"],
|
||||||
|
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
|
||||||
)
|
)
|
||||||
def sendSpeakerTranscript():
|
def sendSpeakerTranscript():
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ PHRASE_TIMEOUT = 3
|
|||||||
MAX_PHRASES = 10
|
MAX_PHRASES = 10
|
||||||
|
|
||||||
class AudioTranscriber:
|
class AudioTranscriber:
|
||||||
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None):
|
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0):
|
||||||
self.speaker = speaker
|
self.speaker = speaker
|
||||||
self.phrase_timeout = phrase_timeout
|
self.phrase_timeout = phrase_timeout
|
||||||
self.max_phrases = max_phrases
|
self.max_phrases = max_phrases
|
||||||
@@ -38,7 +38,7 @@ class AudioTranscriber:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True:
|
if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True:
|
||||||
self.whisper_model = getWhisperModel(root, whisper_weight_type)
|
self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index)
|
||||||
self.transcription_engine = "Whisper"
|
self.transcription_engine = "Whisper"
|
||||||
|
|
||||||
def transcribeAudioQueue(self, audio_queue, language, country, avg_logprob=-0.8, no_speech_prob=0.6):
|
def transcribeAudioQueue(self, audio_queue, language, country, avg_logprob=-0.8, no_speech_prob=0.6):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from os import path as os_path, makedirs as os_makedirs
|
from os import path as os_path, makedirs as os_makedirs
|
||||||
from requests import get as requests_get
|
from requests import get as requests_get
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
import torch
|
||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
import logging
|
import logging
|
||||||
@@ -51,7 +52,7 @@ def checkWhisperWeight(root, weight_type):
|
|||||||
try:
|
try:
|
||||||
WhisperModel(
|
WhisperModel(
|
||||||
path,
|
path,
|
||||||
device="cuda",
|
device="cpu",
|
||||||
device_index=0,
|
device_index=0,
|
||||||
compute_type="int8",
|
compute_type="int8",
|
||||||
cpu_threads=4,
|
cpu_threads=4,
|
||||||
@@ -75,13 +76,14 @@ def downloadWhisperWeight(root, weight_type, callbackFunc):
|
|||||||
url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename)
|
url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename)
|
||||||
downloadFile(url, file_path, func=callbackFunc)
|
downloadFile(url, file_path, func=callbackFunc)
|
||||||
|
|
||||||
def getWhisperModel(root, weight_type):
|
def getWhisperModel(root, weight_type, device="cpu", device_index=0):
|
||||||
path = os_path.join(root, "weights", "whisper", weight_type)
|
path = os_path.join(root, "weights", "whisper", weight_type)
|
||||||
|
compute_type = "int8" if device == "cpu" else "float16"
|
||||||
return WhisperModel(
|
return WhisperModel(
|
||||||
path,
|
path,
|
||||||
device="cuda",
|
device=device,
|
||||||
device_index=0,
|
device_index=device_index,
|
||||||
compute_type="int8",
|
compute_type=compute_type,
|
||||||
cpu_threads=4,
|
cpu_threads=4,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
|
|||||||
@@ -29,17 +29,19 @@ class Translator():
|
|||||||
result = False
|
result = False
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def changeCTranslate2Model(self, path, model_type):
|
def changeCTranslate2Model(self, path, model_type, device="cpu", device_index=0):
|
||||||
self.is_loaded_ctranslate2_model = False
|
self.is_loaded_ctranslate2_model = False
|
||||||
directory_name = ctranslate2_weights[model_type]["directory_name"]
|
directory_name = ctranslate2_weights[model_type]["directory_name"]
|
||||||
tokenizer = ctranslate2_weights[model_type]["tokenizer"]
|
tokenizer = ctranslate2_weights[model_type]["tokenizer"]
|
||||||
weight_path = os_path.join(path, "weights", "ctranslate2", directory_name)
|
weight_path = os_path.join(path, "weights", "ctranslate2", directory_name)
|
||||||
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
||||||
|
|
||||||
|
compute_type = "int8" if device == "cpu" else "float16"
|
||||||
self.ctranslate2_translator = ctranslate2.Translator(
|
self.ctranslate2_translator = ctranslate2.Translator(
|
||||||
weight_path,
|
weight_path,
|
||||||
device="cuda",
|
device=device,
|
||||||
device_index=0,
|
device_index=device_index,
|
||||||
compute_type="int8",
|
compute_type=compute_type,
|
||||||
inter_threads=1,
|
inter_threads=1,
|
||||||
intra_threads=4
|
intra_threads=4
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from device_manager import device_manager
|
|||||||
from config import config
|
from config import config
|
||||||
from model import model
|
from model import model
|
||||||
from utils import isUniqueStrings, printLog
|
from utils import isUniqueStrings, printLog
|
||||||
|
import torch
|
||||||
|
|
||||||
class Controller:
|
class Controller:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@@ -363,10 +364,36 @@ class Controller:
|
|||||||
def getMessageBoxRatioRange(*args, **kwargs) -> dict:
|
def getMessageBoxRatioRange(*args, **kwargs) -> dict:
|
||||||
return {"status":200, "result":config.MESSAGE_BOX_RATIO_RANGE}
|
return {"status":200, "result":config.MESSAGE_BOX_RATIO_RANGE}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def getComputeDeviceList(*args, **kwargs) -> dict:
|
||||||
|
device_list = [{"type":"cuda", "device_index": i, "name": torch.cuda.get_device_name(i)} for i in range(torch.cuda.device_count())]
|
||||||
|
device_list.append({"type":"cpu", "device_index": 0, "name": "cpu"})
|
||||||
|
return {"status":200, "result":device_list}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict:
|
||||||
|
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setSelectedTranslationComputeDevice(device:str, *args, **kwargs) -> dict:
|
||||||
|
printLog("setSelectedTranslationComputeDevice", device)
|
||||||
|
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device
|
||||||
|
return {"status":200,"result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getSelectableCtranslate2WeightTypeDict(*args, **kwargs) -> dict:
|
def getSelectableCtranslate2WeightTypeDict(*args, **kwargs) -> dict:
|
||||||
return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT}
|
return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_WEIGHT_TYPE_DICT}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def getSelectedTranscriptionComputeDevice(*args, **kwargs) -> dict:
|
||||||
|
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setSelectedTranscriptionComputeDevice(device:str, *args, **kwargs) -> dict:
|
||||||
|
printLog("setSelectedTranscriptionComputeDevice", device)
|
||||||
|
config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = device
|
||||||
|
return {"status":200,"result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getSelectableWhisperModelTypeDict(*args, **kwargs) -> dict:
|
def getSelectableWhisperModelTypeDict(*args, **kwargs) -> dict:
|
||||||
return {"status":200, "result":config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT}
|
return {"status":200, "result":config.SELECTABLE_WHISPER_WEIGHT_TYPE_DICT}
|
||||||
|
|||||||
@@ -128,6 +128,10 @@ mapping = {
|
|||||||
"/set/enable/use_translation_feature": {"status": True, "variable":controller.setEnableUseTranslationFeature},
|
"/set/enable/use_translation_feature": {"status": True, "variable":controller.setEnableUseTranslationFeature},
|
||||||
"/set/disable/use_translation_feature": {"status": True, "variable":controller.setDisableUseTranslationFeature},
|
"/set/disable/use_translation_feature": {"status": True, "variable":controller.setDisableUseTranslationFeature},
|
||||||
|
|
||||||
|
"/get/data/translation_compute_device_dict": {"status": True, "variable":controller.getComputeDeviceList},
|
||||||
|
"/get/data/selected_translation_compute_device": {"status": True, "variable":controller.getSelectedTranslationComputeDevice},
|
||||||
|
"/set/data/selected_translation_compute_device": {"status": True, "variable":controller.setSelectedTranslationComputeDevice},
|
||||||
|
|
||||||
"/get/data/selectable_ctranslate2_weight_type_dict": {"status": True, "variable":controller.getSelectableCtranslate2WeightTypeDict},
|
"/get/data/selectable_ctranslate2_weight_type_dict": {"status": True, "variable":controller.getSelectableCtranslate2WeightTypeDict},
|
||||||
|
|
||||||
"/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType},
|
"/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType},
|
||||||
@@ -229,6 +233,10 @@ mapping = {
|
|||||||
"/set/enable/check_speaker_threshold": {"status": True, "variable":controller.setEnableCheckSpeakerThreshold},
|
"/set/enable/check_speaker_threshold": {"status": True, "variable":controller.setEnableCheckSpeakerThreshold},
|
||||||
"/set/disable/check_speaker_threshold": {"status": True, "variable":controller.setDisableCheckSpeakerThreshold},
|
"/set/disable/check_speaker_threshold": {"status": True, "variable":controller.setDisableCheckSpeakerThreshold},
|
||||||
|
|
||||||
|
"/get/data/transcription_compute_device_dict": {"status": True, "variable":controller.getComputeDeviceList},
|
||||||
|
"/get/data/selected_transcription_compute_device": {"status": True, "variable":controller.getSelectedTranscriptionComputeDevice},
|
||||||
|
"/set/data/selected_transcription_compute_device": {"status": True, "variable":controller.setSelectedTranscriptionComputeDevice},
|
||||||
|
|
||||||
"/get/data/selectable_whisper_weight_type_dict": {"status": True, "variable":controller.getSelectableWhisperModelTypeDict},
|
"/get/data/selectable_whisper_weight_type_dict": {"status": True, "variable":controller.getSelectableWhisperModelTypeDict},
|
||||||
|
|
||||||
"/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType},
|
"/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType},
|
||||||
|
|||||||
Reference in New Issue
Block a user