torchとctranslate2のインポートをガードし、安全なデフォルトを提供。型注釈とdocstringを追加して可読性を向上。ログ設定の重複ハンドラ追加を防ぐチェックを導入。encodeBase64はデコード失敗時に空辞書を返すように変更。getComputeDeviceListはGPU情報取得失敗時にCPU情報を返すように例外保護を追加。

This commit is contained in:
misyaguziya
2025-10-09 18:53:42 +09:00
parent 35e8d7dda9
commit eca5e31429
2 changed files with 176 additions and 59 deletions

View File

@@ -1,12 +1,22 @@
import base64
from typing import Any, List, Dict
from typing import Any, List, Dict, Optional
import json
import traceback
import logging
from logging.handlers import RotatingFileHandler
import torch
from ctranslate2 import get_supported_compute_types
try:
import torch
except Exception:
torch = None # type: ignore
try:
from ctranslate2 import get_supported_compute_types
except Exception:
# Fallback: if ctranslate2 is not installed, provide a safe stub.
def get_supported_compute_types(device: str, device_index: int) -> List[str]:
return []
import requests
import ipaddress
import socket
@@ -47,32 +57,32 @@ def validateDictStructure(data: dict, structure: dict) -> bool:
return True
def isConnectedNetwork(url="http://www.google.com", timeout=3) -> bool:
"""Quick network connectivity check by requesting `url`.
Returns True when a 200 response is returned within `timeout` seconds.
"""
try:
response = requests.get(url, timeout=timeout)
return response.status_code == 200
except requests.RequestException:
return False
def isAvailableWebSocketServer(host:str, port:int) -> bool:
"""WebSocketサーバーのポートが使用中かどうかを確認する"""
response = True
def isAvailableWebSocketServer(host: str, port: int) -> bool:
"""Return True if the given host/port appear available for binding.
Note: This attempts to bind a TCP socket to the address. If bind
succeeds the function returns True (meaning the address was available).
"""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as chk:
try:
# SO_REUSEADDRを設定してソケットの再利用を許可
chk.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
chk.bind((host, port))
# シャットダウン前にリッスン状態にする必要はない
chk.close()
except Exception:
response = False
chk.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
chk.bind((host, port))
return True
except Exception:
errorLogging()
response = False
return response
return False
def isValidIpAddress(ip_address: str) -> bool:
"""Return True if `ip_address` is a valid IPv4/IPv6 address."""
try:
ipaddress.ip_address(ip_address)
return True
@@ -80,7 +90,12 @@ def isValidIpAddress(ip_address: str) -> bool:
return False
def getComputeDeviceList() -> List[Dict[str, Any]]:
compute_types = [
"""Return a list of available compute devices and supported compute types.
The returned list contains dicts describing CPU and (if available)
CUDA devices. This function is defensive to missing optional packages.
"""
compute_types: List[Dict[str, Any]] = [
{
"device": "cpu",
"device_index": 0,
@@ -89,32 +104,47 @@ def getComputeDeviceList() -> List[Dict[str, Any]]:
}
]
if torch.cuda.is_available():
for device_index in range(torch.cuda.device_count()):
gpu_device_name = torch.cuda.get_device_name(device_index)
gpu_compute_types = ["auto"] + list(get_supported_compute_types("cuda", device_index))
try:
if torch is not None and hasattr(torch, "cuda") and torch.cuda.is_available():
for device_index in range(torch.cuda.device_count()):
gpu_device_name = torch.cuda.get_device_name(device_index)
gpu_compute_types = ["auto"] + list(get_supported_compute_types("cuda", device_index))
# デバイスごとの計算タイプの制限
if "GTX" in gpu_device_name:
unsupported_types = {"int8_bfloat16", "bfloat16", "float16", "int8"}
gpu_compute_types = [t for t in gpu_compute_types if t not in unsupported_types]
elif not any(keyword in gpu_device_name for keyword in ["RTX", "Tesla", "A100", "Quadro"]):
gpu_compute_types = ["float32"]
# デバイスごとの計算タイプの制限
if "GTX" in gpu_device_name:
unsupported_types = {"int8_bfloat16", "bfloat16", "float16", "int8"}
gpu_compute_types = [t for t in gpu_compute_types if t not in unsupported_types]
elif not any(keyword in gpu_device_name for keyword in ["RTX", "Tesla", "A100", "Quadro"]):
gpu_compute_types = ["float32"]
compute_types.append(
{
"device": "cuda",
"device_index": device_index,
"device_name": gpu_device_name,
"compute_types": gpu_compute_types,
}
)
compute_types.append(
{
"device": "cuda",
"device_index": device_index,
"device_name": gpu_device_name,
"compute_types": gpu_compute_types,
}
)
except Exception:
# If querying GPU devices fails, return at least the CPU entry
errorLogging()
return compute_types
def getBestComputeType(device: str, device_index: int) -> str:
compute_types = set(get_supported_compute_types(device, device_index))
device_name = "cpu" if device == "cpu" else torch.cuda.get_device_name(device_index)
"""Pick the best available compute type for a device.
Falls back to "float32" when no preferred type is available.
"""
try:
compute_types = set(get_supported_compute_types(device, device_index))
except Exception:
compute_types = set()
try:
device_name = "cpu" if device == "cpu" else (torch.cuda.get_device_name(device_index) if torch is not None else "")
except Exception:
device_name = ""
# デバイスごとの優先計算タイプ
preferred_types = {
@@ -141,14 +171,26 @@ def getBestComputeType(device: str, device_index: int) -> str:
return "float32"
def encodeBase64(data:str) -> dict:
return json.loads(base64.b64decode(data).decode('utf-8'))
def encodeBase64(data: str) -> Dict[str, Any]:
"""Decode a base64-encoded JSON string and return the parsed object.
def removeLog():
with open('process.log', 'w', encoding="utf-8") as f:
f.write("")
Returns an empty dict on failure.
"""
try:
return json.loads(base64.b64decode(data).decode('utf-8'))
except Exception:
errorLogging()
return {}
def setupLogger(name, log_file, level=logging.INFO):
def removeLog() -> None:
"""Truncate the process log file (process.log) if present."""
try:
with open('process.log', 'w', encoding="utf-8") as f:
f.write("")
except Exception:
errorLogging()
def setupLogger(name: str, log_file: str, level: int = logging.INFO) -> logging.Logger:
"""
特定の名前とログファイルを持つロガーを設定します。
"""
@@ -174,13 +216,17 @@ def setupLogger(name, log_file, level=logging.INFO):
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
# ロガーにハンドラーを追加
logger.addHandler(file_handler)
# ロガーにハンドラーを追加(重複追加を避ける)
if not any(isinstance(h, RotatingFileHandler) and getattr(h, 'baseFilename', None) == getattr(file_handler, 'baseFilename', None) for h in logger.handlers):
logger.addHandler(file_handler)
return logger
process_logger = None
def printLog(log:str, data:Any=None) -> None:
process_logger: Optional[logging.Logger] = None
def printLog(log: str, data: Any = None) -> None:
"""Log and print a structured process log message."""
global process_logger
if process_logger is None:
process_logger = setupLogger("process", "process.log", logging.INFO)
@@ -194,7 +240,11 @@ def printLog(log:str, data:Any=None) -> None:
serialized = json.dumps(response)
print(serialized, flush=True)
def printResponse(status:int, endpoint:str, result:Any=None) -> None:
def printResponse(status: int, endpoint: str, result: Any = None) -> None:
"""Log and print a structured response object.
If JSON serialization fails, record the error and emit a generic error payload.
"""
global process_logger
if process_logger is None:
process_logger = setupLogger("process", "process.log", logging.INFO)
@@ -208,28 +258,37 @@ def printResponse(status:int, endpoint:str, result:Any=None) -> None:
try:
serialized_response = json.dumps(response)
except OSError as e:
errorLogging() # Log the full traceback of the OSError
process_logger.error(f"Problematic response object before json.dumps: {response}")
process_logger.error(f"OSError during json.dumps: {e}")
# Optionally, print a generic error JSON to stdout if needed, or re-raise
# For now, we'll print a simple error message to stdout as a fallback
except Exception as e:
errorLogging() # Log the full traceback of the exception
try:
process_logger.error(f"Problematic response object before json.dumps: {response}")
process_logger.error(f"Exception during json.dumps: {e}")
except Exception:
pass
# Fallback generic error payload
error_json = json.dumps({
"status": 500,
"endpoint": endpoint,
"result": {"error": "Failed to serialize response due to OSError", "details": str(e)}
"result": {"error": "Failed to serialize response", "details": str(e)},
})
print(error_json, flush=True)
else:
print(serialized_response, flush=True)
error_logger = None
error_logger: Optional[logging.Logger] = None
def errorLogging() -> None:
"""Log the current exception traceback to the error logger."""
global error_logger
if error_logger is None:
error_logger = setupLogger("error", "error.log", logging.ERROR)
error_logger.error(traceback.format_exc())
try:
error_logger.error(traceback.format_exc())
except Exception:
# As a last resort, print the traceback to stdout
print(traceback.format_exc(), flush=True)
if __name__ == "__main__":
print(getComputeDeviceList())