[Update] Add compute type management for CTranslate2 and Whisper models

This commit is contained in:
misyaguziya
2025-09-17 10:35:34 +09:00
parent 6bf09970a4
commit 245855d0ca
8 changed files with 98 additions and 12 deletions

View File

@@ -21,7 +21,7 @@ PHRASE_TIMEOUT = 3
MAX_PHRASES = 10
class AudioTranscriber:
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0):
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0, compute_type="auto"):
self.speaker = speaker
self.phrase_timeout = phrase_timeout
self.max_phrases = max_phrases
@@ -41,7 +41,7 @@ class AudioTranscriber:
}
if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True:
self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index)
self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index, compute_type=compute_type)
self.transcription_engine = "Whisper"
def transcribeAudioQueue(self, audio_queue, languages, countries, avg_logprob=-0.8, no_speech_prob=0.6):

View File

@@ -74,9 +74,10 @@ def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None):
if isinstance(end_callback, Callable):
end_callback()
def getWhisperModel(root, weight_type, device="cpu", device_index=0):
def getWhisperModel(root, weight_type, device="cpu", device_index=0, compute_type="auto"):
path = os_path.join(root, "weights", "whisper", weight_type)
compute_type = getBestComputeType(device, device_index)
if compute_type == "auto":
compute_type = getBestComputeType(device, device_index)
try:
model = WhisperModel(
path,