🐛[bugfix] Model : Added the ability to automatically select calculation types from GPU devices
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Callable
|
|||||||
import huggingface_hub
|
import huggingface_hub
|
||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
import logging
|
import logging
|
||||||
|
from utils import getBestComputeType
|
||||||
|
|
||||||
logger = logging.getLogger('faster_whisper')
|
logger = logging.getLogger('faster_whisper')
|
||||||
logger.setLevel(logging.CRITICAL)
|
logger.setLevel(logging.CRITICAL)
|
||||||
@@ -73,7 +74,7 @@ def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None):
|
|||||||
|
|
||||||
def getWhisperModel(root, weight_type, device="cpu", device_index=0):
|
def getWhisperModel(root, weight_type, device="cpu", device_index=0):
|
||||||
path = os_path.join(root, "weights", "whisper", weight_type)
|
path = os_path.join(root, "weights", "whisper", weight_type)
|
||||||
compute_type = "int8" if device == "cpu" else "float16"
|
compute_type = getBestComputeType(device, device_index)
|
||||||
return WhisperModel(
|
return WhisperModel(
|
||||||
path,
|
path,
|
||||||
device=device,
|
device=device,
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from .translation_utils import ctranslate2_weights
|
|||||||
|
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import transformers
|
import transformers
|
||||||
from utils import errorLogging
|
from utils import errorLogging, getBestComputeType
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
@@ -37,7 +37,7 @@ class Translator():
|
|||||||
weight_path = os_path.join(path, "weights", "ctranslate2", directory_name)
|
weight_path = os_path.join(path, "weights", "ctranslate2", directory_name)
|
||||||
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
||||||
|
|
||||||
compute_type = "int8" if device == "cpu" else "float16"
|
compute_type = getBestComputeType(device, device_index)
|
||||||
self.ctranslate2_translator = ctranslate2.Translator(
|
self.ctranslate2_translator = ctranslate2.Translator(
|
||||||
weight_path,
|
weight_path,
|
||||||
device=device,
|
device=device,
|
||||||
|
|||||||
@@ -1,68 +1,19 @@
|
|||||||
import base64
|
import base64
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import json
|
import json
|
||||||
import random
|
|
||||||
from typing import Union
|
|
||||||
from os import path as os_path, rename as os_rename
|
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from PIL.Image import open as Image_open
|
|
||||||
|
|
||||||
def getImageFile(file_name):
|
from ctranslate2 import get_supported_compute_types
|
||||||
img = Image_open(os_path.join(os_path.dirname(__file__), "img", file_name))
|
|
||||||
return img
|
|
||||||
|
|
||||||
def callFunctionIfCallable(function, *args):
|
def getBestComputeType(device, device_index) -> str:
|
||||||
if callable(function) is True:
|
compute_types = get_supported_compute_types(device, device_index)
|
||||||
function(*args)
|
compute_types = set(compute_types)
|
||||||
|
preferred_types = ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"]
|
||||||
|
|
||||||
def isEven(number):
|
for preferred_type in preferred_types:
|
||||||
return number % 2 == 0
|
if preferred_type in compute_types:
|
||||||
|
return preferred_type
|
||||||
def makeEven(number, minus:bool=False):
|
|
||||||
if minus is True:
|
|
||||||
return number if isEven(number) else number - 1
|
|
||||||
return number if isEven(number) else number + 1
|
|
||||||
|
|
||||||
def intToPctStr(value:int):
|
|
||||||
return f"{value}%"
|
|
||||||
|
|
||||||
def floatToPctStr(value:float):
|
|
||||||
return f"{int(value*100)}%"
|
|
||||||
|
|
||||||
def strPctToInt(value:str):
|
|
||||||
return int(value.replace("%", ""))
|
|
||||||
|
|
||||||
def isUniqueStrings(unique_strings:Union[str, list], input_string:str, require=False):
|
|
||||||
import re
|
|
||||||
if isinstance(unique_strings, str):
|
|
||||||
unique_strings = [unique_strings]
|
|
||||||
patterns = [re.escape(s) for s in unique_strings]
|
|
||||||
|
|
||||||
counts = [len(re.findall(pattern, input_string)) for pattern in patterns]
|
|
||||||
|
|
||||||
if require is True:
|
|
||||||
# If require is True, unique_strings must appear once
|
|
||||||
return all(count == 1 for count in counts) and counts.count(1) == 2
|
|
||||||
else:
|
|
||||||
# If require is False, check if unique strings are used exactly once
|
|
||||||
return all(count == 1 for count in counts)
|
|
||||||
|
|
||||||
# path先のweightフォルダがある場合にはそのフォルダ名をweightsに変更する
|
|
||||||
def renameWeightFolder(path):
|
|
||||||
weight_path = os_path.join(path, "weight")
|
|
||||||
if os_path.exists(weight_path):
|
|
||||||
os_rename(weight_path, os_path.join(path, "weights"))
|
|
||||||
|
|
||||||
def splitList(lst:list, split_count:int, to_shuffle:bool=False):
|
|
||||||
if to_shuffle is True:
|
|
||||||
random.shuffle(lst)
|
|
||||||
|
|
||||||
split_lists = []
|
|
||||||
for i in range(0, len(lst), split_count):
|
|
||||||
sub_list = lst[i:i+split_count]
|
|
||||||
split_lists.append(sub_list)
|
|
||||||
return split_lists
|
|
||||||
|
|
||||||
def encodeBase64(data:str) -> dict:
|
def encodeBase64(data:str) -> dict:
|
||||||
return json.loads(base64.b64decode(data).decode('utf-8'))
|
return json.loads(base64.b64decode(data).decode('utf-8'))
|
||||||
|
|||||||
Reference in New Issue
Block a user