Merge branch 'bugfix_compute_type' into develop
This commit is contained in:
@@ -11,7 +11,7 @@ from models.translation.translation_languages import translation_lang
|
||||
from models.translation.translation_utils import ctranslate2_weights
|
||||
from models.transcription.transcription_languages import transcription_lang
|
||||
from models.transcription.transcription_whisper import _MODELS as whisper_models
|
||||
from utils import errorLogging, validateDictStructure
|
||||
from utils import errorLogging, validateDictStructure, getComputeDeviceList
|
||||
|
||||
json_serializable_vars = {}
|
||||
def json_serializable(var_name):
|
||||
@@ -818,6 +818,18 @@ class Config:
|
||||
self._CTRANSLATE2_WEIGHT_TYPE = value
|
||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||
|
||||
@property
|
||||
@json_serializable('SELECTED_TRANSLATION_COMPUTE_TYPE')
|
||||
def SELECTED_TRANSLATION_COMPUTE_TYPE(self):
|
||||
return self._SELECTED_TRANSLATION_COMPUTE_TYPE
|
||||
|
||||
@SELECTED_TRANSLATION_COMPUTE_TYPE.setter
|
||||
def SELECTED_TRANSLATION_COMPUTE_TYPE(self, value):
|
||||
if isinstance(value, str):
|
||||
if value in self.SELECTED_TRANSLATION_COMPUTE_DEVICE["compute_types"]:
|
||||
self._SELECTED_TRANSLATION_COMPUTE_TYPE = value
|
||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||
|
||||
@property
|
||||
@json_serializable('WHISPER_WEIGHT_TYPE')
|
||||
def WHISPER_WEIGHT_TYPE(self):
|
||||
@@ -830,6 +842,18 @@ class Config:
|
||||
self._WHISPER_WEIGHT_TYPE = value
|
||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||
|
||||
@property
|
||||
@json_serializable('SELECTED_TRANSCRIPTION_COMPUTE_TYPE')
|
||||
def SELECTED_TRANSCRIPTION_COMPUTE_TYPE(self):
|
||||
return self._SELECTED_TRANSCRIPTION_COMPUTE_TYPE
|
||||
|
||||
@SELECTED_TRANSCRIPTION_COMPUTE_TYPE.setter
|
||||
def SELECTED_TRANSCRIPTION_COMPUTE_TYPE(self, value):
|
||||
if isinstance(value, str):
|
||||
if value in self.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["compute_types"]:
|
||||
self._SELECTED_TRANSCRIPTION_COMPUTE_TYPE = value
|
||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||
|
||||
@property
|
||||
@json_serializable('AUTO_CLEAR_MESSAGE_BOX')
|
||||
def AUTO_CLEAR_MESSAGE_BOX(self):
|
||||
@@ -1050,11 +1074,7 @@ class Config:
|
||||
self._SELECTABLE_TRANSCRIPTION_ENGINE_LIST = list(transcription_lang[list(transcription_lang.keys())[0]].values())[0].keys()
|
||||
self._SELECTABLE_UI_LANGUAGE_LIST = ["en", "ja", "ko", "zh-Hant", "zh-Hans"]
|
||||
self._COMPUTE_MODE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self._SELECTABLE_COMPUTE_DEVICE_LIST = []
|
||||
if torch.cuda.is_available():
|
||||
for i in range(torch.cuda.device_count()):
|
||||
self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cuda", "device_index": i, "device_name": torch.cuda.get_device_name(i)})
|
||||
self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cpu", "device_index": 0, "device_name": "cpu"})
|
||||
self._SELECTABLE_COMPUTE_DEVICE_LIST = getComputeDeviceList()
|
||||
self._SEND_MESSAGE_BUTTON_TYPE_LIST = ["show", "hide", "show_and_disable_enter_key"]
|
||||
self._SEND_MESSAGE_FORMAT_PARTS = {
|
||||
"message": {
|
||||
@@ -1186,7 +1206,9 @@ class Config:
|
||||
self._SELECTED_TRANSLATION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||
self._SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = copy.deepcopy(self.SELECTABLE_COMPUTE_DEVICE_LIST[0])
|
||||
self._CTRANSLATE2_WEIGHT_TYPE = "small"
|
||||
self._SELECTED_TRANSLATION_COMPUTE_TYPE = "auto"
|
||||
self._WHISPER_WEIGHT_TYPE = "base"
|
||||
self._SELECTED_TRANSCRIPTION_COMPUTE_TYPE = "auto"
|
||||
self._AUTO_CLEAR_MESSAGE_BOX = True
|
||||
self._SEND_ONLY_TRANSLATED_MESSAGES = False
|
||||
self._OVERLAY_SMALL_LOG = False
|
||||
|
||||
@@ -747,13 +747,15 @@ class Controller:
|
||||
def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
||||
|
||||
@staticmethod
|
||||
def setSelectedTranslationComputeDevice(device:str, *args, **kwargs) -> dict:
|
||||
def setSelectedTranslationComputeDevice(self, device:str, *args, **kwargs) -> dict:
|
||||
printLog("setSelectedTranslationComputeDevice", device)
|
||||
pre_device = config.SELECTED_TRANSLATION_COMPUTE_DEVICE
|
||||
pre_compute_type = config.SELECTED_TRANSLATION_COMPUTE_TYPE
|
||||
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = device
|
||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = "auto"
|
||||
try:
|
||||
model.changeTranslatorCTranslate2Model()
|
||||
self.run(200, self.run_mapping["selected_translation_compute_type"], config.SELECTED_TRANSLATION_COMPUTE_TYPE)
|
||||
except Exception as e:
|
||||
# VRAM不足エラーの検出(デバイス切り替え時)
|
||||
is_vram_error, error_message = model.detectVRAMError(e)
|
||||
@@ -761,6 +763,7 @@ class Controller:
|
||||
# 前のデバイス設定に戻す
|
||||
printLog("VRAM error detected, reverting device setting")
|
||||
config.SELECTED_TRANSLATION_COMPUTE_DEVICE = pre_device
|
||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = pre_compute_type
|
||||
model.changeTranslatorCTranslate2Model()
|
||||
else:
|
||||
# その他のエラーは通常通り処理
|
||||
@@ -775,10 +778,11 @@ class Controller:
|
||||
def getSelectedTranscriptionComputeDevice(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
|
||||
|
||||
@staticmethod
|
||||
def setSelectedTranscriptionComputeDevice(device:str, *args, **kwargs) -> dict:
|
||||
def setSelectedTranscriptionComputeDevice(self, device:str, *args, **kwargs) -> dict:
|
||||
printLog("setSelectedTranscriptionComputeDevice", device)
|
||||
config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE = device
|
||||
config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE = "auto"
|
||||
self.run(200, self.run_mapping["selected_transcription_compute_type"], config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE)
|
||||
return {"status":200,"result":config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE}
|
||||
|
||||
@staticmethod
|
||||
@@ -1549,6 +1553,7 @@ class Controller:
|
||||
|
||||
@staticmethod
|
||||
def setCtranslate2WeightType(data, *args, **kwargs) -> dict:
|
||||
pre_weight_type = config.CTRANSLATE2_WEIGHT_TYPE
|
||||
config.CTRANSLATE2_WEIGHT_TYPE = str(data)
|
||||
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
|
||||
def callback():
|
||||
@@ -1557,8 +1562,29 @@ class Controller:
|
||||
th_callback.daemon = True
|
||||
th_callback.start()
|
||||
th_callback.join()
|
||||
else:
|
||||
config.CTRANSLATE2_WEIGHT_TYPE = pre_weight_type
|
||||
return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def getSelectedTranslationComputeType(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def setSelectedTranslationComputeType(data, *args, **kwargs) -> dict:
|
||||
pre_compute_type = config.SELECTED_TRANSLATION_COMPUTE_TYPE
|
||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = str(data)
|
||||
if model.checkTranslatorCTranslate2ModelWeight(config.CTRANSLATE2_WEIGHT_TYPE):
|
||||
def callback():
|
||||
model.changeTranslatorCTranslate2Model()
|
||||
th_callback = Thread(target=callback)
|
||||
th_callback.daemon = True
|
||||
th_callback.start()
|
||||
th_callback.join()
|
||||
else:
|
||||
config.SELECTED_TRANSLATION_COMPUTE_TYPE = pre_compute_type
|
||||
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def getWhisperWeightType(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.WHISPER_WEIGHT_TYPE}
|
||||
@@ -1568,6 +1594,15 @@ class Controller:
|
||||
config.WHISPER_WEIGHT_TYPE = str(data)
|
||||
return {"status":200, "result": config.WHISPER_WEIGHT_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def getSelectedTranscriptionComputeType(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def setSelectedTranscriptionComputeType(data, *args, **kwargs) -> dict:
|
||||
config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE = str(data)
|
||||
return {"status":200, "result":config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE}
|
||||
|
||||
@staticmethod
|
||||
def getSendMessageFormatParts(*args, **kwargs) -> dict:
|
||||
return {"status":200, "result":config.SEND_MESSAGE_FORMAT_PARTS}
|
||||
|
||||
@@ -48,6 +48,9 @@ run_mapping = {
|
||||
"selected_translation_engines":"/run/selected_translation_engines",
|
||||
"translation_engines":"/run/translation_engines",
|
||||
|
||||
"selected_translation_compute_type":"/run/selected_translation_compute_type",
|
||||
"selected_transcription_compute_type":"/run/selected_transcription_compute_type",
|
||||
|
||||
"mic_host_list":"/run/mic_host_list",
|
||||
"mic_device_list":"/run/mic_device_list",
|
||||
"speaker_device_list":"/run/speaker_device_list",
|
||||
@@ -162,6 +165,9 @@ mapping = {
|
||||
"/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType},
|
||||
"/set/data/ctranslate2_weight_type": {"status": True, "variable":controller.setCtranslate2WeightType},
|
||||
|
||||
"/get/data/selected_translation_compute_type": {"status": True, "variable":controller.getSelectedTranslationComputeType},
|
||||
"/set/data/selected_translation_compute_type": {"status": True, "variable":controller.setSelectedTranslationComputeType},
|
||||
|
||||
"/run/download_ctranslate2_weight": {"status": True, "variable":controller.downloadCtranslate2Weight},
|
||||
|
||||
"/get/data/deepl_auth_key": {"status": False, "variable":controller.getDeepLAuthKey},
|
||||
@@ -261,8 +267,13 @@ mapping = {
|
||||
"/set/disable/check_speaker_threshold": {"status": True, "variable":controller.setDisableCheckSpeakerThreshold},
|
||||
|
||||
"/get/data/selectable_whisper_weight_type_dict": {"status": True, "variable":controller.getSelectableWhisperWeightTypeDict},
|
||||
|
||||
"/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType},
|
||||
"/set/data/whisper_weight_type": {"status": True, "variable":controller.setWhisperWeightType},
|
||||
|
||||
"/get/data/selected_transcription_compute_type": {"status": True, "variable":controller.getSelectedTranscriptionComputeType},
|
||||
"/set/data/selected_transcription_compute_type": {"status": True, "variable":controller.setSelectedTranscriptionComputeType},
|
||||
|
||||
"/run/download_whisper_weight": {"status": True, "variable":controller.downloadWhisperWeight},
|
||||
|
||||
# VR
|
||||
|
||||
@@ -112,10 +112,12 @@ class Model:
|
||||
|
||||
def changeTranslatorCTranslate2Model(self):
|
||||
self.translator.changeCTranslate2Model(
|
||||
config.PATH_LOCAL,
|
||||
config.CTRANSLATE2_WEIGHT_TYPE,
|
||||
config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device"],
|
||||
config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"])
|
||||
path=config.PATH_LOCAL,
|
||||
model_type=config.CTRANSLATE2_WEIGHT_TYPE,
|
||||
device=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device"],
|
||||
device_index=config.SELECTED_TRANSLATION_COMPUTE_DEVICE["device_index"],
|
||||
compute_type=config.SELECTED_TRANSLATION_COMPUTE_TYPE
|
||||
)
|
||||
|
||||
def downloadCTranslate2ModelWeight(self, weight_type, callback=None, end_callback=None):
|
||||
return downloadCTranslate2Weight(config.PATH_LOCAL, weight_type, callback, end_callback)
|
||||
@@ -446,6 +448,7 @@ class Model:
|
||||
whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
|
||||
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"],
|
||||
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
|
||||
compute_type=config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE,
|
||||
)
|
||||
def sendMicTranscript():
|
||||
try:
|
||||
@@ -629,6 +632,7 @@ class Model:
|
||||
whisper_weight_type=config.WHISPER_WEIGHT_TYPE,
|
||||
device=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device"],
|
||||
device_index=config.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["device_index"],
|
||||
compute_type=config.SELECTED_TRANSCRIPTION_COMPUTE_TYPE,
|
||||
)
|
||||
def sendSpeakerTranscript():
|
||||
try:
|
||||
|
||||
@@ -21,7 +21,7 @@ PHRASE_TIMEOUT = 3
|
||||
MAX_PHRASES = 10
|
||||
|
||||
class AudioTranscriber:
|
||||
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0):
|
||||
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0, compute_type="auto"):
|
||||
self.speaker = speaker
|
||||
self.phrase_timeout = phrase_timeout
|
||||
self.max_phrases = max_phrases
|
||||
@@ -41,7 +41,7 @@ class AudioTranscriber:
|
||||
}
|
||||
|
||||
if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True:
|
||||
self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index)
|
||||
self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index, compute_type=compute_type)
|
||||
self.transcription_engine = "Whisper"
|
||||
|
||||
def transcribeAudioQueue(self, audio_queue, languages, countries, avg_logprob=-0.8, no_speech_prob=0.6):
|
||||
|
||||
@@ -74,9 +74,10 @@ def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None):
|
||||
if isinstance(end_callback, Callable):
|
||||
end_callback()
|
||||
|
||||
def getWhisperModel(root, weight_type, device="cpu", device_index=0):
|
||||
def getWhisperModel(root, weight_type, device="cpu", device_index=0, compute_type="auto"):
|
||||
path = os_path.join(root, "weights", "whisper", weight_type)
|
||||
compute_type = getBestComputeType(device, device_index)
|
||||
if compute_type == "auto":
|
||||
compute_type = getBestComputeType(device, device_index)
|
||||
try:
|
||||
model = WhisperModel(
|
||||
path,
|
||||
|
||||
@@ -36,14 +36,15 @@ class Translator():
|
||||
result = False
|
||||
return result
|
||||
|
||||
def changeCTranslate2Model(self, path, model_type, device="cpu", device_index=0):
|
||||
def changeCTranslate2Model(self, path, model_type, device="cpu", device_index=0, compute_type="auto"):
|
||||
self.is_loaded_ctranslate2_model = False
|
||||
directory_name = ctranslate2_weights[model_type]["directory_name"]
|
||||
tokenizer = ctranslate2_weights[model_type]["tokenizer"]
|
||||
weight_path = os_path.join(path, "weights", "ctranslate2", directory_name)
|
||||
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
||||
|
||||
compute_type = getBestComputeType(device, device_index)
|
||||
if compute_type == "auto":
|
||||
compute_type = getBestComputeType(device, device_index)
|
||||
self.ctranslate2_translator = ctranslate2.Translator(
|
||||
weight_path,
|
||||
device=device,
|
||||
|
||||
@@ -5,6 +5,7 @@ import traceback
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import torch
|
||||
from ctranslate2 import get_supported_compute_types
|
||||
import requests
|
||||
import ipaddress
|
||||
@@ -78,14 +79,67 @@ def isValidIpAddress(ip_address: str) -> bool:
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def getBestComputeType(device, device_index) -> str:
|
||||
compute_types = get_supported_compute_types(device, device_index)
|
||||
compute_types = set(compute_types)
|
||||
preferred_types = ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"]
|
||||
def getComputeDeviceList() -> dict:
|
||||
compute_types = [
|
||||
{
|
||||
"device": "cpu",
|
||||
"device_index": 0,
|
||||
"device_name": "cpu",
|
||||
"compute_types": ["auto"] + list(get_supported_compute_types("cpu", 0)),
|
||||
}
|
||||
]
|
||||
|
||||
for preferred_type in preferred_types:
|
||||
if preferred_type in compute_types:
|
||||
return preferred_type
|
||||
if torch.cuda.is_available():
|
||||
for device_index in range(torch.cuda.device_count()):
|
||||
gpu_device_name = torch.cuda.get_device_name(device_index)
|
||||
gpu_compute_types = ["auto"] + list(get_supported_compute_types("cuda", device_index))
|
||||
|
||||
# デバイスごとの計算タイプの制限
|
||||
if "GTX" in gpu_device_name:
|
||||
unsupported_types = {"int8_bfloat16", "bfloat16", "float16", "int8"}
|
||||
gpu_compute_types = [t for t in gpu_compute_types if t not in unsupported_types]
|
||||
elif not any(keyword in gpu_device_name for keyword in ["RTX", "Tesla", "A100", "Quadro"]):
|
||||
gpu_compute_types = ["float32"]
|
||||
|
||||
compute_types.append(
|
||||
{
|
||||
"device": "cuda",
|
||||
"device_index": device_index,
|
||||
"device_name": gpu_device_name,
|
||||
"compute_types": gpu_compute_types,
|
||||
}
|
||||
)
|
||||
|
||||
return compute_types
|
||||
|
||||
def getBestComputeType(device: str, device_index: int) -> str:
|
||||
compute_types = set(get_supported_compute_types(device, device_index))
|
||||
device_name = "cpu" if device == "cpu" else torch.cuda.get_device_name(device_index)
|
||||
|
||||
# デバイスごとの優先計算タイプ
|
||||
preferred_types = {
|
||||
"default": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||
"GTX": ["float32"],
|
||||
"RTX": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||
"Tesla": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||
"A100": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||
"Quadro": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||
}
|
||||
|
||||
# デバイス名に基づいて優先タイプを選択
|
||||
for key in preferred_types:
|
||||
if key in device_name:
|
||||
selected_types = preferred_types[key]
|
||||
break
|
||||
else:
|
||||
selected_types = preferred_types["default"]
|
||||
|
||||
# 利用可能な計算タイプを返す
|
||||
for compute_type in selected_types:
|
||||
if compute_type in compute_types:
|
||||
return compute_type
|
||||
|
||||
return "float32"
|
||||
|
||||
def encodeBase64(data:str) -> dict:
|
||||
return json.loads(base64.b64decode(data).decode('utf-8'))
|
||||
@@ -175,4 +229,7 @@ def errorLogging() -> None:
|
||||
if error_logger is None:
|
||||
error_logger = setupLogger("error", "error.log", logging.ERROR)
|
||||
|
||||
error_logger.error(traceback.format_exc())
|
||||
error_logger.error(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(getComputeDeviceList())
|
||||
Reference in New Issue
Block a user