👍️[Update] Model: cpu/cudaをtranslationもしくはtranscriptionで選択できるように実装

This commit is contained in:
misyaguziya
2024-10-23 13:41:34 +09:00
parent 2136865493
commit af3fe1f0f9
7 changed files with 84 additions and 13 deletions

View File

@@ -1,6 +1,7 @@
from os import path as os_path, makedirs as os_makedirs
from requests import get as requests_get
from typing import Callable
import torch
import huggingface_hub
from faster_whisper import WhisperModel
import logging
@@ -51,7 +52,7 @@ def checkWhisperWeight(root, weight_type):
try:
WhisperModel(
path,
device="cuda",
device="cpu",
device_index=0,
compute_type="int8",
cpu_threads=4,
@@ -75,13 +76,14 @@ def downloadWhisperWeight(root, weight_type, callbackFunc):
url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename)
downloadFile(url, file_path, func=callbackFunc)
def getWhisperModel(root, weight_type):
def getWhisperModel(root, weight_type, device="cpu", device_index=0):
path = os_path.join(root, "weights", "whisper", weight_type)
compute_type = "int8" if device == "cpu" else "float16"
return WhisperModel(
path,
device="cuda",
device_index=0,
compute_type="int8",
device=device,
device_index=device_index,
compute_type=compute_type,
cpu_threads=4,
num_workers=1,
local_files_only=True,