Merge branch 'bugfix_compute_type' into develop
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Callable
|
||||
import huggingface_hub
|
||||
from faster_whisper import WhisperModel
|
||||
import logging
|
||||
from utils import getBestComputeType
|
||||
|
||||
logger = logging.getLogger('faster_whisper')
|
||||
logger.setLevel(logging.CRITICAL)
|
||||
@@ -73,7 +74,7 @@ def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None):
|
||||
|
||||
def getWhisperModel(root, weight_type, device="cpu", device_index=0):
|
||||
path = os_path.join(root, "weights", "whisper", weight_type)
|
||||
compute_type = "int8" if device == "cpu" else "float16"
|
||||
compute_type = getBestComputeType(device, device_index)
|
||||
return WhisperModel(
|
||||
path,
|
||||
device=device,
|
||||
|
||||
@@ -6,7 +6,7 @@ from .translation_utils import ctranslate2_weights
|
||||
|
||||
import ctranslate2
|
||||
import transformers
|
||||
from utils import errorLogging
|
||||
from utils import errorLogging, getBestComputeType
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
@@ -37,7 +37,7 @@ class Translator():
|
||||
weight_path = os_path.join(path, "weights", "ctranslate2", directory_name)
|
||||
tokenizer_path = os_path.join(path, "weights", "ctranslate2", directory_name, "tokenizer")
|
||||
|
||||
compute_type = "int8" if device == "cpu" else "float16"
|
||||
compute_type = getBestComputeType(device, device_index)
|
||||
self.ctranslate2_translator = ctranslate2.Translator(
|
||||
weight_path,
|
||||
device=device,
|
||||
|
||||
@@ -1,68 +1,19 @@
|
||||
import base64
|
||||
from typing import Any
|
||||
import json
|
||||
import random
|
||||
from typing import Union
|
||||
from os import path as os_path, rename as os_rename
|
||||
import traceback
|
||||
import logging
|
||||
from PIL.Image import open as Image_open
|
||||
|
||||
def getImageFile(file_name):
|
||||
img = Image_open(os_path.join(os_path.dirname(__file__), "img", file_name))
|
||||
return img
|
||||
from ctranslate2 import get_supported_compute_types
|
||||
|
||||
def callFunctionIfCallable(function, *args):
|
||||
if callable(function) is True:
|
||||
function(*args)
|
||||
def getBestComputeType(device, device_index) -> str:
|
||||
compute_types = get_supported_compute_types(device, device_index)
|
||||
compute_types = set(compute_types)
|
||||
preferred_types = ["int8_bfloat16", "int8_float16", "int8", "bfloat16", "float16", "int8_float32", "float32"]
|
||||
|
||||
def isEven(number):
|
||||
return number % 2 == 0
|
||||
|
||||
def makeEven(number, minus:bool=False):
|
||||
if minus is True:
|
||||
return number if isEven(number) else number - 1
|
||||
return number if isEven(number) else number + 1
|
||||
|
||||
def intToPctStr(value:int):
|
||||
return f"{value}%"
|
||||
|
||||
def floatToPctStr(value:float):
|
||||
return f"{int(value*100)}%"
|
||||
|
||||
def strPctToInt(value:str):
|
||||
return int(value.replace("%", ""))
|
||||
|
||||
def isUniqueStrings(unique_strings:Union[str, list], input_string:str, require=False):
|
||||
import re
|
||||
if isinstance(unique_strings, str):
|
||||
unique_strings = [unique_strings]
|
||||
patterns = [re.escape(s) for s in unique_strings]
|
||||
|
||||
counts = [len(re.findall(pattern, input_string)) for pattern in patterns]
|
||||
|
||||
if require is True:
|
||||
# If require is True, unique_strings must appear once
|
||||
return all(count == 1 for count in counts) and counts.count(1) == 2
|
||||
else:
|
||||
# If require is False, check if unique strings are used exactly once
|
||||
return all(count == 1 for count in counts)
|
||||
|
||||
# path先のweightフォルダがある場合にはそのフォルダ名をweightsに変更する
|
||||
def renameWeightFolder(path):
|
||||
weight_path = os_path.join(path, "weight")
|
||||
if os_path.exists(weight_path):
|
||||
os_rename(weight_path, os_path.join(path, "weights"))
|
||||
|
||||
def splitList(lst:list, split_count:int, to_shuffle:bool=False):
|
||||
if to_shuffle is True:
|
||||
random.shuffle(lst)
|
||||
|
||||
split_lists = []
|
||||
for i in range(0, len(lst), split_count):
|
||||
sub_list = lst[i:i+split_count]
|
||||
split_lists.append(sub_list)
|
||||
return split_lists
|
||||
for preferred_type in preferred_types:
|
||||
if preferred_type in compute_types:
|
||||
return preferred_type
|
||||
|
||||
def encodeBase64(data:str) -> dict:
|
||||
return json.loads(base64.b64decode(data).decode('utf-8'))
|
||||
|
||||
Reference in New Issue
Block a user