ドキュメントを更新し、型注釈を追加してコードの可読性とメンテナンス性を向上。各モジュールの使用例や依存関係を明示化し、エラーハンドリングを改善。

This commit is contained in:
misyaguziya
2025-10-09 17:35:55 +09:00
parent b26129af68
commit 690a2f081b
5 changed files with 276 additions and 60 deletions

View File

@@ -1,3 +1,78 @@
## 文字起こしモジュール (models.transcription)
このドキュメントでは `models/transcription` に関する設計・セットアップ・使用例・テスト方針・トラブルシュートをまとめます。
### 概要
- `models/transcription` は音声入力をテキストに変換する機能を提供します。主に:
- `transcription_recorder.py` — マイクやスピーカからの音声取得ラッパー
- `transcription_transcriber.py` — 音声バッファを認識エンジンに渡して文字起こしを行うロジック
- `transcription_whisper.py` — faster-whisperWhisperModel周りのダウンロード/ロード補助
- `transcription_languages.py` — 各言語・国別のエンジン別コードマップ
### 最近の変更点
- 各モジュールに型注釈と docstring を追加しました。これによりメンテナンス性が向上します。
- `transcription_whisper.py` にダウンロード進捗コールバックを明記した実装を追加しました。
### 依存関係
主要な依存:
- `speech_recognition` — オーディオ録音と Google 音声認識のラッパー
- `pyaudiowpatch` — クロスプラットフォームのオーディオ設定
- `pydub` — 音声のチャンネル変換や処理
- `faster_whisper`(オプショナル)— ローカルで Whisper を使う場合
- `huggingface_hub`(オプショナル)— モデルアーティファクトのダウンロード
注意: `pydub``ffmpeg` が必要です。環境に ffmpeg が無いとワーニングが出ます。
推奨インストール(任意):
```powershell
pip install speechrecognition pyaudiowpatch pydub faster-whisper huggingface-hub
```
テストでは多くの外部依存をモックするため、全てをインストールする必要はありません。
### 初回セットアップ
1. 必要に応じて `ffmpeg` をインストールしてくださいpydub の動作に必要)。
2. Whisper ローカルモデルを使う場合、`transcription_whisper.downloadWhisperWeight(root, weight_type, callback, end_callback)` を呼んでモデルを取得します。
- `callback(progress: float)` は 0.0〜1.0 の進捗通知です。
- 例:
```python
from models.transcription import transcription_whisper as tw
tw.downloadWhisperWeight("./", "tiny", callback=lambda p: print(f"{p*100:.1f}%"), end_callback=lambda: print("done"))
```
### API 使用例
簡単な `AudioTranscriber` の使い方:
```python
from models.transcription.transcription_transcriber import AudioTranscriber
# source はライブラリが提供するオーディオソースオブジェクト
tr = AudioTranscriber(speaker=False, source=source, phrase_timeout=3, max_phrases=10, transcription_engine="Google")
# audio_queue は録音スレッドがプッシュするキュー
tr.transcribeAudioQueue(audio_queue, languages=["English"], countries=["United States"])
```
戻り値やエラー処理のルールについては各関数の docstring を参照してください。
### テスト方針
- `AudioTranscriber``Whisper` ラッパーはユニットテストでモック化して検証します。
- 推奨: `pytest``unittest.mock` を使い、以下のケースをカバーします:
- 正常系: Google/Whisper の成功パス(モックで期待テキストを返す)
- エッジ: 無音、低確信、複数言語
- フォールバック: Whisper が利用不可の場合のフォールバック動作
### トラブルシュート
- ffmpeg が見つからない: `pydub` がワーニングを出します。OS に合わせて ffmpeg をインストールしてください。
- Whisper のロード時に VRAM エラー: `getWhisperModel` は VRAM 不足を検出して `ValueError("VRAM_OUT_OF_MEMORY", message)` を投げます。デバイス設定や compute_type を調整してください。
- ハッシュ不一致やダウンロード失敗: キャッシュや weights ディレクトリを削除して再ダウンロードしてください。
### 変更履歴
- 2025-10-09: 型注釈と docstring を追加、ダウンロード/コールバック仕様を明記。
---
このドキュメントは簡潔な参照用です。さらに詳細な実行手順ログ収集方法、ffmpeg のインストール手順例など)が必要であれば追記します。
# transcription — 文字起こしモジュール # transcription — 文字起こしモジュール
概要: マイク/スピーカー音声の録音と Whisper/Google などのエンジンを使った文字起こしを提供するモジュール群です。主なクラスは録音用の Recorder と `AudioTranscriber` です。 概要: マイク/スピーカー音声の録音と Whisper/Google などのエンジンを使った文字起こしを提供するモジュール群です。主なクラスは録音用の Recorder と `AudioTranscriber` です。

View File

@@ -1,3 +1,8 @@
"""Language table used by transcription components.
Maps a display language and country to engine-specific language codes.
"""
transcription_lang = { transcription_lang = {
"Afrikaans":{ "Afrikaans":{
"South Africa":{ "South Africa":{

View File

@@ -1,9 +1,18 @@
"""Recorders that wrap speech_recognition microphone interfaces.
These classes provide small adapters that push raw audio bytes into queues.
They intentionally keep a thin API so the rest of the system can mock them
in tests.
"""
from typing import Any
from speech_recognition import Recognizer, Microphone from speech_recognition import Recognizer, Microphone
from pyaudiowpatch import get_sample_size, paInt16 from pyaudiowpatch import get_sample_size, paInt16
from datetime import datetime from datetime import datetime
class BaseRecorder: class BaseRecorder:
def __init__(self, source, energy_threshold, dynamic_energy_threshold, record_timeout): def __init__(self, source: Any, energy_threshold: int, dynamic_energy_threshold: bool, record_timeout: int) -> None:
self.recorder = Recognizer() self.recorder = Recognizer()
self.recorder.energy_threshold = energy_threshold self.recorder.energy_threshold = energy_threshold
self.recorder.dynamic_energy_threshold = dynamic_energy_threshold self.recorder.dynamic_energy_threshold = dynamic_energy_threshold
@@ -15,27 +24,29 @@ class BaseRecorder:
self.source = source self.source = source
def adjustForNoise(self): def adjustForNoise(self) -> None:
with self.source: with self.source:
self.recorder.adjust_for_ambient_noise(self.source) self.recorder.adjust_for_ambient_noise(self.source)
def recordIntoQueue(self, audio_queue): def recordIntoQueue(self, audio_queue: Any) -> None:
def record_callback(_, audio): def record_callback(_, audio):
audio_queue.put((audio.get_raw_data(), datetime.now())) audio_queue.put((audio.get_raw_data(), datetime.now()))
self.stop, self.pause, self.resume = self.recorder.listen_in_background(self.source, record_callback, phrase_time_limit=self.record_timeout) self.stop, self.pause, self.resume = self.recorder.listen_in_background(self.source, record_callback, phrase_time_limit=self.record_timeout)
class SelectedMicRecorder(BaseRecorder): class SelectedMicRecorder(BaseRecorder):
def __init__(self, device, energy_threshold, dynamic_energy_threshold, record_timeout): def __init__(self, device: dict, energy_threshold: int, dynamic_energy_threshold: bool, record_timeout: int) -> None:
source=Microphone( source = Microphone(
device_index=device['index'], device_index=device['index'],
sample_rate=int(device["defaultSampleRate"]), sample_rate=int(device["defaultSampleRate"]),
) )
super().__init__(source=source, energy_threshold=energy_threshold, dynamic_energy_threshold=dynamic_energy_threshold, record_timeout=record_timeout) super().__init__(source=source, energy_threshold=energy_threshold, dynamic_energy_threshold=dynamic_energy_threshold, record_timeout=record_timeout)
# self.adjustForNoise() # self.adjustForNoise()
class SelectedSpeakerRecorder(BaseRecorder): class SelectedSpeakerRecorder(BaseRecorder):
def __init__(self, device, energy_threshold, dynamic_energy_threshold, record_timeout): def __init__(self, device: dict, energy_threshold: int, dynamic_energy_threshold: bool, record_timeout: int) -> None:
source = Microphone(speaker=True, source = Microphone(speaker=True,
device_index= device["index"], device_index= device["index"],
@@ -47,7 +58,7 @@ class SelectedSpeakerRecorder(BaseRecorder):
# self.adjustForNoise() # self.adjustForNoise()
class BaseEnergyRecorder: class BaseEnergyRecorder:
def __init__(self, source): def __init__(self, source: Any) -> None:
self.recorder = Recognizer() self.recorder = Recognizer()
self.recorder.energy_threshold = 0 self.recorder.energy_threshold = 0
self.recorder.dynamic_energy_threshold = False self.recorder.dynamic_energy_threshold = False
@@ -59,27 +70,29 @@ class BaseEnergyRecorder:
self.source = source self.source = source
def adjustForNoise(self): def adjustForNoise(self) -> None:
with self.source: with self.source:
self.recorder.adjust_for_ambient_noise(self.source) self.recorder.adjust_for_ambient_noise(self.source)
def recordIntoQueue(self, energy_queue): def recordIntoQueue(self, energy_queue: Any) -> None:
def recordCallback(_, energy): def recordCallback(_, energy):
energy_queue.put(energy) energy_queue.put(energy)
self.stop, self.pause, self.resume = self.recorder.listen_energy_in_background(self.source, recordCallback) self.stop, self.pause, self.resume = self.recorder.listen_energy_in_background(self.source, recordCallback)
class SelectedMicEnergyRecorder(BaseEnergyRecorder): class SelectedMicEnergyRecorder(BaseEnergyRecorder):
def __init__(self, device): def __init__(self, device: dict) -> None:
source=Microphone( source = Microphone(
device_index=device['index'], device_index=device['index'],
sample_rate=int(device["defaultSampleRate"]), sample_rate=int(device["defaultSampleRate"]),
) )
super().__init__(source=source) super().__init__(source=source)
# self.adjustForNoise() # self.adjustForNoise()
class SelectedSpeakerEnergyRecorder(BaseEnergyRecorder): class SelectedSpeakerEnergyRecorder(BaseEnergyRecorder):
def __init__(self, device): def __init__(self, device: dict) -> None:
source = Microphone(speaker=True, source = Microphone(speaker=True,
device_index= device["index"], device_index= device["index"],
@@ -90,7 +103,15 @@ class SelectedSpeakerEnergyRecorder(BaseEnergyRecorder):
# self.adjustForNoise() # self.adjustForNoise()
class BaseEnergyAndAudioRecorder: class BaseEnergyAndAudioRecorder:
def __init__(self, source, energy_threshold, dynamic_energy_threshold, phrase_time_limit, phrase_timeout, record_timeout): def __init__(
self,
source: Any,
energy_threshold: int,
dynamic_energy_threshold: bool,
phrase_time_limit: int,
phrase_timeout: int,
record_timeout: int,
) -> None:
self.recorder = Recognizer() self.recorder = Recognizer()
self.recorder.energy_threshold = energy_threshold self.recorder.energy_threshold = energy_threshold
self.recorder.dynamic_energy_threshold = dynamic_energy_threshold self.recorder.dynamic_energy_threshold = dynamic_energy_threshold
@@ -104,11 +125,11 @@ class BaseEnergyAndAudioRecorder:
self.source = source self.source = source
def adjustForNoise(self): def adjustForNoise(self) -> None:
with self.source: with self.source:
self.recorder.adjust_for_ambient_noise(self.source) self.recorder.adjust_for_ambient_noise(self.source)
def recordIntoQueue(self, audio_queue, energy_queue=None): def recordIntoQueue(self, audio_queue: Any, energy_queue: Any = None) -> None:
def audioRecordCallback(_, audio): def audioRecordCallback(_, audio):
audio_queue.put((audio.get_raw_data(), datetime.now())) audio_queue.put((audio.get_raw_data(), datetime.now()))
@@ -121,11 +142,21 @@ class BaseEnergyAndAudioRecorder:
phrase_time_limit=self.phrase_time_limit, phrase_time_limit=self.phrase_time_limit,
callback_energy=energyRecordCallback if energy_queue is not None else None, callback_energy=energyRecordCallback if energy_queue is not None else None,
phrase_timeout=self.phrase_timeout, phrase_timeout=self.phrase_timeout,
record_timeout=self.record_timeout) record_timeout=self.record_timeout,
)
class SelectedMicEnergyAndAudioRecorder(BaseEnergyAndAudioRecorder): class SelectedMicEnergyAndAudioRecorder(BaseEnergyAndAudioRecorder):
def __init__(self, device, energy_threshold, dynamic_energy_threshold, phrase_time_limit, phrase_timeout:int=1, record_timeout:int=5): def __init__(
source=Microphone( self,
device: dict,
energy_threshold: int,
dynamic_energy_threshold: bool,
phrase_time_limit: int,
phrase_timeout: int = 1,
record_timeout: int = 5,
) -> None:
source = Microphone(
device_index=device['index'], device_index=device['index'],
sample_rate=int(device["defaultSampleRate"]), sample_rate=int(device["defaultSampleRate"]),
) )
@@ -139,14 +170,23 @@ class SelectedMicEnergyAndAudioRecorder(BaseEnergyAndAudioRecorder):
) )
# self.adjustForNoise() # self.adjustForNoise()
class SelectedSpeakerEnergyAndAudioRecorder(BaseEnergyAndAudioRecorder): class SelectedSpeakerEnergyAndAudioRecorder(BaseEnergyAndAudioRecorder):
def __init__(self, device, energy_threshold, dynamic_energy_threshold, phrase_time_limit, phrase_timeout:int=1, record_timeout:int=5): def __init__(
self,
device: dict,
energy_threshold: int,
dynamic_energy_threshold: bool,
phrase_time_limit: int,
phrase_timeout: int = 1,
record_timeout: int = 5,
) -> None:
source = Microphone(speaker=True, source = Microphone(speaker=True,
device_index= device["index"], device_index= device["index"],
sample_rate=int(device["defaultSampleRate"]), sample_rate=int(device["defaultSampleRate"]),
chunk_size=get_sample_size(paInt16), chunk_size=get_sample_size(paInt16),
channels=device["maxInputChannels"] channels=device["maxInputChannels"],
) )
super().__init__( super().__init__(
source=source, source=source,

View File

@@ -1,7 +1,14 @@
"""Runtime transcriber that wraps Google SpeechRecognition and faster-whisper.
This class focuses on converting incoming raw audio buffers into text using
either the Google web recognizer (online) or a local Whisper model (offline).
"""
import time import time
from io import BytesIO from io import BytesIO
from threading import Event from threading import Event
import wave import wave
from typing import Any, Callable, Dict, List, Optional, Tuple
from speech_recognition import Recognizer, AudioData, AudioFile from speech_recognition import Recognizer, AudioData, AudioFile
from speech_recognition.exceptions import UnknownValueError from speech_recognition.exceptions import UnknownValueError
from datetime import timedelta from datetime import timedelta
@@ -20,38 +27,71 @@ warnings.simplefilter('ignore', RuntimeWarning)
PHRASE_TIMEOUT = 3 PHRASE_TIMEOUT = 3
MAX_PHRASES = 10 MAX_PHRASES = 10
class AudioTranscriber: class AudioTranscriber:
def __init__(self, speaker, source, phrase_timeout, max_phrases, transcription_engine, root=None, whisper_weight_type=None, device="cpu", device_index=0, compute_type="auto"): """Convert queued audio buffers into transcripts.
Public attributes set by the constructor:
- speaker: bool
- phrase_timeout: int
- max_phrases: int
Methods are intentionally permissive about input types to match the
existing codebase; this wrapper adds typing for clarity.
"""
def __init__(
self,
speaker: bool,
source: Any,
phrase_timeout: int,
max_phrases: int,
transcription_engine: str,
root: Optional[str] = None,
whisper_weight_type: Optional[str] = None,
device: str = "cpu",
device_index: int = 0,
compute_type: str = "auto",
) -> None:
self.speaker = speaker self.speaker = speaker
self.phrase_timeout = phrase_timeout self.phrase_timeout = phrase_timeout
self.max_phrases = max_phrases self.max_phrases = max_phrases
self.transcript_data = [] self.transcript_data: List[Dict[str, Any]] = []
self.transcript_changed_event = Event() self.transcript_changed_event = Event()
self.audio_recognizer = Recognizer() self.audio_recognizer = Recognizer()
self.transcription_engine = "Google" self.transcription_engine = "Google"
self.whisper_model = None self.whisper_model = None
self.audio_sources = { self.audio_sources: Dict[str, Any] = {
"sample_rate": source.SAMPLE_RATE, "sample_rate": source.SAMPLE_RATE,
"sample_width": source.SAMPLE_WIDTH, "sample_width": source.SAMPLE_WIDTH,
"channels": source.channels, "channels": source.channels,
"last_sample": bytes(), "last_sample": bytes(),
"last_spoken": None, "last_spoken": None,
"new_phrase": True, "new_phrase": True,
"process_data_func": self.processSpeakerData if speaker else self.processSpeakerData "process_data_func": self.processSpeakerData if speaker else self.processSpeakerData,
} }
if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True: if transcription_engine == "Whisper" and checkWhisperWeight(root, whisper_weight_type) is True:
self.whisper_model = getWhisperModel(root, whisper_weight_type, device=device, device_index=device_index, compute_type=compute_type) self.whisper_model = getWhisperModel(
root, whisper_weight_type, device=device, device_index=device_index, compute_type=compute_type
)
self.transcription_engine = "Whisper" self.transcription_engine = "Whisper"
def transcribeAudioQueue(self, audio_queue, languages, countries, avg_logprob=-0.8, no_speech_prob=0.6): def transcribeAudioQueue(
self,
audio_queue: Any,
languages: List[str],
countries: List[str],
avg_logprob: float = -0.8,
no_speech_prob: float = 0.6,
) -> bool:
if audio_queue.empty(): if audio_queue.empty():
time.sleep(0.01) time.sleep(0.01)
return False return False
audio, time_spoken = audio_queue.get() audio, time_spoken = audio_queue.get()
self.updateLastSampleAndPhraseStatus(audio, time_spoken) self.updateLastSampleAndPhraseStatus(audio, time_spoken)
confidences = [{"confidence": 0, "text": "", "language": None}] confidences: List[Dict[str, Any]] = [{"confidence": 0, "text": "", "language": None}]
try: try:
audio_data = self.audio_sources["process_data_func"]() audio_data = self.audio_sources["process_data_func"]()
match self.transcription_engine: match self.transcription_engine:
@@ -67,13 +107,19 @@ class AudioTranscriber:
except Exception: except Exception:
pass pass
case "Whisper": case "Whisper":
audio_data = np.frombuffer(audio_data.get_raw_data(convert_rate=16000, convert_width=2), np.int16).flatten().astype(np.float32) / 32768.0 audio_data = np.frombuffer(
audio_data.get_raw_data(convert_rate=16000, convert_width=2), np.int16
).flatten().astype(np.float32) / 32768.0
if isinstance(audio_data, torch.Tensor): if isinstance(audio_data, torch.Tensor):
audio_data = audio_data.detach().numpy() audio_data = audio_data.detach().numpy()
for language, country in zip(languages, countries): for language, country in zip(languages, countries):
text = "" text = ""
source_language = transcription_lang[language][country][self.transcription_engine] if len(languages) == 1 else None source_language = (
transcription_lang[language][country][self.transcription_engine]
if len(languages) == 1
else None
)
segments, info = self.whisper_model.transcribe( segments, info = self.whisper_model.transcribe(
audio_data, audio_data,
beam_size=5, beam_size=5,
@@ -91,7 +137,9 @@ class AudioTranscriber:
continue continue
text += s.text text += s.text
confidences.append({"confidence": info.language_probability, "text": text, "language": language}) confidences.append({"confidence": info.language_probability, "text": text, "language": language})
if (len(languages) == 1) or (transcription_lang[language][country][self.transcription_engine] == info.language): if (len(languages) == 1) or (
transcription_lang[language][country][self.transcription_engine] == info.language
):
break break
except UnknownValueError: except UnknownValueError:
@@ -106,7 +154,7 @@ class AudioTranscriber:
self.updateTranscript(result) self.updateTranscript(result)
return True return True
def updateLastSampleAndPhraseStatus(self, data, time_spoken): def updateLastSampleAndPhraseStatus(self, data: bytes, time_spoken) -> None:
source_info = self.audio_sources source_info = self.audio_sources
if source_info["last_spoken"] and time_spoken - source_info["last_spoken"] > timedelta(seconds=self.phrase_timeout): if source_info["last_spoken"] and time_spoken - source_info["last_spoken"] > timedelta(seconds=self.phrase_timeout):
source_info["last_sample"] = bytes() source_info["last_sample"] = bytes()
@@ -117,11 +165,13 @@ class AudioTranscriber:
source_info["last_sample"] += data source_info["last_sample"] += data
source_info["last_spoken"] = time_spoken source_info["last_spoken"] = time_spoken
def processMicData(self): def processMicData(self) -> AudioData:
audio_data = AudioData(self.audio_sources["last_sample"], self.audio_sources["sample_rate"], self.audio_sources["sample_width"]) audio_data = AudioData(
self.audio_sources["last_sample"], self.audio_sources["sample_rate"], self.audio_sources["sample_width"]
)
return audio_data return audio_data
def processSpeakerData(self): def processSpeakerData(self) -> AudioData:
temp_file = BytesIO() temp_file = BytesIO()
with wave.open(temp_file, 'wb') as wf: with wave.open(temp_file, 'wb') as wf:
wf.setnchannels(self.audio_sources["channels"]) wf.setnchannels(self.audio_sources["channels"])
@@ -141,7 +191,7 @@ class AudioTranscriber:
audio = self.audio_recognizer.record(source) audio = self.audio_recognizer.record(source)
return audio return audio
def updateTranscript(self, result): def updateTranscript(self, result: dict) -> None:
source_info = self.audio_sources source_info = self.audio_sources
transcript = self.transcript_data transcript = self.transcript_data
@@ -152,14 +202,14 @@ class AudioTranscriber:
else: else:
transcript[0] = result transcript[0] = result
def getTranscript(self): def getTranscript(self) -> dict:
if len(self.transcript_data) > 0: if len(self.transcript_data) > 0:
result = self.transcript_data.pop(-1) result = self.transcript_data.pop(-1)
else: else:
result = {"confidence": 0, "text": "", "language": None} result = {"confidence": 0, "text": "", "language": None}
return result return result
def clearTranscriptData(self): def clearTranscriptData(self) -> None:
self.transcript_data.clear() self.transcript_data.clear()
self.audio_sources["last_sample"] = bytes() self.audio_sources["last_sample"] = bytes()
self.audio_sources["new_phrase"] = True self.audio_sources["new_phrase"] = True

View File

@@ -1,6 +1,17 @@
"""Helpers for downloading and loading Whisper (faster-whisper) models.
This module exposes small utilities used by the transcription subsystem:
- downloadFile: stream-download a file with optional progress callback
- checkWhisperWeight: quick local availability check
- downloadWhisperWeight: download model artifacts from HF hub
- getWhisperModel: construct and return a WhisperModel instance
The functions are defensive: failures are caught and reported by the caller.
"""
from os import path as os_path, makedirs as os_makedirs from os import path as os_path, makedirs as os_makedirs
from requests import get as requests_get from requests import get as requests_get
from typing import Callable from typing import Callable, Optional
import huggingface_hub import huggingface_hub
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
import logging import logging
@@ -30,24 +41,36 @@ _FILENAMES = [
"vocabulary.json", "vocabulary.json",
] ]
def downloadFile(url, path, func=None): def downloadFile(url: str, path: str, func: Optional[Callable[[float], None]] = None) -> None:
"""Download a file from `url` to `path`.
Args:
url: remote URL to download from
path: local filepath to write
func: optional callback(progress: float) called with a 0.0-1.0 progress
"""
try: try:
res = requests_get(url, stream=True) res = requests_get(url, stream=True)
res.raise_for_status() res.raise_for_status()
file_size = int(res.headers.get('content-length', 0)) file_size = int(res.headers.get('content-length', 0))
total_chunk = 0 total_chunk = 0
with open(os_path.join(path), 'wb') as file: with open(os_path.join(path), 'wb') as file:
for chunk in res.iter_content(chunk_size=1024*2000): for chunk in res.iter_content(chunk_size=1024 * 2000):
file.write(chunk) file.write(chunk)
if isinstance(func, Callable): if callable(func) and file_size:
total_chunk += len(chunk) total_chunk += len(chunk)
func(total_chunk/file_size) func(total_chunk / file_size)
except Exception: except Exception:
# Silent failure here; caller may re-check or log
pass pass
def checkWhisperWeight(root, weight_type): def checkWhisperWeight(root: str, weight_type: str) -> bool:
"""Return True if a Whisper model for `weight_type` is loadable from disk.
This attempts to construct a local `WhisperModel` with local_files_only=True
to verify required files exist and are compatible.
"""
path = os_path.join(root, "weights", "whisper", weight_type) path = os_path.join(root, "weights", "whisper", weight_type)
result = False
try: try:
WhisperModel( WhisperModel(
path, path,
@@ -58,23 +81,47 @@ def checkWhisperWeight(root, weight_type):
num_workers=1, num_workers=1,
local_files_only=True, local_files_only=True,
) )
result = True return True
except Exception: except Exception:
pass return False
return result
def downloadWhisperWeight(root, weight_type, callback=None, end_callback=None): def downloadWhisperWeight(
root: str,
weight_type: str,
callback: Optional[Callable[[float], None]] = None,
end_callback: Optional[Callable[[], None]] = None,
) -> None:
"""Ensure Whisper weight files are present locally; download them if missing.
Args:
root: project root where `weights/whisper` lives
weight_type: key from `_MODELS` (eg. "tiny", "base")
callback: progress callback for the main model file
end_callback: called when download completes
"""
path = os_path.join(root, "weights", "whisper", weight_type) path = os_path.join(root, "weights", "whisper", weight_type)
os_makedirs(path, exist_ok=True) os_makedirs(path, exist_ok=True)
if checkWhisperWeight(root, weight_type) is False: if not checkWhisperWeight(root, weight_type):
for filename in _FILENAMES: for filename in _FILENAMES:
file_path = os_path.join(path, filename) file_path = os_path.join(path, filename)
url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename) url = huggingface_hub.hf_hub_url(_MODELS[weight_type], filename)
downloadFile(url, file_path, func=callback if filename == "model.bin" else None) downloadFile(url, file_path, func=callback if filename == "model.bin" else None)
if isinstance(end_callback, Callable): if callable(end_callback):
end_callback() end_callback()
def getWhisperModel(root, weight_type, device="cpu", device_index=0, compute_type="auto"): def getWhisperModel(
root: str,
weight_type: str,
device: str = "cpu",
device_index: int = 0,
compute_type: str = "auto",
) -> WhisperModel:
"""Return a `WhisperModel` instance loaded from local weights.
Raises:
ValueError: when VRAM shortage is detected (wrapped from RuntimeError)
Exception: other loading errors are propagated.
"""
path = os_path.join(root, "weights", "whisper", weight_type) path = os_path.join(root, "weights", "whisper", weight_type)
if compute_type == "auto": if compute_type == "auto":
compute_type = getBestComputeType(device, device_index) compute_type = getBestComputeType(device, device_index)
@@ -90,11 +137,10 @@ def getWhisperModel(root, weight_type, device="cpu", device_index=0, compute_typ
) )
return model return model
except RuntimeError as e: except RuntimeError as e:
# VRAM不足エラーの検出 # Detect VRAM out-of-memory-like errors and raise a clear ValueError
error_message = str(e) error_message = str(e)
if "CUDA out of memory" in error_message or "CUBLAS_STATUS_ALLOC_FAILED" in error_message: if "CUDA out of memory" in error_message or "CUBLAS_STATUS_ALLOC_FAILED" in error_message:
raise ValueError("VRAM_OUT_OF_MEMORY", error_message) raise ValueError("VRAM_OUT_OF_MEMORY", error_message)
# その他のエラーは通常通り再送出
raise raise
if __name__ == "__main__": if __name__ == "__main__":