Files
VRCT/src-python/models/translation/translation_openai.py
misyaguziya dac903e07c feat: Implement translation prompt history injection for Chat/Mic/Speaker
- Added a history management system in model.py to store and retrieve recent messages from Chat, Mic, and Speaker.
- Updated controller.py to automatically add messages to the translation history after processing.
- Enhanced translation clients (OpenAI, Gemini, Groq, etc.) to accept and utilize context history for improved translation quality.
- Introduced YAML configuration options for enabling history injection and customizing its behavior across different translation models.
- Ensured that only original messages are stored in history to optimize token usage during translation.
2025-12-15 01:15:47 +09:00

201 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from openai import OpenAI
from langchain_openai import ChatOpenAI
from pydantic import SecretStr
try:
from .translation_languages import translation_lang
from .translation_utils import loadTranslatePromptConfig
except Exception:
import sys
from os import path as os_path
sys.path.append(os_path.dirname(os_path.dirname(os_path.dirname(os_path.abspath(__file__)))))
from translation_languages import translation_lang, loadTranslationLanguages
from translation_utils import loadTranslatePromptConfig
translation_lang = loadTranslationLanguages(path=".", force=True)
def _authentication_check(api_key: str, base_url: str | None = None) -> bool:
"""Check if the provided API key is valid by attempting to list models.
"""
try:
client = OpenAI(api_key=api_key, base_url=base_url)
client.models.list()
return True
except Exception:
return False
def _get_available_text_models(api_key: str, base_url: str | None = None) -> list[str]:
"""Extract only GPT models suitable for translation and chat applications (plus those with fine-tuning)
"""
client = OpenAI(api_key=api_key, base_url=base_url)
res = client.models.list()
allowed_models = []
for model in res.data:
model_id = model.id
root = getattr(model, "root", "")
# 除外対象のキーワード
exclude_keywords = [
"whisper", # 音声認識
"embedding", # 埋め込み
"image", # 画像生成
"tts", # 音声合成
"audio", # 音声系transcribe, diarize含む
"search", # 検索補助モデル
"transcribe", # 音声→文字起こし
"diarize", # 話者分離
"vision" # 画像入力系旧gpt-4-visionなど
]
# 除外キーワードが含まれているモデルをスキップ
if any(kw in model_id for kw in exclude_keywords):
continue
# GPTモデルまたはFine-tune GPTモデルのみ対象
if model_id.startswith("gpt-"):
allowed_models.append(model_id)
elif model_id.startswith("ft:") and root.startswith("gpt-"):
allowed_models.append(model_id)
allowed_models.sort()
return allowed_models
class OpenAIClient:
"""OpenAI Translation simple wrapper.
prompt/translation_openai.yml から system_prompt / supported_languages を読み込む。
"""
def __init__(self, base_url: str | None = None, root_path: str = None):
self.api_key = None
self.model = None
self.base_url = base_url # None の場合は公式エンドポイント
prompt_config = loadTranslatePromptConfig(root_path, "translation_openai.yml")
self.supported_languages = list(translation_lang["OpenAI_API"]["source"].keys())
self.prompt_template = prompt_config["system_prompt"]
# history config (optional)
self.history_cfg = prompt_config.get("history", {
"use_history": False,
"sources": [],
"max_messages": 0,
"max_chars": 0,
"header_template": "",
"item_template": "[{source}] {role}: {text}",
})
self._context_history: list[dict] = []
self.openai_llm = None
def getModelList(self) -> list[str]:
return _get_available_text_models(self.api_key, self.base_url) if self.api_key else []
def getAuthKey(self) -> str:
return self.api_key
def setAuthKey(self, api_key: str) -> bool:
result = _authentication_check(api_key, self.base_url)
if result:
self.api_key = api_key
return result
def getModel(self) -> str:
return self.model
def setModel(self, model: str) -> bool:
if model in self.getModelList():
self.model = model
return True
else:
return False
def updateClient(self) -> None:
self.openai_llm = ChatOpenAI(
base_url=self.base_url,
model=self.model,
api_key=SecretStr(self.api_key),
streaming=False,
)
def setContextHistory(self, history_items: list[dict]) -> None:
"""Set recent conversation history for prompt injection.
Each item should be a dict containing:
- source: "chat" | "mic" | "speaker"
- text: message string
- timestamp: ISO format datetime string
"""
self._context_history = history_items or []
def translate(self, text: str, input_lang: str, output_lang: str) -> str:
system_prompt = self.prompt_template.format(
supported_languages=self.supported_languages,
input_lang=input_lang,
output_lang=output_lang,
)
# Inject recent conversation history if enabled by YAML config
if self.history_cfg.get("use_history"):
allowed_sources = set(self.history_cfg.get("sources", []))
max_messages = int(self.history_cfg.get("max_messages", 0))
max_chars = int(self.history_cfg.get("max_chars", 0))
item_tmpl = self.history_cfg.get("item_template", "[{source}] {role}: {text}")
header_tmpl = self.history_cfg.get("header_template", "{history}")
# filter by source and take newest N
filtered = [h for h in self._context_history if h.get("source") in allowed_sources]
recent = filtered[-max_messages:] if max_messages > 0 else filtered
# format items
formatted_items = []
for h in recent:
# Format timestamp as HH:MM to save tokens
timestamp_str = ''
if 'timestamp' in h:
from datetime import datetime
try:
ts = datetime.fromisoformat(h['timestamp'])
timestamp_str = ts.strftime('%H:%M')
except:
timestamp_str = ''
formatted_items.append(
item_tmpl.format(
timestamp=timestamp_str,
source=h.get("source", ""),
text=h.get("text", ""),
)
)
history_blob = "\n".join(formatted_items).strip()
# truncate by char limit to mitigate token use
if max_chars and len(history_blob) > max_chars:
history_blob = history_blob[-max_chars:]
# assemble header and append to system prompt
history_header = header_tmpl.format(max_messages=max_messages, history=history_blob)
if history_header:
system_prompt = f"{system_prompt}\n\n{history_header}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
]
resp = self.openai_llm.invoke(messages)
content = ""
if isinstance(resp.content, str):
content = resp.content
elif isinstance(resp.content, list):
for part in resp.content:
if isinstance(part, str):
content += part
elif isinstance(part, dict) and "content" in part and isinstance(part["content"], str):
content += part["content"]
return content.strip()
if __name__ == "__main__":
AUTH_KEY = "OPENAI_API_KEY"
client = OpenAIClient()
client.setAuthKey(AUTH_KEY)
models = client.getModelList()
if models:
print("Available models:", models)
model = input("Select a model: ")
client.setModel(model)
client.updateClient()
print(client.translate("こんにちは世界", "Japanese", "English"))