[Update] Refactor compute type management: unify device list retrieval and remove deprecated methods
This commit is contained in:
@@ -5,6 +5,7 @@ import traceback
|
||||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import torch
|
||||
from ctranslate2 import get_supported_compute_types
|
||||
import requests
|
||||
import ipaddress
|
||||
@@ -78,17 +79,67 @@ def isValidIpAddress(ip_address: str) -> bool:
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def getComputeTypeList() -> list:
|
||||
return ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"]
|
||||
def getComputeDeviceList() -> dict:
|
||||
compute_types = [
|
||||
{
|
||||
"device": "cpu",
|
||||
"device_index": 0,
|
||||
"device_name": "cpu",
|
||||
"compute_types": ["auto"] + list(get_supported_compute_types("cpu", 0)),
|
||||
}
|
||||
]
|
||||
|
||||
def getBestComputeType(device, device_index) -> str:
|
||||
compute_types = get_supported_compute_types(device, device_index)
|
||||
compute_types = set(compute_types)
|
||||
preferred_types = getComputeTypeList()
|
||||
if torch.cuda.is_available():
|
||||
for device_index in range(torch.cuda.device_count()):
|
||||
gpu_device_name = torch.cuda.get_device_name(device_index)
|
||||
gpu_compute_types = ["auto"] + list(get_supported_compute_types("cuda", device_index))
|
||||
|
||||
for preferred_type in preferred_types:
|
||||
if preferred_type in compute_types:
|
||||
return preferred_type
|
||||
# デバイスごとの計算タイプの制限
|
||||
if "GTX" in gpu_device_name:
|
||||
unsupported_types = {"int8_bfloat16", "bfloat16", "float16", "int8"}
|
||||
gpu_compute_types = [t for t in gpu_compute_types if t not in unsupported_types]
|
||||
elif not any(keyword in gpu_device_name for keyword in ["RTX", "Tesla", "A100", "Quadro"]):
|
||||
gpu_compute_types = ["float32"]
|
||||
|
||||
compute_types.append(
|
||||
{
|
||||
"device": "cuda",
|
||||
"device_index": device_index,
|
||||
"device_name": gpu_device_name,
|
||||
"compute_types": gpu_compute_types,
|
||||
}
|
||||
)
|
||||
|
||||
return compute_types
|
||||
|
||||
def getBestComputeType(device: str, device_index: int) -> str:
|
||||
compute_types = set(get_supported_compute_types(device, device_index))
|
||||
device_name = "cpu" if device == "cpu" else torch.cuda.get_device_name(device_index)
|
||||
|
||||
# デバイスごとの優先計算タイプ
|
||||
preferred_types = {
|
||||
"default": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||
"GTX": ["float32"],
|
||||
"RTX": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||
"Tesla": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||
"A100": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||
"Quadro": ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"],
|
||||
}
|
||||
|
||||
# デバイス名に基づいて優先タイプを選択
|
||||
for key in preferred_types:
|
||||
if key in device_name:
|
||||
selected_types = preferred_types[key]
|
||||
break
|
||||
else:
|
||||
selected_types = preferred_types["default"]
|
||||
|
||||
# 利用可能な計算タイプを返す
|
||||
for compute_type in selected_types:
|
||||
if compute_type in compute_types:
|
||||
return compute_type
|
||||
|
||||
return "float32"
|
||||
|
||||
def encodeBase64(data:str) -> dict:
|
||||
return json.loads(base64.b64decode(data).decode('utf-8'))
|
||||
@@ -178,4 +229,7 @@ def errorLogging() -> None:
|
||||
if error_logger is None:
|
||||
error_logger = setupLogger("error", "error.log", logging.ERROR)
|
||||
|
||||
error_logger.error(traceback.format_exc())
|
||||
error_logger.error(traceback.format_exc())
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(getComputeDeviceList())
|
||||
Reference in New Issue
Block a user