torchとctranslate2のインポートをガードし、安全なデフォルトを提供。型注釈とdocstringを追加して可読性を向上。ログ設定の重複ハンドラ追加を防ぐチェックを導入。encodeBase64はデコード失敗時に空辞書を返すように変更。getComputeDeviceListはGPU情報取得失敗時にCPU情報を返すように例外保護を追加。
This commit is contained in:
@@ -1,3 +1,61 @@
|
|||||||
|
## utils モジュール(src-python/utils.py)
|
||||||
|
|
||||||
|
このドキュメントは `src-python/utils.py` に対する最近のリファクタ内容、公開 API、利用上の注意点、テスト方法をまとめたものです。
|
||||||
|
|
||||||
|
### 概要
|
||||||
|
- `utils.py` はプロジェクト全体で使われる汎用ユーティリティ群を提供します。主な内容:
|
||||||
|
- ネットワーク接続チェック (`isConnectedNetwork`)
|
||||||
|
- ソケットの空きポート確認 (`isAvailableWebSocketServer`)
|
||||||
|
- IP アドレス検証 (`isValidIpAddress`)
|
||||||
|
- 計算デバイス一覧取得 (`getComputeDeviceList` / `getBestComputeType`)
|
||||||
|
- Base64 デコード (JSON) (`encodeBase64`)
|
||||||
|
- ロガー設定/ログ出力ヘルパー (`setupLogger`, `printLog`, `printResponse`, `errorLogging`)
|
||||||
|
|
||||||
|
### 今回のリファクタ(要点)
|
||||||
|
- Optional 依存へのフォールバック: `torch` と `ctranslate2` が存在しない環境でも動作するよう、import をガードし、安全なデフォルトを返す実装にしました。
|
||||||
|
- 型注釈と docstring を追加して可読性を向上させました。
|
||||||
|
- ログ設定の重複ハンドラ追加を防ぐチェックを導入しました。
|
||||||
|
- `encodeBase64` はデコード失敗時に例外を投げず空辞書を返すように(安全側)変更しました。
|
||||||
|
- `getComputeDeviceList` は GPU 情報取得で失敗しても CPU 情報を返すように例外保護を行いました。
|
||||||
|
|
||||||
|
### 重要な利用上の注意(breaking/behavior changes)
|
||||||
|
- Optional 依存
|
||||||
|
- `torch` が無い環境では GPU 情報は取得できません(`getComputeDeviceList` は CPU エントリのみ返します)。
|
||||||
|
- `ctranslate2` の `get_supported_compute_types` が無い場合は空リストを返します。
|
||||||
|
→ 環境に依存する挙動を想定して、呼び出し側は存在チェックやフォールバックを実装してください。
|
||||||
|
|
||||||
|
- `encodeBase64` の挙動
|
||||||
|
- 不正な base64/JSON を入力した場合、例外を投げず `{}` を返します。既存コードが例外を期待している場合は注意してください。
|
||||||
|
|
||||||
|
- `isAvailableWebSocketServer` の仕様
|
||||||
|
- 指定した host:port に対して bind が成功すれば True を返します(「使用中かどうか」を判定する用途と逆の意味合いになることがあるため注意)。
|
||||||
|
|
||||||
|
- ロギング
|
||||||
|
- `setupLogger` は同じログファイルに対するハンドラを重複して追加しません。`errorLogging()` はログ書き込みに失敗した場合でも最後に trace を stdout に出力するフォールバックがあります。
|
||||||
|
|
||||||
|
### API 使い方(短い例)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from utils import getComputeDeviceList, encodeBase64, printResponse
|
||||||
|
|
||||||
|
devices = getComputeDeviceList()
|
||||||
|
print(devices)
|
||||||
|
|
||||||
|
obj = encodeBase64('eyAia2V5IjogInZhbHVlIiB9') # -> {'key': 'value'}
|
||||||
|
|
||||||
|
printResponse(200, '/health', {'status': 'ok'})
|
||||||
|
```
|
||||||
|
|
||||||
|
### テスト方針
|
||||||
|
- optional 依存の違いを扱うため、ユニットテストは `torch` と `ctranslate2` をモックして行うことを推奨します。
|
||||||
|
- 例: `getComputeDeviceList()` は GPU がない環境でも CPU のエントリを返すことを確認するテスト。
|
||||||
|
|
||||||
|
### トラブルシュート
|
||||||
|
- ログファイルの書き込みエラー: 権限やディスク容量を確認してください。`error.log` と `process.log` の存在と権限をチェックします。
|
||||||
|
- `getComputeDeviceList()` が空しか返さない場合、`torch` または `ctranslate2` のインストールを確認してください。
|
||||||
|
|
||||||
|
### 変更履歴
|
||||||
|
- 2025-10-09: 型注釈・docstring 追加、optional import ガード、ロギング堅牢化。
|
||||||
# utils.py — 関数一覧と使用例
|
# utils.py — 関数一覧と使用例
|
||||||
目的: 共通ユーティリティ(ログ、JSON 出力、ネットワーク/ポート検査、デバイス/計算タイプ列挙、バリデーション等)を提供します。
|
目的: 共通ユーティリティ(ログ、JSON 出力、ネットワーク/ポート検査、デバイス/計算タイプ列挙、バリデーション等)を提供します。
|
||||||
|
|
||||||
|
|||||||
@@ -1,12 +1,22 @@
|
|||||||
import base64
|
import base64
|
||||||
from typing import Any, List, Dict
|
from typing import Any, List, Dict, Optional
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
import torch
|
try:
|
||||||
from ctranslate2 import get_supported_compute_types
|
import torch
|
||||||
|
except Exception:
|
||||||
|
torch = None # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
from ctranslate2 import get_supported_compute_types
|
||||||
|
except Exception:
|
||||||
|
# Fallback: if ctranslate2 is not installed, provide a safe stub.
|
||||||
|
def get_supported_compute_types(device: str, device_index: int) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import socket
|
import socket
|
||||||
@@ -47,32 +57,32 @@ def validateDictStructure(data: dict, structure: dict) -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def isConnectedNetwork(url="http://www.google.com", timeout=3) -> bool:
|
def isConnectedNetwork(url="http://www.google.com", timeout=3) -> bool:
|
||||||
|
"""Quick network connectivity check by requesting `url`.
|
||||||
|
|
||||||
|
Returns True when a 200 response is returned within `timeout` seconds.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, timeout=timeout)
|
response = requests.get(url, timeout=timeout)
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
except requests.RequestException:
|
except requests.RequestException:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def isAvailableWebSocketServer(host:str, port:int) -> bool:
|
def isAvailableWebSocketServer(host: str, port: int) -> bool:
|
||||||
"""WebSocketサーバーのポートが使用中かどうかを確認する"""
|
"""Return True if the given host/port appear available for binding.
|
||||||
response = True
|
|
||||||
|
Note: This attempts to bind a TCP socket to the address. If bind
|
||||||
|
succeeds the function returns True (meaning the address was available).
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as chk:
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as chk:
|
||||||
try:
|
chk.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
# SO_REUSEADDRを設定してソケットの再利用を許可
|
chk.bind((host, port))
|
||||||
chk.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
return True
|
||||||
chk.bind((host, port))
|
|
||||||
# シャットダウン前にリッスン状態にする必要はない
|
|
||||||
chk.close()
|
|
||||||
except Exception:
|
|
||||||
response = False
|
|
||||||
except Exception:
|
except Exception:
|
||||||
errorLogging()
|
return False
|
||||||
response = False
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
def isValidIpAddress(ip_address: str) -> bool:
|
def isValidIpAddress(ip_address: str) -> bool:
|
||||||
|
"""Return True if `ip_address` is a valid IPv4/IPv6 address."""
|
||||||
try:
|
try:
|
||||||
ipaddress.ip_address(ip_address)
|
ipaddress.ip_address(ip_address)
|
||||||
return True
|
return True
|
||||||
@@ -80,7 +90,12 @@ def isValidIpAddress(ip_address: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def getComputeDeviceList() -> List[Dict[str, Any]]:
|
def getComputeDeviceList() -> List[Dict[str, Any]]:
|
||||||
compute_types = [
|
"""Return a list of available compute devices and supported compute types.
|
||||||
|
|
||||||
|
The returned list contains dicts describing CPU and (if available)
|
||||||
|
CUDA devices. This function is defensive to missing optional packages.
|
||||||
|
"""
|
||||||
|
compute_types: List[Dict[str, Any]] = [
|
||||||
{
|
{
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
"device_index": 0,
|
"device_index": 0,
|
||||||
@@ -89,32 +104,47 @@ def getComputeDeviceList() -> List[Dict[str, Any]]:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
try:
|
||||||
for device_index in range(torch.cuda.device_count()):
|
if torch is not None and hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
gpu_device_name = torch.cuda.get_device_name(device_index)
|
for device_index in range(torch.cuda.device_count()):
|
||||||
gpu_compute_types = ["auto"] + list(get_supported_compute_types("cuda", device_index))
|
gpu_device_name = torch.cuda.get_device_name(device_index)
|
||||||
|
gpu_compute_types = ["auto"] + list(get_supported_compute_types("cuda", device_index))
|
||||||
|
|
||||||
# デバイスごとの計算タイプの制限
|
# デバイスごとの計算タイプの制限
|
||||||
if "GTX" in gpu_device_name:
|
if "GTX" in gpu_device_name:
|
||||||
unsupported_types = {"int8_bfloat16", "bfloat16", "float16", "int8"}
|
unsupported_types = {"int8_bfloat16", "bfloat16", "float16", "int8"}
|
||||||
gpu_compute_types = [t for t in gpu_compute_types if t not in unsupported_types]
|
gpu_compute_types = [t for t in gpu_compute_types if t not in unsupported_types]
|
||||||
elif not any(keyword in gpu_device_name for keyword in ["RTX", "Tesla", "A100", "Quadro"]):
|
elif not any(keyword in gpu_device_name for keyword in ["RTX", "Tesla", "A100", "Quadro"]):
|
||||||
gpu_compute_types = ["float32"]
|
gpu_compute_types = ["float32"]
|
||||||
|
|
||||||
compute_types.append(
|
compute_types.append(
|
||||||
{
|
{
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
"device_index": device_index,
|
"device_index": device_index,
|
||||||
"device_name": gpu_device_name,
|
"device_name": gpu_device_name,
|
||||||
"compute_types": gpu_compute_types,
|
"compute_types": gpu_compute_types,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
except Exception:
|
||||||
|
# If querying GPU devices fails, return at least the CPU entry
|
||||||
|
errorLogging()
|
||||||
|
|
||||||
return compute_types
|
return compute_types
|
||||||
|
|
||||||
def getBestComputeType(device: str, device_index: int) -> str:
|
def getBestComputeType(device: str, device_index: int) -> str:
|
||||||
compute_types = set(get_supported_compute_types(device, device_index))
|
"""Pick the best available compute type for a device.
|
||||||
device_name = "cpu" if device == "cpu" else torch.cuda.get_device_name(device_index)
|
|
||||||
|
Falls back to "float32" when no preferred type is available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
compute_types = set(get_supported_compute_types(device, device_index))
|
||||||
|
except Exception:
|
||||||
|
compute_types = set()
|
||||||
|
|
||||||
|
try:
|
||||||
|
device_name = "cpu" if device == "cpu" else (torch.cuda.get_device_name(device_index) if torch is not None else "")
|
||||||
|
except Exception:
|
||||||
|
device_name = ""
|
||||||
|
|
||||||
# デバイスごとの優先計算タイプ
|
# デバイスごとの優先計算タイプ
|
||||||
preferred_types = {
|
preferred_types = {
|
||||||
@@ -141,14 +171,26 @@ def getBestComputeType(device: str, device_index: int) -> str:
|
|||||||
|
|
||||||
return "float32"
|
return "float32"
|
||||||
|
|
||||||
def encodeBase64(data:str) -> dict:
|
def encodeBase64(data: str) -> Dict[str, Any]:
|
||||||
return json.loads(base64.b64decode(data).decode('utf-8'))
|
"""Decode a base64-encoded JSON string and return the parsed object.
|
||||||
|
|
||||||
def removeLog():
|
Returns an empty dict on failure.
|
||||||
with open('process.log', 'w', encoding="utf-8") as f:
|
"""
|
||||||
f.write("")
|
try:
|
||||||
|
return json.loads(base64.b64decode(data).decode('utf-8'))
|
||||||
|
except Exception:
|
||||||
|
errorLogging()
|
||||||
|
return {}
|
||||||
|
|
||||||
def setupLogger(name, log_file, level=logging.INFO):
|
def removeLog() -> None:
|
||||||
|
"""Truncate the process log file (process.log) if present."""
|
||||||
|
try:
|
||||||
|
with open('process.log', 'w', encoding="utf-8") as f:
|
||||||
|
f.write("")
|
||||||
|
except Exception:
|
||||||
|
errorLogging()
|
||||||
|
|
||||||
|
def setupLogger(name: str, log_file: str, level: int = logging.INFO) -> logging.Logger:
|
||||||
"""
|
"""
|
||||||
特定の名前とログファイルを持つロガーを設定します。
|
特定の名前とログファイルを持つロガーを設定します。
|
||||||
"""
|
"""
|
||||||
@@ -174,13 +216,17 @@ def setupLogger(name, log_file, level=logging.INFO):
|
|||||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
# ロガーにハンドラーを追加
|
# ロガーにハンドラーを追加(重複追加を避ける)
|
||||||
logger.addHandler(file_handler)
|
if not any(isinstance(h, RotatingFileHandler) and getattr(h, 'baseFilename', None) == getattr(file_handler, 'baseFilename', None) for h in logger.handlers):
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
process_logger = None
|
process_logger: Optional[logging.Logger] = None
|
||||||
def printLog(log:str, data:Any=None) -> None:
|
|
||||||
|
|
||||||
|
def printLog(log: str, data: Any = None) -> None:
|
||||||
|
"""Log and print a structured process log message."""
|
||||||
global process_logger
|
global process_logger
|
||||||
if process_logger is None:
|
if process_logger is None:
|
||||||
process_logger = setupLogger("process", "process.log", logging.INFO)
|
process_logger = setupLogger("process", "process.log", logging.INFO)
|
||||||
@@ -194,7 +240,11 @@ def printLog(log:str, data:Any=None) -> None:
|
|||||||
serialized = json.dumps(response)
|
serialized = json.dumps(response)
|
||||||
print(serialized, flush=True)
|
print(serialized, flush=True)
|
||||||
|
|
||||||
def printResponse(status:int, endpoint:str, result:Any=None) -> None:
|
def printResponse(status: int, endpoint: str, result: Any = None) -> None:
|
||||||
|
"""Log and print a structured response object.
|
||||||
|
|
||||||
|
If JSON serialization fails, record the error and emit a generic error payload.
|
||||||
|
"""
|
||||||
global process_logger
|
global process_logger
|
||||||
if process_logger is None:
|
if process_logger is None:
|
||||||
process_logger = setupLogger("process", "process.log", logging.INFO)
|
process_logger = setupLogger("process", "process.log", logging.INFO)
|
||||||
@@ -208,28 +258,37 @@ def printResponse(status:int, endpoint:str, result:Any=None) -> None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
serialized_response = json.dumps(response)
|
serialized_response = json.dumps(response)
|
||||||
except OSError as e:
|
except Exception as e:
|
||||||
errorLogging() # Log the full traceback of the OSError
|
errorLogging() # Log the full traceback of the exception
|
||||||
process_logger.error(f"Problematic response object before json.dumps: {response}")
|
try:
|
||||||
process_logger.error(f"OSError during json.dumps: {e}")
|
process_logger.error(f"Problematic response object before json.dumps: {response}")
|
||||||
# Optionally, print a generic error JSON to stdout if needed, or re-raise
|
process_logger.error(f"Exception during json.dumps: {e}")
|
||||||
# For now, we'll print a simple error message to stdout as a fallback
|
except Exception:
|
||||||
|
pass
|
||||||
|
# Fallback generic error payload
|
||||||
error_json = json.dumps({
|
error_json = json.dumps({
|
||||||
"status": 500,
|
"status": 500,
|
||||||
"endpoint": endpoint,
|
"endpoint": endpoint,
|
||||||
"result": {"error": "Failed to serialize response due to OSError", "details": str(e)}
|
"result": {"error": "Failed to serialize response", "details": str(e)},
|
||||||
})
|
})
|
||||||
print(error_json, flush=True)
|
print(error_json, flush=True)
|
||||||
else:
|
else:
|
||||||
print(serialized_response, flush=True)
|
print(serialized_response, flush=True)
|
||||||
|
|
||||||
error_logger = None
|
error_logger: Optional[logging.Logger] = None
|
||||||
|
|
||||||
|
|
||||||
def errorLogging() -> None:
|
def errorLogging() -> None:
|
||||||
|
"""Log the current exception traceback to the error logger."""
|
||||||
global error_logger
|
global error_logger
|
||||||
if error_logger is None:
|
if error_logger is None:
|
||||||
error_logger = setupLogger("error", "error.log", logging.ERROR)
|
error_logger = setupLogger("error", "error.log", logging.ERROR)
|
||||||
|
|
||||||
error_logger.error(traceback.format_exc())
|
try:
|
||||||
|
error_logger.error(traceback.format_exc())
|
||||||
|
except Exception:
|
||||||
|
# As a last resort, print the traceback to stdout
|
||||||
|
print(traceback.format_exc(), flush=True)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(getComputeDeviceList())
|
print(getComputeDeviceList())
|
||||||
Reference in New Issue
Block a user