| |
| from __future__ import annotations |
|
|
| import base64 |
| import binascii |
| import json |
| import os |
| import re |
| import subprocess |
| import uuid |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer |
| from pathlib import Path |
| from time import monotonic |
| from typing import Any |
|
|
| APP_DIR = Path(__file__).resolve().parent |
| STATIC_DIR = APP_DIR / "static" |
| UPLOADS_DIR = APP_DIR / "uploads" |
| HOST = os.environ.get("HOST", "127.0.0.1") |
| PORT = int(os.environ.get("PORT", "8080")) |
| GEMINI_TIMEOUT_SEC = int(os.environ.get("GEMINI_TIMEOUT_SEC", "90")) |
| GEMINI_CLI_BINARY = os.environ.get("GEMINI_CLI_BINARY", "gemini") |
| LOCKED_GEMINI_MODEL = "gemini-3-flash-preview" |
| MAX_IMAGE_BYTES = int(os.environ.get("MAX_IMAGE_BYTES", str(8 * 1024 * 1024))) |
| MAX_BATCH_IMAGES = int(os.environ.get("MAX_BATCH_IMAGES", "20")) |
| MAX_PARALLEL_WORKERS = max(1, int(os.environ.get("MAX_PARALLEL_WORKERS", "4"))) |
|
|
| ALLOWED_IMAGE_MIME_TO_EXT = { |
| "image/png": "png", |
| "image/jpeg": "jpg", |
| "image/jpg": "jpg", |
| "image/webp": "webp", |
| "image/gif": "gif", |
| } |
| DATA_URL_RE = re.compile(r"^data:(?P<mime>[-\w.+/]+);base64,(?P<data>[A-Za-z0-9+/=\s]+)$") |
|
|
| DEFAULT_SYSTEM_PROMPT = ( |
| "You are a UK fintech ad compliance screening assistant. " |
| "Return only valid JSON and nothing else." |
| ) |
|
|
| JSON_SCHEMA_HINT = { |
| "risk_level": "low | medium | high", |
| "summary": "short sentence", |
| "violations": [ |
| { |
| "issue": "what is risky", |
| "rule_refs": ["FCA handbook or principle area"], |
| "why": "why this is a risk", |
| "fix": "specific rewrite guidance", |
| } |
| ], |
| "safe_rewrite": "optional ad rewrite", |
| } |
|
|
|
|
| def build_prompt(ad_text: str, extra_context: str, system_prompt: str, image_at_path: str | None) -> str: |
| ad_text_clean = ad_text.strip() |
| parts = [ |
| system_prompt.strip(), |
| "", |
| "Task: Screen this ad copy for UK fintech/FCA-style risk.", |
| "Output format: Return only JSON in this shape:", |
| json.dumps(JSON_SCHEMA_HINT, ensure_ascii=True, indent=2), |
| "", |
| ] |
| if image_at_path: |
| parts += [ |
| "Creative image reference:", |
| f"@{image_at_path}", |
| "Use this image as part of your compliance risk review.", |
| "", |
| ] |
| parts += [ |
| "Ad copy:", |
| ad_text_clean if ad_text_clean else "[Not provided]", |
| ] |
| if extra_context.strip(): |
| parts += ["", "Extra context:", extra_context.strip()] |
| return "\n".join(parts) |
|
|
|
|
| def gemini_cmd_candidates(prompt: str) -> list[list[str]]: |
| |
| return [ |
| [GEMINI_CLI_BINARY, "--model", LOCKED_GEMINI_MODEL, "-p", prompt], |
| [GEMINI_CLI_BINARY, "-m", LOCKED_GEMINI_MODEL, "-p", prompt], |
| [GEMINI_CLI_BINARY, "--model", LOCKED_GEMINI_MODEL, "--prompt", prompt], |
| [GEMINI_CLI_BINARY, "-m", LOCKED_GEMINI_MODEL, "--prompt", prompt], |
| ] |
|
|
|
|
| def is_flag_parse_error(stderr: str, stdout: str) -> bool: |
| combined = f"{stderr}\n{stdout}".lower() |
| return any( |
| token in combined |
| for token in ( |
| "unknown option", |
| "unknown argument", |
| "invalid option", |
| "unrecognized option", |
| "unrecognized argument", |
| "unexpected argument", |
| "did you mean", |
| ) |
| ) |
|
|
|
|
| def run_gemini(prompt: str) -> str: |
| attempts = gemini_cmd_candidates(prompt) |
| last_error = "Gemini CLI invocation failed." |
| child_env = os.environ.copy() |
|
|
| |
| if not child_env.get("GEMINI_API_KEY") and child_env.get("GOOGLE_API_KEY"): |
| child_env["GEMINI_API_KEY"] = child_env["GOOGLE_API_KEY"] |
| child_env.pop("GOOGLE_API_KEY", None) |
|
|
| for idx, cmd in enumerate(attempts): |
| proc = subprocess.run( |
| cmd, |
| capture_output=True, |
| text=True, |
| cwd=str(APP_DIR), |
| env=child_env, |
| timeout=GEMINI_TIMEOUT_SEC, |
| check=False, |
| ) |
| if proc.returncode == 0: |
| return (proc.stdout or "").strip() |
|
|
| stderr = (proc.stderr or "").strip() |
| stdout = (proc.stdout or "").strip() |
| details = stderr if stderr else stdout |
| last_error = details or f"Gemini CLI exited with code {proc.returncode}." |
|
|
| |
| if idx < len(attempts) - 1 and is_flag_parse_error(stderr, stdout): |
| continue |
| break |
|
|
| raise RuntimeError(last_error) |
|
|
|
|
| def try_parse_json(text: str) -> Any | None: |
| trimmed = text.strip() |
| if not trimmed: |
| return None |
| |
| if trimmed.startswith("```"): |
| lines = trimmed.splitlines() |
| if len(lines) >= 3 and lines[-1].strip().startswith("```"): |
| trimmed = "\n".join(lines[1:-1]).strip() |
| if trimmed.lower().startswith("json"): |
| trimmed = trimmed[4:].strip() |
| try: |
| return json.loads(trimmed) |
| except json.JSONDecodeError: |
| return None |
|
|
|
|
| def safe_filename_stem(raw_name: str) -> str: |
| stem = Path(raw_name).stem if raw_name else "ad-image" |
| cleaned = re.sub(r"[^A-Za-z0-9_-]+", "-", stem).strip("-") |
| if not cleaned: |
| return "ad-image" |
| return cleaned[:40] |
|
|
|
|
| def save_image_from_data_url(image_data_url: str, image_filename: str) -> str: |
| match = DATA_URL_RE.match(image_data_url.strip()) |
| if not match: |
| raise ValueError("Image must be a valid base64 data URL (data:image/...;base64,...).") |
|
|
| mime_type = match.group("mime").lower() |
| extension = ALLOWED_IMAGE_MIME_TO_EXT.get(mime_type) |
| if not extension: |
| allowed = ", ".join(sorted(ALLOWED_IMAGE_MIME_TO_EXT)) |
| raise ValueError(f"Unsupported image type '{mime_type}'. Allowed: {allowed}.") |
|
|
| base64_payload = re.sub(r"\s+", "", match.group("data")) |
| try: |
| image_bytes = base64.b64decode(base64_payload, validate=True) |
| except (ValueError, binascii.Error): |
| raise ValueError("Image base64 payload is invalid.") from None |
|
|
| if not image_bytes: |
| raise ValueError("Image payload is empty.") |
|
|
| if len(image_bytes) > MAX_IMAGE_BYTES: |
| raise ValueError(f"Image is too large. Max size is {MAX_IMAGE_BYTES} bytes.") |
|
|
| UPLOADS_DIR.mkdir(parents=True, exist_ok=True) |
| final_name = f"{safe_filename_stem(image_filename)}-{uuid.uuid4().hex[:10]}.{extension}" |
| image_path = UPLOADS_DIR / final_name |
| image_path.write_bytes(image_bytes) |
| return f"uploads/{final_name}" |
|
|
|
|
| def normalize_image_inputs(payload: dict[str, Any]) -> list[dict[str, str]]: |
| images_field = payload.get("images") |
| single_data_url = str(payload.get("image_data_url", "")).strip() |
| single_filename = str(payload.get("image_filename", "")).strip() |
|
|
| normalized: list[dict[str, str]] = [] |
| if isinstance(images_field, list) and images_field: |
| if len(images_field) > MAX_BATCH_IMAGES: |
| raise ValueError(f"Too many images. Max is {MAX_BATCH_IMAGES}.") |
| for idx, item in enumerate(images_field): |
| if not isinstance(item, dict): |
| raise ValueError(f"images[{idx}] must be an object.") |
| data_url = str(item.get("data_url", "")).strip() |
| filename = str(item.get("filename", "")).strip() or f"image-{idx + 1}.png" |
| if not data_url: |
| raise ValueError(f"images[{idx}].data_url is required.") |
| normalized.append({"data_url": data_url, "filename": filename}) |
| elif single_data_url: |
| normalized.append( |
| { |
| "data_url": single_data_url, |
| "filename": single_filename or "image.png", |
| } |
| ) |
| return normalized |
|
|
|
|
| def run_single_check(prompt: str) -> tuple[bool, int, dict[str, Any]]: |
| try: |
| raw_output = run_gemini(prompt) |
| return True, 200, {"parsed_output": try_parse_json(raw_output), "raw_output": raw_output} |
| except FileNotFoundError: |
| return ( |
| False, |
| 500, |
| {"error": f"Gemini CLI not found. Install it and ensure '{GEMINI_CLI_BINARY}' is on PATH."}, |
| ) |
| except subprocess.TimeoutExpired: |
| return False, 504, {"error": f"Gemini CLI timed out after {GEMINI_TIMEOUT_SEC}s."} |
| except RuntimeError as err: |
| return False, 500, {"error": str(err)} |
|
|
|
|
| def run_single_image_check( |
| index: int, |
| total: int, |
| image_ref: str, |
| ad_text: str, |
| extra_context: str, |
| system_prompt: str, |
| ) -> dict[str, Any]: |
| print(f"[batch {index}/{total}] starting check for {image_ref}", flush=True) |
| started = monotonic() |
| prompt = build_prompt( |
| ad_text=ad_text, |
| extra_context=extra_context, |
| system_prompt=system_prompt, |
| image_at_path=image_ref, |
| ) |
| ok, _status, result = run_single_check(prompt) |
| elapsed = monotonic() - started |
| status_text = "ok" if ok else "failed" |
| print(f"[batch {index}/{total}] {status_text} in {elapsed:.1f}s", flush=True) |
| return { |
| "index": index, |
| "ok": ok, |
| "image_reference": image_ref, |
| "parsed_output": result.get("parsed_output"), |
| "raw_output": result.get("raw_output"), |
| "error": result.get("error"), |
| } |
|
|
|
|
| class AppHandler(SimpleHTTPRequestHandler): |
| def __init__(self, *args: Any, **kwargs: Any) -> None: |
| super().__init__(*args, directory=str(STATIC_DIR), **kwargs) |
|
|
| def _send_json(self, status: int, payload: dict[str, Any]) -> None: |
| data = json.dumps(payload, ensure_ascii=True).encode("utf-8") |
| self.send_response(status) |
| self.send_header("Content-Type", "application/json; charset=utf-8") |
| self.send_header("Content-Length", str(len(data))) |
| self.end_headers() |
| self.wfile.write(data) |
|
|
| def do_POST(self) -> None: |
| if self.path != "/api/check": |
| self._send_json(404, {"ok": False, "error": "Not found"}) |
| return |
|
|
| content_length = int(self.headers.get("Content-Length", "0")) |
| if content_length <= 0: |
| self._send_json(400, {"ok": False, "error": "Request body is required."}) |
| return |
|
|
| content_type = self.headers.get("Content-Type", "") |
| if "application/json" not in content_type.lower(): |
| self._send_json(400, {"ok": False, "error": "Content-Type must be application/json."}) |
| return |
|
|
| raw_body = self.rfile.read(content_length) |
| try: |
| body_str = raw_body.decode("utf-8") |
| except UnicodeDecodeError: |
| self._send_json(400, {"ok": False, "error": "Body contains invalid UTF-8 data."}) |
| return |
|
|
| try: |
| payload = json.loads(body_str) |
| except json.JSONDecodeError: |
| self._send_json(400, {"ok": False, "error": "Body must be valid JSON."}) |
| return |
|
|
| ad_text = str(payload.get("ad_text", "")).strip() |
| extra_context = str(payload.get("extra_context", "")).strip() |
| system_prompt = str(payload.get("system_prompt", DEFAULT_SYSTEM_PROMPT)).strip() |
|
|
| try: |
| image_inputs = normalize_image_inputs(payload) |
| except ValueError as err: |
| self._send_json(400, {"ok": False, "error": str(err)}) |
| return |
|
|
| if not ad_text and not image_inputs: |
| self._send_json(400, {"ok": False, "error": "Provide 'ad_text' or an image."}) |
| return |
|
|
| if not system_prompt: |
| system_prompt = DEFAULT_SYSTEM_PROMPT |
|
|
| if not image_inputs: |
| prompt = build_prompt( |
| ad_text=ad_text, |
| extra_context=extra_context, |
| system_prompt=system_prompt, |
| image_at_path=None, |
| ) |
| ok, status, result = run_single_check(prompt) |
| if not ok: |
| self._send_json(status, {"ok": False, "error": result["error"]}) |
| return |
| self._send_json( |
| 200, |
| { |
| "ok": True, |
| "mode": "single", |
| "parallel_workers": 1, |
| "all_success": True, |
| "total": 1, |
| "success_count": 1, |
| "failure_count": 0, |
| "results": [ |
| { |
| "index": 1, |
| "ok": True, |
| "image_reference": None, |
| "parsed_output": result["parsed_output"], |
| "raw_output": result["raw_output"], |
| "error": None, |
| } |
| ], |
| "parsed_output": result["parsed_output"], |
| "raw_output": result["raw_output"], |
| "image_reference": None, |
| }, |
| ) |
| return |
|
|
| image_refs: list[str] = [] |
| for image in image_inputs: |
| try: |
| image_ref = save_image_from_data_url( |
| image_data_url=image["data_url"], image_filename=image["filename"] |
| ) |
| except ValueError as err: |
| self._send_json(400, {"ok": False, "error": str(err)}) |
| return |
| image_refs.append(image_ref) |
|
|
| total = len(image_refs) |
| worker_count = max(1, min(MAX_PARALLEL_WORKERS, total)) |
| print( |
| f"Starting bulk Gemini checks: total_images={total}, parallel_workers={worker_count}", |
| flush=True, |
| ) |
|
|
| results: list[dict[str, Any] | None] = [None] * total |
| completed = 0 |
| with ThreadPoolExecutor(max_workers=worker_count) as executor: |
| future_to_slot = { |
| executor.submit( |
| run_single_image_check, |
| idx, |
| total, |
| image_ref, |
| ad_text, |
| extra_context, |
| system_prompt, |
| ): (idx - 1, image_ref) |
| for idx, image_ref in enumerate(image_refs, start=1) |
| } |
| for future in as_completed(future_to_slot): |
| slot, image_ref = future_to_slot[future] |
| try: |
| results[slot] = future.result() |
| except Exception as err: |
| |
| results[slot] = { |
| "index": slot + 1, |
| "ok": False, |
| "image_reference": image_ref, |
| "parsed_output": None, |
| "raw_output": None, |
| "error": f"Unexpected worker error: {err}", |
| } |
| completed += 1 |
| print(f"Bulk progress: {completed}/{total} completed", flush=True) |
|
|
| finalized_results = [item for item in results if item is not None] |
| finalized_results.sort(key=lambda item: int(item["index"])) |
|
|
| success_count = sum(1 for item in finalized_results if item["ok"]) |
| failure_count = len(finalized_results) - success_count |
| first = finalized_results[0] |
| self._send_json( |
| 200, |
| { |
| "ok": True, |
| "mode": "bulk" if len(finalized_results) > 1 else "single", |
| "parallel_workers": worker_count, |
| "all_success": failure_count == 0, |
| "total": len(finalized_results), |
| "success_count": success_count, |
| "failure_count": failure_count, |
| "results": finalized_results, |
| |
| "parsed_output": first.get("parsed_output"), |
| "raw_output": first.get("raw_output"), |
| "image_reference": first.get("image_reference"), |
| }, |
| ) |
|
|
|
|
| def main() -> None: |
| STATIC_DIR.mkdir(parents=True, exist_ok=True) |
| UPLOADS_DIR.mkdir(parents=True, exist_ok=True) |
| server = ThreadingHTTPServer((HOST, PORT), AppHandler) |
| print(f"Server running at http://{HOST}:{PORT}") |
| server.serve_forever() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|