Source code for veil.entity_detectors.api.masker_api_entity_detector

from __future__ import annotations

import json
import logging
import re
import time
from typing import Dict, List, Optional, Set, Tuple
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen

from veil.config.entity_detectors import MaskerApiEntityDetectorConfig
from veil.core.base_entity_detector import BaseEntityDetector
from veil.core.base_entity_type import EntityTypeBase
from veil.core.document import Document
from veil.core.span import Span
from veil.datahandler import DataHandler

logger = logging.getLogger(__name__)


[docs] class MaskerApiEntityDetector(BaseEntityDetector[EntityTypeBase]): ENTITY_TYPES: Set[EntityTypeBase] def __init__(self, cfg: MaskerApiEntityDetectorConfig): """Initialize the masking-API entity detector with a configuration. Expected config attributes (subclasses should provide): - api_url | url | endpoint: str -> endpoint to POST masking requests to - headers: dict[str, str] (optional) - timeout: float | int (optional, default 30) """ super().__init__(cfg) # Allow flexible config attribute names to avoid tight coupling here self.api_url = cfg.api_url if not self.api_url: raise ValueError( "Masking API URL not provided in config. Expected 'api_url', 'url' or 'endpoint'." ) self.timeout: float = float(getattr(cfg, "timeout", 30)) headers: Dict[str, str] | None = getattr(cfg, "headers", None) self.headers: Dict[str, str] = dict(headers) if headers else {} # Optional chat API parameters (e.g., Fireworks) self.model: str | None = getattr(cfg, "model", None) self.system_prompt: str | None = getattr(cfg, "system_prompt", None) self.max_tokens: int = int(getattr(cfg, "max_tokens", 4000)) self.top_p: float = float(getattr(cfg, "top_p", 1)) self.top_k: int = int(getattr(cfg, "top_k", 40)) self.presence_penalty: float = float(getattr(cfg, "presence_penalty", 0)) self.frequency_penalty: float = float(getattr(cfg, "frequency_penalty", 0)) self.temperature: float = float(getattr(cfg, "temperature", 0.6)) # Resilience controls self.retries: int = int(getattr(cfg, "retries", 2)) self.retry_backoff_base: float = float(getattr(cfg, "retry_backoff_base", 0.5)) self.retry_on_truncation: bool = bool(getattr(cfg, "retry_on_truncation", True)) self.chunk_on_truncation: bool = bool(getattr(cfg, "chunk_on_truncation", True)) self.chunk_char_limit: int = int(getattr(cfg, "chunk_char_limit", 4000)) self.truncation_min_fraction: float = float( getattr(cfg, "truncation_min_fraction", 0.6) )
[docs] def send_request(self, doc: Document) -> str: """ Send request to the masking API and return the masked text. """ masked_text, _ = self._send_request_internal(doc) return masked_text
def _send_request_internal(self, doc: Document) -> Tuple[str, bool]: """ Send one request and return (masked_text, truncated_flag). Does not implement retries; the caller decides. """ # If a model is configured, assume a chat-completions style API (e.g., Fireworks) if self.model: supported_types = ", ".join( getattr(et, "name", str(et)) for et in getattr(self, "ENTITY_TYPES", set()) ) system_prompt = self.system_prompt or ( "You are an anonymization engine. Mask the specified entity types by" " replacing each sensitive span with a token in the form [ENTITY] where" " ENTITY is the entity type name in UPPERCASE. Only output the fully" " masked text, preserving original spacing and punctuation as much as possible." f" Entity types to mask: {supported_types}." ) payload: Dict[str, object] = { "model": self.model, "max_tokens": self.max_tokens, "top_p": self.top_p, "top_k": self.top_k, "presence_penalty": self.presence_penalty, "frequency_penalty": self.frequency_penalty, "temperature": self.temperature, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": doc.text}, ], } else: payload = { "text": doc.text, # Provide entity type names if defined for the subclass; remote API may ignore "entity_types": [ getattr(et, "name", str(et)) for et in getattr(self, "ENTITY_TYPES", set()) ], } data = json.dumps(payload).encode("utf-8") # avoid Cloudflare bot blocks with a reasonable user-agent default_headers = { "Content-Type": "application/json", "Accept": "application/json", "User-Agent": self.headers.get( "User-Agent", "veil-masker/0.1 (+https://github.com/chus-chus/veil) Python-urllib", ), } headers = {**default_headers, **self.headers} req = Request(self.api_url, data=data, headers=headers, method="POST") logger.info( f"Sent request for document {doc.doc_id} to '{self.api_url}', model '{self.model}'" ) try: with urlopen(req, timeout=self.timeout) as resp: resp_bytes = resp.read() content_type = resp.headers.get("Content-Type", "") content_length_header = resp.headers.get("Content-Length") except HTTPError as e: body = ( e.read().decode("utf-8", errors="replace") if hasattr(e, "read") else "" ) logger.error(f"HTTP error from masking API: {e.code} {e.reason} - {body}") raise except URLError as e: logger.error(f"Failed to reach masking API: {e.reason}") raise text = resp_bytes.decode("utf-8", errors="replace") # Try JSON response first masked_text: str | None = None truncated_flag: bool = False finish_reason: str | None = None try: if "json" in content_type or ( text and text.lstrip().startswith(("{", "[")) ): obj = json.loads(text) if isinstance(obj, dict): masked_text = ( obj.get("masked_text") or obj.get("maskedText") or obj.get("text") ) # Fireworks-like chat completions: choices[0].message.content if masked_text is None: choices = obj.get("choices") if isinstance(choices, list) and choices: first = choices[0] if isinstance(first, dict): msg = first.get("message") if isinstance(msg, dict): content = msg.get("content") if isinstance(content, str): masked_text = content # OpenAI/Fireworks style finish_reason fr = first.get("finish_reason") or first.get( "finishReason" ) if isinstance(fr, str): finish_reason = fr # Some providers include a top-level stop/finish reason if finish_reason is None: fr_top = ( obj.get("finish_reason") or obj.get("finishReason") or obj.get("stop_reason") ) if isinstance(fr_top, str): finish_reason = fr_top elif isinstance(obj, str): masked_text = obj except json.JSONDecodeError: # Fall back to plain text below pass if masked_text is None: masked_text = text if not isinstance(masked_text, str): raise ValueError( "Masking API response did not contain a masked text string" ) # Detect truncation heuristics try: if finish_reason: # 'length' is common indicator of hitting max tokens if finish_reason.lower() in {"length", "max_tokens", "maxTokens"}: truncated_flag = True # Network-level incomplete transfer if not truncated_flag and content_length_header is not None: try: expected_len = int(content_length_header) if expected_len > len(resp_bytes): truncated_flag = True except ValueError: pass # Heuristic on relative lengths if not truncated_flag: if len(masked_text) < self.truncation_min_fraction * len(doc.text): # If much shorter than original and doesn’t end on a clean boundary tail = masked_text[-64:] if tail and not tail.strip().endswith( (".", "!", "?", ")", "]", "\n") ): truncated_flag = True # Unmatched token bracket at the end if ( not truncated_flag and "[" in masked_text and not masked_text.rstrip().endswith("]") ): # Unclosed token near the tail may indicate cut-off if masked_text.rfind("[") > masked_text.rfind("]"): truncated_flag = True except Exception: # Heuristics must not fail the call pass return masked_text, truncated_flag
[docs] def send_request_with_retries(self, doc: Document) -> Tuple[str, bool]: """ Send request with retries and backoff. Returns (masked_text, truncated_flag). """ last_err: Exception | None = None masked_text: str = "" truncated: bool = False for attempt in range(self.retries + 1): try: masked_text, truncated = self._send_request_internal(doc) if truncated and self.retry_on_truncation and attempt < self.retries: backoff = self.retry_backoff_base * (2**attempt) time.sleep(backoff) continue return masked_text, truncated except (HTTPError, URLError) as e: last_err = e if attempt >= self.retries: raise backoff = self.retry_backoff_base * (2**attempt) time.sleep(backoff) if last_err: raise last_err return masked_text, truncated
[docs] def detect_entities(self, doc: Document) -> List[Span]: """ Detect entities by delegating to a remote masking API and parsing its masked output. """ masked_text, truncated = self.send_request_with_retries(doc) # logger.info("-" * 100) # logger.info(f"Original text: {doc.text}") # logger.info(f"Masked text: {masked_text}") # logger.info(f"Truncated: {truncated}") # logger.info("*" * 100) if truncated and self.chunk_on_truncation: spans = self._detect_with_chunking(doc) logger.info( f"Detected {len(spans)} entity spans in document {doc.doc_id} using chunking fallback" ) return self._map_and_filter_supported_spans(spans) spans = self._diff_to_spans(doc.text, masked_text) logger.info(f"Detected {len(spans)} entity spans in document {doc.doc_id}") # logger.info(f"Detected entities: {spans}") return self._map_and_filter_supported_spans(spans)
def _detect_with_chunking(self, doc: Document) -> List[Span]: """ Fallback: split the document text into manageable chunks and merge spans. """ text = doc.text if not text: return [] limit = max(256, int(self.chunk_char_limit)) overlap = min(128, max(32, limit // 16)) chunks: List[tuple[int, str]] = [] # (start_index, chunk_text) start = 0 n = len(text) while start < n: end = min(start + limit, n) if end < n: # try to break at whitespace near the boundary ws = text.rfind(" ", start, end) if ws != -1 and ws - start >= limit * 0.7: end = ws chunk_text = text[start:end] chunks.append((start, chunk_text)) if end >= n: break # advance with overlap to avoid cutting entities start = max(0, end - overlap) merged_spans: List[Span] = [] for idx, (offset, chunk_text) in enumerate(chunks): sub_doc = Document(text=chunk_text, doc_id=f"{doc.doc_id}-chunk-{idx}") chunk_masked, _ = self.send_request_with_retries(sub_doc) chunk_spans = self._diff_to_spans(chunk_text, chunk_masked) for sp in chunk_spans: adj = Span( start=sp.start + offset, end=sp.end + offset, entity_type=sp.entity_type, # type: ignore[arg-type] id=sp.id, replacement=sp.replacement, confidence=sp.confidence, ) merged_spans.append(adj) # Sort and coalesce overlapping spans conservatively (prefer earlier span) merged_spans.sort(key=lambda s: (s.start, s.end)) coalesced: List[Span] = [] for sp in merged_spans: if not coalesced or sp.start >= coalesced[-1].end: coalesced.append(sp) else: # overlap: extend the last one if this extends further last = coalesced[-1] if sp.end > last.end: coalesced[-1] = Span( start=last.start, end=sp.end, entity_type=last.entity_type, # type: ignore[arg-type] id=last.id, replacement=doc.text[last.start : sp.end], ) return coalesced def _map_and_filter_supported_spans(self, spans: List[Span]) -> List[Span]: """ Convert string entity types in spans to enum members and drop unsupported types. """ if not spans: return spans # Build mapping from enum name to enum member for supported types supported_map: Dict[str, EntityTypeBase] = { getattr(e, "name", str(e)).upper(): e for e in getattr(self, "ENTITY_TYPES", set()) } filtered: List[Span] = [] for sp in spans: et = getattr(sp, "entity_type", None) if isinstance(et, EntityTypeBase): # Already an enum member; ensure it's supported if getattr(et, "name", str(et)).upper() in supported_map: filtered.append(sp) else: logger.warning( f"Dropping span with unsupported entity type: {getattr(et, 'name', et)}" ) continue # Attempt to map from string to enum member if isinstance(et, str): key = et.upper() if key in supported_map: mapped = supported_map[key] filtered.append( Span( start=sp.start, end=sp.end, entity_type=mapped, id=sp.id, replacement=sp.replacement, confidence=sp.confidence, ) ) else: logger.warning( f"Dropping span with unknown/unsupported entity type from API output: {et}" ) else: # Unknown type information; drop logger.warning( "Dropping span with missing entity type information from API output" ) return filtered @staticmethod def _diff_to_spans(original: str, masked: str) -> List[Span]: def normalise_with_map(s: str) -> tuple[str, List[int]]: norm_chars: List[str] = [] index_map: List[int] = [] i = 0 n = len(s) while i < n: ch = s[i] if ch.isspace(): start_ws = i while i < n and s[i].isspace(): i += 1 norm_chars.append(" ") index_map.append(start_ws) continue norm_chars.append(ch) index_map.append(i) i += 1 return ("".join(norm_chars), index_map) def map_norm_to_orig( norm_index: int, index_map: List[int], orig_len: int ) -> int: if norm_index <= 0: return 0 if norm_index >= len(index_map): return orig_len return index_map[norm_index] spans: List[Span] = [] token_re = re.compile(r"\[(?P<etype>[A-Z_]+)(?P<num>\d*)\]", re.IGNORECASE) norm_orig, orig_map = normalise_with_map(original) tokens = list(token_re.finditer(masked)) if not tokens: return [] masked_segments: List[tuple[str, Optional[re.Match[str]]]] = [] prev = 0 for m in tokens: if m.start() > prev: masked_segments.append((masked[prev : m.start()], None)) masked_segments.append((masked[m.start() : m.end()], m)) prev = m.end() if prev < len(masked): masked_segments.append((masked[prev:], None)) def normalise_segment(seg: str) -> str: return normalise_with_map(seg)[0] def find_in_norm(haystack: str, needle: str, start: int) -> int: if not needle: return start return haystack.find(needle, start) def find_fallback_end(haystack: str, start: int) -> int: # Try to cap the wildcard span at a reasonable boundary if the next literal is not found. # Prefer sentence/phrase delimiters; otherwise cap to a max length. delimiters = {",", ";", ":", ".", ")"} max_len = 512 end = len(haystack) for i in range(start, len(haystack)): if haystack[i] in delimiters and ( i + 1 >= len(haystack) or haystack[i + 1] == " " ): end = i break # hard cap if i - start >= max_len: end = i break return end pos_norm_orig = 0 i = 0 while i < len(masked_segments): seg_text, seg_tok = masked_segments[i] if seg_tok is None: seg_norm = normalise_segment(seg_text) if seg_norm: found = find_in_norm(norm_orig, seg_norm, pos_norm_orig) if found != -1: pos_norm_orig = found + len(seg_norm) i += 1 continue start_norm = pos_norm_orig j = i + 1 next_text_norm = "" while j < len(masked_segments): t_text, t_tok = masked_segments[j] if t_tok is None: next_text_norm = normalise_segment(t_text) break j += 1 if next_text_norm: end_found = find_in_norm(norm_orig, next_text_norm, start_norm) if end_found == -1: end_norm = find_fallback_end(norm_orig, start_norm) else: end_norm = end_found else: end_norm = find_fallback_end(norm_orig, start_norm) start_idx = map_norm_to_orig(start_norm, orig_map, len(original)) end_idx = map_norm_to_orig(end_norm, orig_map, len(original)) if 0 <= start_idx <= end_idx <= len(original) and start_idx != end_idx: entity_type_raw = seg_tok.group("etype").upper() entity_type = DataHandler._canonicalize_type_str(entity_type_raw) num_suffix = seg_tok.group("num") if num_suffix: normalized_id = num_suffix.lstrip("0") span_id = normalized_id if normalized_id != "" else "0" else: span_id = None sensitive_text = original[start_idx:end_idx] spans.append( Span( start=start_idx, end=end_idx, entity_type=entity_type, # type: ignore[arg-type] id=span_id, replacement=sensitive_text, confidence=1.0, ) ) pos_norm_orig = end_norm i += 1 spans.sort(key=lambda s: (s.start, s.end)) merged: List[Span] = [] for sp in spans: if not merged or sp.start >= merged[-1].end: merged.append(sp) else: last = merged[-1] if sp.end > last.end: merged[-1] = Span( start=last.start, end=sp.end, entity_type=last.entity_type, # type: ignore[arg-type] id=last.id, replacement=original[last.start : sp.end], ) return merged