🐛[bugfix] Model : Added the ability to automatically select calculation types from GPU devices

This commit is contained in:
misyaguziya
2025-01-02 08:32:35 +09:00
parent 7616a11a46
commit 567907fc4d
3 changed files with 12 additions and 60 deletions

View File

@@ -4,6 +4,7 @@ from typing import Callable
import huggingface_hub
from faster_whisper import WhisperModel
import logging
from utils import getBestComputeType
logger = logging.getLogger('faster_whisper')
logger.setLevel(logging.CRITICAL)
@@ -73,7 +74,7 @@ def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None):
def getWhisperModel(root, weight_type, device="cpu", device_index=0):
path = os_path.join(root, "weights", "whisper", weight_type)
compute_type = "int8" if device == "cpu" else "float16"
compute_type = getBestComputeType(device, device_index)
return WhisperModel(
path,
device=device,

View File

@@ -6,7 +6,7 @@ from .translation_utils import ctranslate2_weights
import ctranslate2
import transformers
from utils import errorLogging
from utils import errorLogging, getBestComputeType
import warnings
warnings.filterwarnings("ignore")
@@ -37,7 +37,7 @@ class Translator():
weight_path = os_path.join(path, "weights", "ctranslate2", directory_name)
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
compute_type = "int8" if device == "cpu" else "float16"
compute_type = getBestComputeType(device, device_index)
self.ctranslate2_translator = ctranslate2.Translator(
weight_path,
device=device,

View File

@@ -1,68 +1,19 @@
import base64
from typing import Any
import json
import random
from typing import Union
from os import path as os_path, rename as os_rename
import traceback
import logging
from PIL.Image import open as Image_open
def getImageFile(file_name):
img = Image_open(os_path.join(os_path.dirname(__file__), "img", file_name))
return img
from ctranslate2 import get_supported_compute_types
def callFunctionIfCallable(function, *args):
if callable(function) is True:
function(*args)
def getBestComputeType(device, device_index) -> str:
compute_types = get_supported_compute_types(device, device_index)
compute_types = set(compute_types)
preferred_types = ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"]
def isEven(number):
return number % 2 == 0
def makeEven(number, minus:bool=False):
if minus is True:
return number if isEven(number) else number - 1
return number if isEven(number) else number + 1
def intToPctStr(value:int):
return f"{value}%"
def floatToPctStr(value:float):
return f"{int(value*100)}%"
def strPctToInt(value:str):
return int(value.replace("%", ""))
def isUniqueStrings(unique_strings:Union[str, list], input_string:str, require=False):
import re
if isinstance(unique_strings, str):
unique_strings = [unique_strings]
patterns = [re.escape(s) for s in unique_strings]
counts = [len(re.findall(pattern, input_string)) for pattern in patterns]
if require is True:
# If require is True, unique_strings must appear once
return all(count == 1 for count in counts) and counts.count(1) == 2
else:
# If require is False, check if unique strings are used exactly once
return all(count == 1 for count in counts)
# path先のweightフォルダがある場合にはそのフォルダ名をweightsに変更する
def renameWeightFolder(path):
weight_path = os_path.join(path, "weight")
if os_path.exists(weight_path):
os_rename(weight_path, os_path.join(path, "weights"))
def splitList(lst:list, split_count:int, to_shuffle:bool=False):
if to_shuffle is True:
random.shuffle(lst)
split_lists = []
for i in range(0, len(lst), split_count):
sub_list = lst[i:i+split_count]
split_lists.append(sub_list)
return split_lists
for preferred_type in preferred_types:
if preferred_type in compute_types:
return preferred_type
def encodeBase64(data:str) -> dict:
return json.loads(base64.b64decode(data).decode('utf-8'))