From 567907fc4db809e25810573c48f499e9e04455f8 Mon Sep 17 00:00:00 2001 From: misyaguziya <53165965+misyaguziya@users.noreply.github.com> Date: Thu, 2 Jan 2025 08:32:35 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B[bugfix]=20Model=20:=20Added=20the?= =?UTF-8?q?=20ability=20to=20automatically=20select=20calculation=20types?= =?UTF-8?q?=20from=20GPU=20devices?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../transcription/transcription_whisper.py | 3 +- .../translation/translation_translator.py | 4 +- src-python/utils.py | 65 +++---------------- 3 files changed, 12 insertions(+), 60 deletions(-) diff --git a/src-python/models/transcription/transcription_whisper.py b/src-python/models/transcription/transcription_whisper.py index 3eb574c7..080054b5 100644 --- a/src-python/models/transcription/transcription_whisper.py +++ b/src-python/models/transcription/transcription_whisper.py @@ -4,6 +4,7 @@ from typing import Callable import huggingface_hub from faster_whisper import WhisperModel import logging +from utils import getBestComputeType logger = logging.getLogger('faster_whisper') logger.setLevel(logging.CRITICAL) @@ -73,7 +74,7 @@ def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None): def getWhisperModel(root, weight_type, device="cpu", device_index=0): path = os_path.join(root, "weights", "whisper", weight_type) - compute_type = "int8" if device == "cpu" else "float16" + compute_type = getBestComputeType(device, device_index) return WhisperModel( path, device=device, diff --git a/src-python/models/translation/translation_translator.py b/src-python/models/translation/translation_translator.py index 44004155..fad96744 100644 --- a/src-python/models/translation/translation_translator.py +++ b/src-python/models/translation/translation_translator.py @@ -6,7 +6,7 @@ from .translation_utils import ctranslate2_weights import ctranslate2 import transformers -from utils import errorLogging +from utils import errorLogging, getBestComputeType import warnings warnings.filterwarnings("ignore") @@ -37,7 +37,7 @@ class Translator(): weight_path = os_path.join(path, "weights", "ctranslate2", directory_name) tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer") - compute_type = "int8" if device == "cpu" else "float16" + compute_type = getBestComputeType(device, device_index) self.ctranslate2_translator = ctranslate2.Translator( weight_path, device=device, diff --git a/src-python/utils.py b/src-python/utils.py index 31d589bd..8e23af38 100644 --- a/src-python/utils.py +++ b/src-python/utils.py @@ -1,68 +1,19 @@ import base64 from typing import Any import json -import random -from typing import Union -from os import path as os_path, rename as os_rename import traceback import logging -from PIL.Image import open as Image_open -def getImageFile(file_name): - img = Image_open(os_path.join(os_path.dirname(__file__), "img", file_name)) - return img +from ctranslate2 import get_supported_compute_types -def callFunctionIfCallable(function, *args): - if callable(function) is True: - function(*args) +def getBestComputeType(device, device_index) -> str: + compute_types = get_supported_compute_types(device, device_index) + compute_types = set(compute_types) + preferred_types = ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"] -def isEven(number): - return number % 2 == 0 - -def makeEven(number, minus:bool=False): - if minus is True: - return number if isEven(number) else number - 1 - return number if isEven(number) else number + 1 - -def intToPctStr(value:int): - return f"{value}%" - -def floatToPctStr(value:float): - return f"{int(value*100)}%" - -def strPctToInt(value:str): - return int(value.replace("%", "")) - -def isUniqueStrings(unique_strings:Union[str, list], input_string:str, require=False): - import re - if isinstance(unique_strings, str): - unique_strings = [unique_strings] - patterns = [re.escape(s) for s in unique_strings] - - counts = [len(re.findall(pattern, input_string)) for pattern in patterns] - - if require is True: - # If require is True, unique_strings must appear once - return all(count == 1 for count in counts) and counts.count(1) == 2 - else: - # If require is False, check if unique strings are used exactly once - return all(count == 1 for count in counts) - -# path先のweightフォルダがある場合にはそのフォルダ名をweightsに変更する -def renameWeightFolder(path): - weight_path = os_path.join(path, "weight") - if os_path.exists(weight_path): - os_rename(weight_path, os_path.join(path, "weights")) - -def splitList(lst:list, split_count:int, to_shuffle:bool=False): - if to_shuffle is True: - random.shuffle(lst) - - split_lists = [] - for i in range(0, len(lst), split_count): - sub_list = lst[i:i+split_count] - split_lists.append(sub_list) - return split_lists + for preferred_type in preferred_types: + if preferred_type in compute_types: + return preferred_type def encodeBase64(data:str) -> dict: return json.loads(base64.b64decode(data).decode('utf-8'))