[Update] Refactor compute type management: unify device list retrieval and remove deprecated methods
This commit is contained in:
@@ -11,7 +11,7 @@ from models.translation.translation_languages import translation_lang
|
|||||||
from models.translation.translation_utils import ctranslate2_weights
|
from models.translation.translation_utils import ctranslate2_weights
|
||||||
from models.transcription.transcription_languages import transcription_lang
|
from models.transcription.transcription_languages import transcription_lang
|
||||||
from models.transcription.transcription_whisper import _MODELS as whisper_models
|
from models.transcription.transcription_whisper import _MODELS as whisper_models
|
||||||
from utils import errorLogging, validateDictStructure, getComputeTypeList
|
from utils import errorLogging, validateDictStructure, getComputeDeviceList
|
||||||
|
|
||||||
json_serializable_vars = {}
|
json_serializable_vars = {}
|
||||||
def json_serializable(var_name):
|
def json_serializable(var_name):
|
||||||
@@ -135,14 +135,6 @@ class Config:
|
|||||||
def SELECTABLE_COMPUTE_DEVICE_LIST(self):
|
def SELECTABLE_COMPUTE_DEVICE_LIST(self):
|
||||||
return self._SELECTABLE_COMPUTE_DEVICE_LIST
|
return self._SELECTABLE_COMPUTE_DEVICE_LIST
|
||||||
|
|
||||||
@property
|
|
||||||
def SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST(self):
|
|
||||||
return self._SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST
|
|
||||||
|
|
||||||
@property
|
|
||||||
def SELECTABLE_WHISPER_COMPUTE_TYPE_LIST(self):
|
|
||||||
return self._SELECTABLE_WHISPER_COMPUTE_TYPE_LIST
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def SEND_MESSAGE_BUTTON_TYPE_LIST(self):
|
def SEND_MESSAGE_BUTTON_TYPE_LIST(self):
|
||||||
return self._SEND_MESSAGE_BUTTON_TYPE_LIST
|
return self._SEND_MESSAGE_BUTTON_TYPE_LIST
|
||||||
@@ -830,7 +822,7 @@ class Config:
|
|||||||
@CTRANSLATE2_COMPUTE_TYPE.setter
|
@CTRANSLATE2_COMPUTE_TYPE.setter
|
||||||
def CTRANSLATE2_COMPUTE_TYPE(self, value):
|
def CTRANSLATE2_COMPUTE_TYPE(self, value):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if value in self.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST:
|
if value in self.SELECTED_TRANSLATION_COMPUTE_DEVICE["compute_type"]:
|
||||||
self._CTRANSLATE2_COMPUTE_TYPE = value
|
self._CTRANSLATE2_COMPUTE_TYPE = value
|
||||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||||
|
|
||||||
@@ -854,7 +846,7 @@ class Config:
|
|||||||
@WHISPER_COMPUTE_TYPE.setter
|
@WHISPER_COMPUTE_TYPE.setter
|
||||||
def WHISPER_COMPUTE_TYPE(self, value):
|
def WHISPER_COMPUTE_TYPE(self, value):
|
||||||
if isinstance(value, str):
|
if isinstance(value, str):
|
||||||
if value in self.SELECTABLE_WHISPER_COMPUTE_TYPE_LIST:
|
if value in self.SELECTED_TRANSCRIPTION_COMPUTE_DEVICE["compute_type"]:
|
||||||
self._WHISPER_COMPUTE_TYPE = value
|
self._WHISPER_COMPUTE_TYPE = value
|
||||||
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
self.saveConfig(inspect.currentframe().f_code.co_name, value)
|
||||||
|
|
||||||
@@ -1078,13 +1070,7 @@ class Config:
|
|||||||
self._SELECTABLE_TRANSCRIPTION_ENGINE_LIST = list(transcription_lang[list(transcription_lang.keys())[0]].values())[0].keys()
|
self._SELECTABLE_TRANSCRIPTION_ENGINE_LIST = list(transcription_lang[list(transcription_lang.keys())[0]].values())[0].keys()
|
||||||
self._SELECTABLE_UI_LANGUAGE_LIST = ["en", "ja", "ko", "zh-Hant", "zh-Hans"]
|
self._SELECTABLE_UI_LANGUAGE_LIST = ["en", "ja", "ko", "zh-Hant", "zh-Hans"]
|
||||||
self._COMPUTE_MODE = "cuda" if torch.cuda.is_available() else "cpu"
|
self._COMPUTE_MODE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
self._SELECTABLE_COMPUTE_DEVICE_LIST = []
|
self._SELECTABLE_COMPUTE_DEVICE_LIST = getComputeDeviceList()
|
||||||
if torch.cuda.is_available():
|
|
||||||
for i in range(torch.cuda.device_count()):
|
|
||||||
self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cuda", "device_index": i, "device_name": torch.cuda.get_device_name(i)})
|
|
||||||
self._SELECTABLE_COMPUTE_DEVICE_LIST.append({"device":"cpu", "device_index": 0, "device_name": "cpu"})
|
|
||||||
self._SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST = ["auto"] + getComputeTypeList()
|
|
||||||
self._SELECTABLE_WHISPER_COMPUTE_TYPE_LIST = ["auto"] + getComputeTypeList()
|
|
||||||
self._SEND_MESSAGE_BUTTON_TYPE_LIST = ["show", "hide", "show_and_disable_enter_key"]
|
self._SEND_MESSAGE_BUTTON_TYPE_LIST = ["show", "hide", "show_and_disable_enter_key"]
|
||||||
self._SEND_MESSAGE_FORMAT_PARTS = {
|
self._SEND_MESSAGE_FORMAT_PARTS = {
|
||||||
"message": {
|
"message": {
|
||||||
|
|||||||
@@ -652,14 +652,6 @@ class Controller:
|
|||||||
def getComputeDeviceList(*args, **kwargs) -> dict:
|
def getComputeDeviceList(*args, **kwargs) -> dict:
|
||||||
return {"status":200, "result":config.SELECTABLE_COMPUTE_DEVICE_LIST}
|
return {"status":200, "result":config.SELECTABLE_COMPUTE_DEVICE_LIST}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def getCTranslate2ComputeTypeList(*args, **kwargs) -> dict:
|
|
||||||
return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def getWhisperComputeTypeList(*args, **kwargs) -> dict:
|
|
||||||
return {"status":200, "result":config.SELECTABLE_WHISPER_COMPUTE_TYPE_LIST}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict:
|
def getSelectedTranslationComputeDevice(*args, **kwargs) -> dict:
|
||||||
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
return {"status":200, "result":config.SELECTED_TRANSLATION_COMPUTE_DEVICE}
|
||||||
@@ -1455,10 +1447,6 @@ class Controller:
|
|||||||
th_callback.join()
|
th_callback.join()
|
||||||
return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE}
|
return {"status":200, "result":config.CTRANSLATE2_WEIGHT_TYPE}
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def getCtranslate2ComputeTypeList(*args, **kwargs) -> dict:
|
|
||||||
return {"status":200, "result":config.SELECTABLE_CTRANSLATE2_COMPUTE_TYPE_LIST}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def getCtranslate2ComputeType(*args, **kwargs) -> dict:
|
def getCtranslate2ComputeType(*args, **kwargs) -> dict:
|
||||||
return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE}
|
return {"status":200, "result":config.CTRANSLATE2_COMPUTE_TYPE}
|
||||||
|
|||||||
@@ -162,8 +162,6 @@ mapping = {
|
|||||||
"/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType},
|
"/get/data/ctranslate2_weight_type": {"status": True, "variable":controller.getCtranslate2WeightType},
|
||||||
"/set/data/ctranslate2_weight_type": {"status": True, "variable":controller.setCtranslate2WeightType},
|
"/set/data/ctranslate2_weight_type": {"status": True, "variable":controller.setCtranslate2WeightType},
|
||||||
|
|
||||||
"/get/data/ctranslate2_compute_type_list": {"status": True, "variable":controller.getCtranslate2ComputeTypeList},
|
|
||||||
|
|
||||||
"/get/data/ctranslate2_compute_type": {"status": True, "variable":controller.getCtranslate2ComputeType},
|
"/get/data/ctranslate2_compute_type": {"status": True, "variable":controller.getCtranslate2ComputeType},
|
||||||
"/set/data/ctranslate2_compute_type": {"status": True, "variable":controller.setCtranslate2ComputeType},
|
"/set/data/ctranslate2_compute_type": {"status": True, "variable":controller.setCtranslate2ComputeType},
|
||||||
|
|
||||||
@@ -270,8 +268,6 @@ mapping = {
|
|||||||
"/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType},
|
"/get/data/whisper_weight_type": {"status": True, "variable":controller.getWhisperWeightType},
|
||||||
"/set/data/whisper_weight_type": {"status": True, "variable":controller.setWhisperWeightType},
|
"/set/data/whisper_weight_type": {"status": True, "variable":controller.setWhisperWeightType},
|
||||||
|
|
||||||
"/get/data/whisper_compute_type_list": {"status": True, "variable":controller.getWhisperComputeTypeList},
|
|
||||||
|
|
||||||
"/get/data/whisper_compute_type": {"status": True, "variable":controller.getWhisperComputeType},
|
"/get/data/whisper_compute_type": {"status": True, "variable":controller.getWhisperComputeType},
|
||||||
"/set/data/whisper_compute_type": {"status": True, "variable":controller.setWhisperComputeType},
|
"/set/data/whisper_compute_type": {"status": True, "variable":controller.setWhisperComputeType},
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import traceback
|
|||||||
import logging
|
import logging
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
|
import torch
|
||||||
from ctranslate2 import get_supported_compute_types
|
from ctranslate2 import get_supported_compute_types
|
||||||
import requests
|
import requests
|
||||||
import ipaddress
|
import ipaddress
|
||||||
@@ -78,17 +79,67 @@ def isValidIpAddress(ip_address: str) -> bool:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def getComputeTypeList() -> list:
|
def getComputeDeviceList() -> dict:
|
||||||
return ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"]
|
compute_types = [
|
||||||
|
{
|
||||||
|
"device": "cpu",
|
||||||
|
"device_index": 0,
|
||||||
|
"device_name": "cpu",
|
||||||
|
"compute_types": ["auto"] + list(get_supported_compute_types("cpu", 0)),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
def getBestComputeType(device, device_index) -> str:
|
if torch.cuda.is_available():
|
||||||
compute_types = get_supported_compute_types(device, device_index)
|
for device_index in range(torch.cuda.device_count()):
|
||||||
compute_types = set(compute_types)
|
gpu_device_name = torch.cuda.get_device_name(device_index)
|
||||||
preferred_types = getComputeTypeList()
|
gpu_compute_types = ["auto"] + list(get_supported_compute_types("cuda", device_index))
|
||||||
|
|
||||||
for preferred_type in preferred_types:
|
# デバイスごとの計算タイプの制限
|
||||||
if preferred_type in compute_types:
|
if "GTX" in gpu_device_name:
|
||||||
return preferred_type
|
unsupported_types = {"int8_bfloat16", "bfloat16", "float16", "int8"}
|
||||||
|
gpu_compute_types = [t for t in gpu_compute_types if t not in unsupported_types]
|
||||||
|
elif not any(keyword in gpu_device_name for keyword in ["RTX", "Tesla", "A100", "Quadro"]):
|
||||||
|
gpu_compute_types = ["float32"]
|
||||||
|
|
||||||
|
compute_types.append(
|
||||||
|
{
|
||||||
|
"device": "cuda",
|
||||||
|
"device_index": device_index,
|
||||||
|
"device_name": gpu_device_name,
|
||||||
|
"compute_types": gpu_compute_types,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return compute_types
|
||||||
|
|
||||||
|
def getBestComputeType(device: str, device_index: int) -> str:
|
||||||
|
compute_types = set(get_supported_compute_types(device, device_index))
|
||||||
|
device_name = "cpu" if device == "cpu" else torch.cuda.get_device_name(device_index)
|
||||||
|
|
||||||
|
# デバイスごとの優先計算タイプ
|
||||||
|
preferred_types = {
|
||||||
|
"default": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||||
|
"GTX": ["float32"],
|
||||||
|
"RTX": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||||
|
"Tesla": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||||
|
"A100": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||||
|
"Quadro": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# デバイス名に基づいて優先タイプを選択
|
||||||
|
for key in preferred_types:
|
||||||
|
if key in device_name:
|
||||||
|
selected_types = preferred_types[key]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
selected_types = preferred_types["default"]
|
||||||
|
|
||||||
|
# 利用可能な計算タイプを返す
|
||||||
|
for compute_type in selected_types:
|
||||||
|
if compute_type in compute_types:
|
||||||
|
return compute_type
|
||||||
|
|
||||||
|
return "float32"
|
||||||
|
|
||||||
def encodeBase64(data:str) -> dict:
|
def encodeBase64(data:str) -> dict:
|
||||||
return json.loads(base64.b64decode(data).decode('utf-8'))
|
return json.loads(base64.b64decode(data).decode('utf-8'))
|
||||||
@@ -179,3 +230,6 @@ def errorLogging() -> None:
|
|||||||
error_logger = setupLogger("error", "error.log", logging.ERROR)
|
error_logger = setupLogger("error", "error.log", logging.ERROR)
|
||||||
|
|
||||||
error_logger.error(traceback.format_exc())
|
error_logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(getComputeDeviceList())
|
||||||
Reference in New Issue
Block a user