[Update] Add compute type management for CTranslate2 and Whisper models
This commit is contained in:
@@ -11,7 +11,7 @@ from models.translation.translation_languages import translation_lang
|
|||||||
from models.translation.translation_utils import ctranslate2_weights
|
from models.translation.translation_utils import ctranslate2_weights
|
||||||
from models.transcription.transcription_languages import transcription_lang
|
from models.transcription.transcription_languages import transcription_lang
|
||||||
from models.transcription.transcription_whisper import _MODELS as whisper_models
|
from models.transcription.transcription_whisper import _MODELS as whisper_models
|
||||||
from utils import errorLogging, validateDictStructure
|
from utils import errorLogging, validateDictStructure, getComputeTypeList
|
||||||
|
|
||||||
json_serializable_vars = {}
|
json_serializable_vars = {}
|
||||||
def json_serializable(var_name):
|
def json_serializable(var_name):
|
||||||
@@ -135,6 +135,14 @@ class Config:
|
|||||||
def SELECTABLE_COMPUTE_DEVICE_LIST(self):
|
def SELECTABLE_COMPUTE_DEVICE_LIST(self):
|
||||||
return self._SELECTABLE_COMPUTE_DEVICE_LIST
|
return self._SELECTABLE_COMPUTE_DEVICE_LIST
|
||||||
|
|
||||||
|
@property
|
||||||
|
def SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST(self):
|
||||||
|
return self._SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST
|
||||||
|
|
||||||
|
@property
|
||||||
|
def SELECTABLE_WHISPER_COMPUTE_TYPE_LIST(self):
|
||||||
|
return self._SELECTABLE_WHISPER_COMPUTE_TYPE_LIST
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def SEND_MESSAGE_BUTTON_TYPE_LIST(self):
|
def SEND_MESSAGE_BUTTON_TYPE_LIST(self):
|
||||||
return self._SEND_MESSAGE_BUTTON_TYPE_LIST
|
return self._SEND_MESSAGE_BUTTON_TYPE_LIST
|
||||||
@@ -814,6 +822,18 @@ class Config:
|
|||||||
self._CTRANSLATE2_WEIGHT_TYPE = value
|
self._CTRANSLATE2_WEIGHT_TYPE = value
|
||||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@json_serializable('CTRANSLATE2_COMPUTE_TYPE')
|
||||||
|
def CTRANSLATE2_COMPUTE_TYPE(self):
|
||||||
|
return self._CTRANSLATE2_COMPUTE_TYPE
|
||||||
|
|
||||||
|
@CTRANSLATE2_COMPUTE_TYPE.setter
|
||||||
|
def CTRANSLATE2_COMPUTE_TYPE(self, value):
|
||||||
|
if isinstance(value, str):
|
||||||
|
if value in self.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST:
|
||||||
|
self._CTRANSLATE2_COMPUTE_TYPE = value
|
||||||
|
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@json_serializable('WHISPER_WEIGHT_TYPE')
|
@json_serializable('WHISPER_WEIGHT_TYPE')
|
||||||
def WHISPER_WEIGHT_TYPE(self):
|
def WHISPER_WEIGHT_TYPE(self):
|
||||||
@@ -826,6 +846,18 @@ class Config:
|
|||||||
self._WHISPER_WEIGHT_TYPE = value
|
self._WHISPER_WEIGHT_TYPE = value
|
||||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@json_serializable('WHISPER_COMPUTE_TYPE')
|
||||||
|
def WHISPER_COMPUTE_TYPE(self):
|
||||||
|
return self._WHISPER_COMPUTE_TYPE
|
||||||
|
|
||||||
|
@WHISPER_COMPUTE_TYPE.setter
|
||||||
|
def WHISPER_COMPUTE_TYPE(self, value):
|
||||||
|
if isinstance(value, str):
|
||||||
|
if value in self.SELECTABLE_WHISPER_COMPUTE_TYPE_LIST:
|
||||||
|
self._WHISPER_COMPUTE_TYPE = value
|
||||||
|
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@json_serializable('AUTO_CLEAR_MESSAGE_BOX')
|
@json_serializable('AUTO_CLEAR_MESSAGE_BOX')
|
||||||
def AUTO_CLEAR_MESSAGE_BOX(self):
|
def AUTO_CLEAR_MESSAGE_BOX(self):
|
||||||
@@ -1051,6 +1083,8 @@ class Config:
|
|||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cuda", "device_index": i, "device_name": torch.cuda.get_device_name(i)})
|
self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cuda", "device_index": i, "device_name": torch.cuda.get_device_name(i)})
|
||||||
self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cpu", "device_index": 0, "device_name": "cpu"})
|
self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cpu", "device_index": 0, "device_name": "cpu"})
|
||||||
|
self._SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST = ["auto"] + getComputeTypeList()
|
||||||
|
self._SELECTABLE_WHISPER_COMPUTE_TYPE_LIST = ["auto"] + getComputeTypeList()
|
||||||
self._SEND_MESSAGE_BUTTON_TYPE_LIST = ["show", "hide", "show_and_disable_enter_key"]
|
self._SEND_MESSAGE_BUTTON_TYPE_LIST = ["show", "hide", "show_and_disable_enter_key"]
|
||||||
self._SEND_MESSAGE_FORMAT_PARTS = {
|
self._SEND_MESSAGE_FORMAT_PARTS = {
|
||||||
"message": {
|
"message": {
|
||||||
@@ -1189,7 +1223,9 @@ class Config:
|
|||||||
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||||
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||||
self._CTRANSLATE2_WEIGHT_TYPE = "small"
|
self._CTRANSLATE2_WEIGHT_TYPE = "small"
|
||||||
|
self._CTRANSLATE2_COMPUTE_TYPE = "auto"
|
||||||
self._WHISPER_WEIGHT_TYPE = "base"
|
self._WHISPER_WEIGHT_TYPE = "base"
|
||||||
|
self._WHISPER_COMPUTE_TYPE = "auto"
|
||||||
self._AUTO_CLEAR_MESSAGE_BOX = True
|
self._AUTO_CLEAR_MESSAGE_BOX = True
|
||||||
self._SEND_ONLY_TRANSLATED_MESSAGES = False
|
self._SEND_ONLY_TRANSLATED_MESSAGES = False
|
||||||
self._OVERLAY_SMALL_LOG = False
|
self._OVERLAY_SMALL_LOG = False
|
||||||
|
|||||||
@@ -652,6 +652,14 @@ class Controller:
|
|||||||
def getComputeDeviceList(*args, **kwargs) -> dict:
|
def getComputeDeviceList(*args, **kwargs) -> dict:
|
||||||
return {"status":200, "result":config.SELECTABLE_COMPUTE_DEVICE_LIST}
|
return {"status":200, "result":config.SELECTABLE_COMPUTE_DEVICE_LIST}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def getCTranslate2ComputeTypeList(*args, **kwargs) -> dict:
|
||||||
|
return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def getWhisperComputeTypeList(*args, **kwargs) -> dict:
|
||||||
|
return {"status":200, "result":config.SELECTABLE_WHISPER_COMPUTE_TYPE_LIST}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict:
|
def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict:
|
||||||
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
||||||
@@ -1447,6 +1455,22 @@ class Controller:
|
|||||||
th_callback.join()
|
th_callback.join()
|
||||||
return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE}
|
return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def getCtranslateComputeType(*args, **kwargs) -> dict:
|
||||||
|
return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setCtranslateComputeType(data, *args, **kwargs) -> dict:
|
||||||
|
config.CTRANSLATE2_COMPUTE_TYPE = str(data)
|
||||||
|
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
|
||||||
|
def callback():
|
||||||
|
model.changeTranslatorCTranslate2Model()
|
||||||
|
th_callback = Thread(target=callback)
|
||||||
|
th_callback.daemon = True
|
||||||
|
th_callback.start()
|
||||||
|
th_callback.join()
|
||||||
|
return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getWhisperWeightType(*args, **kwargs) -> dict:
|
def getWhisperWeightType(*args, **kwargs) -> dict:
|
||||||
return {"status":200, "result":config.WHISPER_WEIGHT_TYPE}
|
return {"status":200, "result":config.WHISPER_WEIGHT_TYPE}
|
||||||
@@ -1456,6 +1480,15 @@ class Controller:
|
|||||||
config.WHISPER_WEIGHT_TYPE = str(data)
|
config.WHISPER_WEIGHT_TYPE = str(data)
|
||||||
return {"status":200, "result": config.WHISPER_WEIGHT_TYPE}
|
return {"status":200, "result": config.WHISPER_WEIGHT_TYPE}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def getWhisperComputeType(*args, **kwargs) -> dict:
|
||||||
|
return {"status":200, "result":config.WHISPER_COMPUTE_TYPE}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setWhisperComputeType(data, *args, **kwargs) -> dict:
|
||||||
|
config.WHISPER_COMPUTE_TYPE = str(data)
|
||||||
|
return {"status":200, "result":config.WHISPER_COMPUTE_TYPE}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getSendMessageFormatParts(*args, **kwargs) -> dict:
|
def getSendMessageFormatParts(*args, **kwargs) -> dict:
|
||||||
return {"status":200, "result":config.SEND_MESSAGE_FORMAT_PARTS}
|
return {"status":200, "result":config.SEND_MESSAGE_FORMAT_PARTS}
|
||||||
|
|||||||
@@ -162,6 +162,9 @@ mapping = {
|
|||||||
"/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType},
|
"/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType},
|
||||||
"/set/data/ctranslate2_weight_type": {"status": True, "variable":controller.setCtranslate2WeightType},
|
"/set/data/ctranslate2_weight_type": {"status": True, "variable":controller.setCtranslate2WeightType},
|
||||||
|
|
||||||
|
"/get/data/ctranslate2_compute_type": {"status": True, "variable":controller.getCtranslateComputeType},
|
||||||
|
"/set/data/ctranslate2_compute_type": {"status": True, "variable":controller.setCtranslateComputeType},
|
||||||
|
|
||||||
"/run/download_ctranslate2_weight": {"status": True, "variable":controller.downloadCtranslate2Weight},
|
"/run/download_ctranslate2_weight": {"status": True, "variable":controller.downloadCtranslate2Weight},
|
||||||
|
|
||||||
"/get/data/deepl_auth_key": {"status": False, "variable":controller.getDeepLAuthKey},
|
"/get/data/deepl_auth_key": {"status": False, "variable":controller.getDeepLAuthKey},
|
||||||
@@ -261,8 +264,13 @@ mapping = {
|
|||||||
"/set/disable/check_speaker_threshold": {"status": True, "variable":controller.setDisableCheckSpeakerThreshold},
|
"/set/disable/check_speaker_threshold": {"status": True, "variable":controller.setDisableCheckSpeakerThreshold},
|
||||||
|
|
||||||
"/get/data/selectable_whisper_weight_type_dict": {"status": True, "variable":controller.getSelectableWhisperWeightTypeDict},
|
"/get/data/selectable_whisper_weight_type_dict": {"status": True, "variable":controller.getSelectableWhisperWeightTypeDict},
|
||||||
|
|
||||||
"/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType},
|
"/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType},
|
||||||
"/set/data/whisper_weight_type": {"status": True, "variable":controller.setWhisperWeightType},
|
"/set/data/whisper_weight_type": {"status": True, "variable":controller.setWhisperWeightType},
|
||||||
|
|
||||||
|
"/get/data/whisper_compute_type": {"status": True, "variable":controller.getWhisperComputeType},
|
||||||
|
"/set/data/whisper_compute_type": {"status": True, "variable":controller.setWhisperComputeType},
|
||||||
|
|
||||||
"/run/download_whisper_weight": {"status": True, "variable":controller.downloadWhisperWeight},
|
"/run/download_whisper_weight": {"status": True, "variable":controller.downloadWhisperWeight},
|
||||||
|
|
||||||
# VR
|
# VR
|
||||||
|
|||||||
@@ -112,10 +112,12 @@ class Model:
|
|||||||
|
|
||||||
def changeTranslatorCTranslate2Model(self):
|
def changeTranslatorCTranslate2Model(self):
|
||||||
self.translator.changeCTranslate2Model(
|
self.translator.changeCTranslate2Model(
|
||||||
config.PATH_LOCAL,
|
path=config.PATH_LOCAL,
|
||||||
config.CTRANSLATE2_WEIGHT_TYPE,
|
model_type=config.CTRANSLATE2_WEIGHT_TYPE,
|
||||||
config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device"],
|
device=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device"],
|
||||||
config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"])
|
device_index=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"],
|
||||||
|
compute_type=config.CTRANSLATE2_COMPUTE_TYPE
|
||||||
|
)
|
||||||
|
|
||||||
def downloadCTranslate2ModelWeight(self, weight_type, callback=None, end_callback=None):
|
def downloadCTranslate2ModelWeight(self, weight_type, callback=None, end_callback=None):
|
||||||
return downloadCTranslate2Weight(config.PATH_LOCAL, weight_type, callback, end_callback)
|
return downloadCTranslate2Weight(config.PATH_LOCAL, weight_type, callback, end_callback)
|
||||||
@@ -438,6 +440,7 @@ class Model:
|
|||||||
whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
|
whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
|
||||||
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"],
|
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"],
|
||||||
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
|
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
|
||||||
|
compute_type=config.WHISPER_COMPUTE_TYPE,
|
||||||
)
|
)
|
||||||
def sendMicTranscript():
|
def sendMicTranscript():
|
||||||
try:
|
try:
|
||||||
@@ -621,6 +624,7 @@ class Model:
|
|||||||
whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
|
whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
|
||||||
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"],
|
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"],
|
||||||
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
|
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
|
||||||
|
compute_type=config.WHISPER_COMPUTE_TYPE,
|
||||||
)
|
)
|
||||||
def sendSpeakerTranscript():
|
def sendSpeakerTranscript():
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ PHRASE_TIMEOUT = 3
|
|||||||
MAX_PHRASES = 10
|
MAX_PHRASES = 10
|
||||||
|
|
||||||
class AudioTranscriber:
|
class AudioTranscriber:
|
||||||
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0):
|
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0, compute_type="auto"):
|
||||||
self.speaker = speaker
|
self.speaker = speaker
|
||||||
self.phrase_timeout = phrase_timeout
|
self.phrase_timeout = phrase_timeout
|
||||||
self.max_phrases = max_phrases
|
self.max_phrases = max_phrases
|
||||||
@@ -41,7 +41,7 @@ class AudioTranscriber:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True:
|
if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True:
|
||||||
self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index)
|
self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index, compute_type=compute_type)
|
||||||
self.transcription_engine = "Whisper"
|
self.transcription_engine = "Whisper"
|
||||||
|
|
||||||
def transcribeAudioQueue(self, audio_queue, languages, countries, avg_logprob=-0.8, no_speech_prob=0.6):
|
def transcribeAudioQueue(self, audio_queue, languages, countries, avg_logprob=-0.8, no_speech_prob=0.6):
|
||||||
|
|||||||
@@ -74,8 +74,9 @@ def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None):
|
|||||||
if isinstance(end_callback, Callable):
|
if isinstance(end_callback, Callable):
|
||||||
end_callback()
|
end_callback()
|
||||||
|
|
||||||
def getWhisperModel(root, weight_type, device="cpu", device_index=0):
|
def getWhisperModel(root, weight_type, device="cpu", device_index=0, compute_type="auto"):
|
||||||
path = os_path.join(root, "weights", "whisper", weight_type)
|
path = os_path.join(root, "weights", "whisper", weight_type)
|
||||||
|
if compute_type == "auto":
|
||||||
compute_type = getBestComputeType(device, device_index)
|
compute_type = getBestComputeType(device, device_index)
|
||||||
try:
|
try:
|
||||||
model = WhisperModel(
|
model = WhisperModel(
|
||||||
|
|||||||
@@ -36,13 +36,14 @@ class Translator():
|
|||||||
result = False
|
result = False
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def changeCTranslate2Model(self, path, model_type, device="cpu", device_index=0):
|
def changeCTranslate2Model(self, path, model_type, device="cpu", device_index=0, compute_type="auto"):
|
||||||
self.is_loaded_ctranslate2_model = False
|
self.is_loaded_ctranslate2_model = False
|
||||||
directory_name = ctranslate2_weights[model_type]["directory_name"]
|
directory_name = ctranslate2_weights[model_type]["directory_name"]
|
||||||
tokenizer = ctranslate2_weights[model_type]["tokenizer"]
|
tokenizer = ctranslate2_weights[model_type]["tokenizer"]
|
||||||
weight_path = os_path.join(path, "weights", "ctranslate2", directory_name)
|
weight_path = os_path.join(path, "weights", "ctranslate2", directory_name)
|
||||||
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
||||||
|
|
||||||
|
if compute_type == "auto":
|
||||||
compute_type = getBestComputeType(device, device_index)
|
compute_type = getBestComputeType(device, device_index)
|
||||||
self.ctranslate2_translator = ctranslate2.Translator(
|
self.ctranslate2_translator = ctranslate2.Translator(
|
||||||
weight_path,
|
weight_path,
|
||||||
|
|||||||
@@ -78,10 +78,13 @@ def isValidIpAddress(ip_address: str) -> bool:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def getComputeTypeList() -> list:
|
||||||
|
return ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"]
|
||||||
|
|
||||||
def getBestComputeType(device, device_index) -> str:
|
def getBestComputeType(device, device_index) -> str:
|
||||||
compute_types = get_supported_compute_types(device, device_index)
|
compute_types = get_supported_compute_types(device, device_index)
|
||||||
compute_types = set(compute_types)
|
compute_types = set(compute_types)
|
||||||
preferred_types = ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"]
|
preferred_types = getComputeTypeList()
|
||||||
|
|
||||||
for preferred_type in preferred_types:
|
for preferred_type in preferred_types:
|
||||||
if preferred_type in compute_types:
|
if preferred_type in compute_types:
|
||||||
|
|||||||
Reference in New Issue
Block a user