diff --git a/src-python/config.py b/src-python/config.py index 1a605701..bd3af244 100644 --- a/src-python/config.py +++ b/src-python/config.py @@ -11,7 +11,7 @@ from models.translation.translation_languages import translation_lang from models.translation.translation_utils import ctranslate2_weights from models.transcription.transcription_languages import transcription_lang from models.transcription.transcription_whisper import _MODELS as whisper_models -from utils import errorLogging, validateDictStructure, getComputeTypeList +from utils import errorLogging, validateDictStructure, getComputeDeviceList json_serializable_vars = {} def json_serializable(var_name): @@ -135,14 +135,6 @@ class Config: def SELECTABLE_COMPUTE_DEVICE_LIST(self): return self._SELECTABLE_COMPUTE_DEVICE_LIST - @property - def SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST(self): - return self._SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST - - @property - def SELECTABLE_WHISPER_COMPUTE_TYPE_LIST(self): - return self._SELECTABLE_WHISPER_COMPUTE_TYPE_LIST - @property def SEND_MESSAGE_BUTTON_TYPE_LIST(self): return self._SEND_MESSAGE_BUTTON_TYPE_LIST @@ -830,7 +822,7 @@ class Config: @CTRANSLATE2_COMPUTE_TYPE.setter def CTRANSLATE2_COMPUTE_TYPE(self, value): if isinstance(value, str): - if value in self.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST: + if value in self.SELECTED_TRANSLATION_COMPUTE_DEVICE["compute_type"]: self._CTRANSLATE2_COMPUTE_TYPE = value self.saveConfig(inspect.currentframe().f_code.co_name, value) @@ -854,7 +846,7 @@ class Config: @WHISPER_COMPUTE_TYPE.setter def WHISPER_COMPUTE_TYPE(self, value): if isinstance(value, str): - if value in self.SELECTABLE_WHISPER_COMPUTE_TYPE_LIST: + if value in self.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["compute_type"]: self._WHISPER_COMPUTE_TYPE = value self.saveConfig(inspect.currentframe().f_code.co_name, value) @@ -1078,13 +1070,7 @@ class Config: self._SELECTABLE_TRANSCRIPTION_ENGINE_LIST = list(transcription_lang[list(transcription_lang.keys())[0]].values())[0].keys() self._SELECTABLE_UI_LANGUAGE_LIST = ["en", "ja", "ko", "zh-Hant", "zh-Hans"] self._COMPUTE_MODE = "cuda" if torch.cuda.is_available() else "cpu" - self._SELECTABLE_COMPUTE_DEVICE_LIST = [] - if torch.cuda.is_available(): - for i in range(torch.cuda.device_count()): - self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cuda", "device_index": i, "device_name": torch.cuda.get_device_name(i)}) - self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cpu", "device_index": 0, "device_name": "cpu"}) - self._SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST = ["auto"] + getComputeTypeList() - self._SELECTABLE_WHISPER_COMPUTE_TYPE_LIST = ["auto"] + getComputeTypeList() + self._SELECTABLE_COMPUTE_DEVICE_LIST = getComputeDeviceList() self._SEND_MESSAGE_BUTTON_TYPE_LIST = ["show", "hide", "show_and_disable_enter_key"] self._SEND_MESSAGE_FORMAT_PARTS = { "message": { diff --git a/src-python/controller.py b/src-python/controller.py index 77717918..5ea86430 100644 --- a/src-python/controller.py +++ b/src-python/controller.py @@ -652,14 +652,6 @@ class Controller: def getComputeDeviceList(*args, **kwargs) -> dict: return {"status":200, "result":config.SELECTABLE_COMPUTE_DEVICE_LIST} - @staticmethod - def getCTranslate2ComputeTypeList(*args, **kwargs) -> dict: - return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST} - - @staticmethod - def getWhisperComputeTypeList(*args, **kwargs) -> dict: - return {"status":200, "result":config.SELECTABLE_WHISPER_COMPUTE_TYPE_LIST} - @staticmethod def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict: return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE} @@ -1455,10 +1447,6 @@ class Controller: th_callback.join() return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE} - @staticmethod - def getCtranslate2ComputeTypeList(*args, **kwargs) -> dict: - return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST} - @staticmethod def getCtranslate2ComputeType(*args, **kwargs) -> dict: return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE} diff --git a/src-python/mainloop.py b/src-python/mainloop.py index c7b03ea6..73e75594 100644 --- a/src-python/mainloop.py +++ b/src-python/mainloop.py @@ -162,8 +162,6 @@ mapping = { "/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType}, "/set/data/ctranslate2_weight_type": {"status": True, "variable":controller.setCtranslate2WeightType}, - "/get/data/ctranslate2_compute_type_list": {"status": True, "variable":controller.getCtranslate2ComputeTypeList}, - "/get/data/ctranslate2_compute_type": {"status": True, "variable":controller.getCtranslate2ComputeType}, "/set/data/ctranslate2_compute_type": {"status": True, "variable":controller.setCtranslate2ComputeType}, @@ -270,8 +268,6 @@ mapping = { "/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType}, "/set/data/whisper_weight_type": {"status": True, "variable":controller.setWhisperWeightType}, - "/get/data/whisper_compute_type_list": {"status": True, "variable":controller.getWhisperComputeTypeList}, - "/get/data/whisper_compute_type": {"status": True, "variable":controller.getWhisperComputeType}, "/set/data/whisper_compute_type": {"status": True, "variable":controller.setWhisperComputeType}, diff --git a/src-python/utils.py b/src-python/utils.py index 1b28fcf6..fab62d51 100644 --- a/src-python/utils.py +++ b/src-python/utils.py @@ -5,6 +5,7 @@ import traceback import logging from logging.handlers import RotatingFileHandler +import torch from ctranslate2 import get_supported_compute_types import requests import ipaddress @@ -78,17 +79,67 @@ def isValidIpAddress(ip_address: str) -> bool: except ValueError: return False -def getComputeTypeList() -> list: - return ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"] +def getComputeDeviceList() -> dict: + compute_types = [ + { + "device": "cpu", + "device_index": 0, + "device_name": "cpu", + "compute_types": ["auto"] + list(get_supported_compute_types("cpu", 0)), + } + ] -def getBestComputeType(device, device_index) -> str: - compute_types = get_supported_compute_types(device, device_index) - compute_types = set(compute_types) - preferred_types = getComputeTypeList() + if torch.cuda.is_available(): + for device_index in range(torch.cuda.device_count()): + gpu_device_name = torch.cuda.get_device_name(device_index) + gpu_compute_types = ["auto"] + list(get_supported_compute_types("cuda", device_index)) - for preferred_type in preferred_types: - if preferred_type in compute_types: - return preferred_type + # デバイスごとの計算タイプの制限 + if "GTX" in gpu_device_name: + unsupported_types = {"int8_bfloat16", "bfloat16", "float16", "int8"} + gpu_compute_types = [t for t in gpu_compute_types if t not in unsupported_types] + elif not any(keyword in gpu_device_name for keyword in ["RTX", "Tesla", "A100", "Quadro"]): + gpu_compute_types = ["float32"] + + compute_types.append( + { + "device": "cuda", + "device_index": device_index, + "device_name": gpu_device_name, + "compute_types": gpu_compute_types, + } + ) + + return compute_types + +def getBestComputeType(device: str, device_index: int) -> str: + compute_types = set(get_supported_compute_types(device, device_index)) + device_name = "cpu" if device == "cpu" else torch.cuda.get_device_name(device_index) + + # デバイスごとの優先計算タイプ + preferred_types = { + "default": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"], + "GTX": ["float32"], + "RTX": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"], + "Tesla": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"], + "A100": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"], + "Quadro": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"], + } + + # デバイス名に基づいて優先タイプを選択 + for key in preferred_types: + if key in device_name: + selected_types = preferred_types[key] + break + else: + selected_types = preferred_types["default"] + + # 利用可能な計算タイプを返す + for compute_type in selected_types: + if compute_type in compute_types: + return compute_type + + return "float32" def encodeBase64(data:str) -> dict: return json.loads(base64.b64decode(data).decode('utf-8')) @@ -178,4 +229,7 @@ def errorLogging() -> None: if error_logger is None: error_logger = setupLogger("error", "error.log", logging.ERROR) - error_logger.error(traceback.format_exc()) \ No newline at end of file + error_logger.error(traceback.format_exc()) + +if __name__ == "__main__": + print(getComputeDeviceList()) \ No newline at end of file