👍️[Update] Model: cpu/cudaをtranslationもしくはtranscriptionで選択できるように実装
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from os import path as os_path, makedirs as os_makedirs
|
||||
from requests import get as requests_get
|
||||
from typing import Callable
|
||||
import torch
|
||||
import huggingface_hub
|
||||
from faster_whisper import WhisperModel
|
||||
import logging
|
||||
@@ -51,7 +52,7 @@ def checkWhisperWeight(root, weight_type):
|
||||
try:
|
||||
WhisperModel(
|
||||
path,
|
||||
device="cuda",
|
||||
device="cpu",
|
||||
device_index=0,
|
||||
compute_type="int8",
|
||||
cpu_threads=4,
|
||||
@@ -75,13 +76,14 @@ def downloadWhisperWeight(root, weight_type, callbackFunc):
|
||||
url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename)
|
||||
downloadFile(url, file_path, func=callbackFunc)
|
||||
|
||||
def getWhisperModel(root, weight_type):
|
||||
def getWhisperModel(root, weight_type, device="cpu", device_index=0):
|
||||
path = os_path.join(root, "weights", "whisper", weight_type)
|
||||
compute_type = "int8" if device == "cpu" else "float16"
|
||||
return WhisperModel(
|
||||
path,
|
||||
device="cuda",
|
||||
device_index=0,
|
||||
compute_type="int8",
|
||||
device=device,
|
||||
device_index=device_index,
|
||||
compute_type=compute_type,
|
||||
cpu_threads=4,
|
||||
num_workers=1,
|
||||
local_files_only=True,
|
||||
|
||||
Reference in New Issue
Block a user