[Update] Add compute type management for CTranslate2 and Whisper models
This commit is contained in:
@@ -78,10 +78,13 @@ def isValidIpAddress(ip_address: str) -> bool:
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def getComputeTypeList() -> list:
|
||||
return ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"]
|
||||
|
||||
def getBestComputeType(device, device_index) -> str:
|
||||
compute_types = get_supported_compute_types(device, device_index)
|
||||
compute_types = set(compute_types)
|
||||
preferred_types = ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"]
|
||||
preferred_types = getComputeTypeList()
|
||||
|
||||
for preferred_type in preferred_types:
|
||||
if preferred_type in compute_types:
|
||||
|
||||
Reference in New Issue
Block a user