diff --git a/src-python/models/transcription/transcription_whisper.py b/src-python/models/transcription/transcription_whisper.py index 3eb574c7..080054b5 100644 --- a/src-python/models/transcription/transcription_whisper.py +++ b/src-python/models/transcription/transcription_whisper.py @@ -4,6 +4,7 @@ from typing import Callable import huggingface_hub from faster_whisper import WhisperModel import logging +from utils import getBestComputeType logger = logging.getLogger('faster_whisper') logger.setLevel(logging.CRITICAL) @@ -73,7 +74,7 @@ def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None): def getWhisperModel(root, weight_type, device="cpu", device_index=0): path = os_path.join(root, "weights", "whisper", weight_type) - compute_type = "int8" if device == "cpu" else "float16" + compute_type = getBestComputeType(device, device_index) return WhisperModel( path, device=device, diff --git a/src-python/models/translation/translation_translator.py b/src-python/models/translation/translation_translator.py index 44004155..fad96744 100644 --- a/src-python/models/translation/translation_translator.py +++ b/src-python/models/translation/translation_translator.py @@ -6,7 +6,7 @@ from .translation_utils import ctranslate2_weights import ctranslate2 import transformers -from utils import errorLogging +from utils import errorLogging, getBestComputeType import warnings warnings.filterwarnings("ignore") @@ -37,7 +37,7 @@ class Translator(): weight_path = os_path.join(path, "weights", "ctranslate2", directory_name) tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer") - compute_type = "int8" if device == "cpu" else "float16" + compute_type = getBestComputeType(device, device_index) self.ctranslate2_translator = ctranslate2.Translator( weight_path, device=device, diff --git a/src-python/utils.py b/src-python/utils.py index 31d589bd..8e23af38 100644 --- a/src-python/utils.py +++ b/src-python/utils.py @@ -1,68 +1,19 @@ import base64 from typing import Any import json -import random -from typing import Union -from os import path as os_path, rename as os_rename import traceback import logging -from PIL.Image import open as Image_open -def getImageFile(file_name): - img = Image_open(os_path.join(os_path.dirname(__file__), "img", file_name)) - return img +from ctranslate2 import get_supported_compute_types -def callFunctionIfCallable(function, *args): - if callable(function) is True: - function(*args) +def getBestComputeType(device, device_index) -> str: + compute_types = get_supported_compute_types(device, device_index) + compute_types = set(compute_types) + preferred_types = ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"] -def isEven(number): - return number % 2 == 0 - -def makeEven(number, minus:bool=False): - if minus is True: - return number if isEven(number) else number - 1 - return number if isEven(number) else number + 1 - -def intToPctStr(value:int): - return f"{value}%" - -def floatToPctStr(value:float): - return f"{int(value*100)}%" - -def strPctToInt(value:str): - return int(value.replace("%", "")) - -def isUniqueStrings(unique_strings:Union[str, list], input_string:str, require=False): - import re - if isinstance(unique_strings, str): - unique_strings = [unique_strings] - patterns = [re.escape(s) for s in unique_strings] - - counts = [len(re.findall(pattern, input_string)) for pattern in patterns] - - if require is True: - # If require is True, unique_strings must appear once - return all(count == 1 for count in counts) and counts.count(1) == 2 - else: - # If require is False, check if unique strings are used exactly once - return all(count == 1 for count in counts) - -# path先のweightフォルダがある場合にはそのフォルダ名をweightsに変更する -def renameWeightFolder(path): - weight_path = os_path.join(path, "weight") - if os_path.exists(weight_path): - os_rename(weight_path, os_path.join(path, "weights")) - -def splitList(lst:list, split_count:int, to_shuffle:bool=False): - if to_shuffle is True: - random.shuffle(lst) - - split_lists = [] - for i in range(0, len(lst), split_count): - sub_list = lst[i:i+split_count] - split_lists.append(sub_list) - return split_lists + for preferred_type in preferred_types: + if preferred_type in compute_types: + return preferred_type def encodeBase64(data:str) -> dict: return json.loads(base64.b64decode(data).decode('utf-8'))