151 lines
4.6 KiB
Python
151 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import Callable, Dict, List, Optional
|
|
|
|
from openai import OpenAI
|
|
|
|
from app.cache import get_cached, set_cache
|
|
|
|
|
|
def _env_int(name: str, default: int) -> int:
|
|
value = os.getenv(name)
|
|
if value is None or value == "":
|
|
return default
|
|
try:
|
|
return int(value)
|
|
except ValueError:
|
|
return default
|
|
|
|
|
|
def _env_float(name: str, default: float) -> float:
|
|
value = os.getenv(name)
|
|
if value is None or value == "":
|
|
return default
|
|
try:
|
|
return float(value)
|
|
except ValueError:
|
|
return default
|
|
|
|
|
|
class FastTranslator:
|
|
def __init__(
|
|
self,
|
|
api_key=None,
|
|
model=None,
|
|
workers=None,
|
|
max_retries=None,
|
|
retry_base_delay=None,
|
|
**kwargs,
|
|
):
|
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
self.model = model or os.getenv("SUBFOX_MODEL", "gpt-4o-mini")
|
|
self.workers = workers if workers is not None else _env_int("SUBFOX_WORKERS", 4)
|
|
self.max_retries = (
|
|
max_retries if max_retries is not None else _env_int("SUBFOX_MAX_RETRIES", 3)
|
|
)
|
|
self.retry_base_delay = (
|
|
retry_base_delay
|
|
if retry_base_delay is not None
|
|
else _env_float("SUBFOX_RETRY_BASE_DELAY", 1.0)
|
|
)
|
|
self.kwargs = kwargs
|
|
self.client = OpenAI(api_key=self.api_key) if self.api_key else OpenAI()
|
|
|
|
def _translate_text(self, text: str, source_lang: str, target_lang: str) -> str:
|
|
cached = get_cached(source_lang, target_lang, text)
|
|
if cached:
|
|
return cached
|
|
|
|
prompt = (
|
|
f"Translate the following subtitle text from {source_lang} to {target_lang}. "
|
|
"Preserve meaning, keep it natural, and return only the translated text.\n\n"
|
|
f"{text}"
|
|
)
|
|
|
|
last_error = None
|
|
|
|
for attempt in range(1, self.max_retries + 1):
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"You are a subtitle translator. "
|
|
"Return only the translated text with no explanations."
|
|
),
|
|
},
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
temperature=0,
|
|
)
|
|
|
|
content = response.choices[0].message.content or ""
|
|
result = content.strip()
|
|
|
|
if result:
|
|
set_cache(source_lang, target_lang, text, result, self.model)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
|
|
if attempt >= self.max_retries:
|
|
break
|
|
|
|
delay = self.retry_base_delay * (2 ** (attempt - 1))
|
|
time.sleep(delay)
|
|
|
|
raise RuntimeError(
|
|
f"Translation failed after {self.max_retries} attempts: {last_error}"
|
|
)
|
|
|
|
def _translate_one(self, block: Dict, source_lang: str, target_lang: str) -> Dict:
|
|
new_block = dict(block)
|
|
new_block["text"] = self._translate_text(
|
|
block["text"],
|
|
source_lang,
|
|
target_lang,
|
|
)
|
|
return new_block
|
|
|
|
def translate_blocks(
|
|
self,
|
|
blocks: List[Dict],
|
|
source_lang: str,
|
|
target_lang: str,
|
|
progress_callback: Optional[Callable[[int, int], None]] = None,
|
|
) -> List[Dict]:
|
|
total = len(blocks)
|
|
output: List[Optional[Dict]] = [None] * total
|
|
|
|
with ThreadPoolExecutor(max_workers=self.workers) as executor:
|
|
futures = {
|
|
executor.submit(self._translate_one, block, source_lang, target_lang): i
|
|
for i, block in enumerate(blocks)
|
|
}
|
|
|
|
done = 0
|
|
for future in as_completed(futures):
|
|
idx = futures[future]
|
|
output[idx] = future.result()
|
|
done += 1
|
|
|
|
if progress_callback:
|
|
progress_callback(done, total)
|
|
|
|
return [block for block in output if block is not None]
|
|
|
|
def translate_batch(
|
|
self,
|
|
batch: List[Dict],
|
|
source_lang: str,
|
|
target_lang: str,
|
|
) -> List[str]:
|
|
translated_blocks = self.translate_blocks(batch, source_lang, target_lang)
|
|
return [block["text"] for block in translated_blocks]
|