Source code for veil.overlap_resolver

from __future__ import annotations

from typing import Dict, List, Optional, Tuple

from veil.config.overlap_resolver import OverlapResolverConfig
from veil.core.base_entity_type import EntityTypeBase
from veil.core.span import Span
from veil.logger import init_logger

logger = init_logger(__name__)


[docs] class OverlapResolver: """Resolve overlaps between spans coming from multiple detectors using priorities only. Selection order: 1) Lower numeric priority value first (0 is highest), per canonical entity type and detector 2) Earlier hierarchy position (smaller number wins), when provided 3) Only spans of the SAME canonical entity type are considered conflicting. Cross-type overlaps are allowed. 4) If a conflict remains after (1) and (2) for the same type, break ties deterministically by: - higher confidence first; then longer span; then earlier start; otherwise keep the existing selection """ def __init__(self, config: OverlapResolverConfig) -> None: self.iou_threshold = config.iou_threshold self.overlap_policy = config.overlap_policy.lower() @staticmethod def _span_brief(span: Span) -> dict: try: type_name = OverlapResolver._canonical_type_name(span) except Exception: type_name = None return { "start": int(getattr(span, "start", 0)), "end": int(getattr(span, "end", 0)), "type": type_name, "id": getattr(span, "id", None), "confidence": getattr(span, "confidence", None), } @staticmethod def _canonical_type_name(span: Span) -> Optional[str]: et = getattr(span, "entity_type", None) if et is None: return None name: Optional[str] = getattr(et, "name", None) if not isinstance(name, str): if isinstance(et, str): name = et else: try: name = str(et) except Exception: name = None if name is None: return None alias_map = EntityTypeBase.global_alias_map() return alias_map.get(name.upper(), name.upper()) @staticmethod def _iou(a: Span, b: Span) -> float: inter = max(0, min(int(a.end), int(b.end)) - max(int(a.start), int(b.start))) if inter <= 0: return 0.0 len_a = int(a.end) - int(a.start) len_b = int(b.end) - int(b.start) union = len_a + len_b - inter if union <= 0: return 0.0 return float(inter) / float(union)
[docs] def resolve( self, *, component_keys_in_order: List[str], component_to_spans: Dict[str, List[Span]], component_to_priority: Dict[str, Dict[str, int]], component_to_hierarchy: Dict[str, int] | None = None, ) -> List[Span]: # Build sortable candidates # (priority, detector_order, start, end, component_key, span) Candidate = Tuple[int, int, int, int, str, Span] candidates: List[Candidate] = [] for order_idx, key in enumerate(component_keys_in_order): spans = component_to_spans.get(key, []) # Deduplicate identical spans (same start, end, canonical type) within the same component seen: set[tuple[int, int, Optional[str]]] = set() pr_map = component_to_priority.get(key, {}) for s in spans: ctype = self._canonical_type_name(s) start_i = int(getattr(s, "start", 0)) end_i = int(getattr(s, "end", 0)) dedupe_key = (start_i, end_i, ctype) if dedupe_key in seen: continue seen.add(dedupe_key) base_priority = pr_map.get(ctype or "", 1_000_000) priority_val = int(base_priority) candidates.append((priority_val, order_idx, start_i, end_i, key, s)) try: logger.debug( "OverlapResolver: built %d candidates from %d components (policy=%s, iou_threshold=%.3f)", len(candidates), len(component_keys_in_order), self.overlap_policy, float(self.iou_threshold), ) except Exception: pass # Sort (without using detector order or span-derived tiebreakers): # 1) lower numeric priority first (0 is highest) # 2) earlier hierarchy position (lower numeric wins) when provided if component_to_hierarchy: def _sort_key(c: Candidate): comp = c[4] hier = int(component_to_hierarchy.get(comp, 1_000_000)) return (c[0], hier) candidates.sort(key=_sort_key) else: candidates.sort(key=lambda c: (c[0],)) selected: List[Tuple[Span, int, str]] = [] # (span, priority, component_key) for prio, _ord, _start, _end, comp_key, span in candidates: try: logger.debug( "Candidate: %s prio=%d comp=%s", self._span_brief(span), prio, comp_key, ) except Exception: pass conflict = False for kept_span, kept_prio, kept_comp in selected: try: cand_type = self._canonical_type_name(span) kept_type = self._canonical_type_name(kept_span) except Exception: cand_type = None kept_type = None iou_val = self._iou(span, kept_span) same_type_conflict = ( cand_type is not None and cand_type == kept_type and iou_val > self.iou_threshold ) cross_type_conflict = ( self.overlap_policy == "cross_type" and iou_val > self.iou_threshold ) try: logger.debug( "Check vs kept: kept=%s prio=%d comp=%s | types=(%s vs %s) iou=%.3f -> same_type_conflict=%s cross_type_conflict=%s", self._span_brief(kept_span), kept_prio, kept_comp, cand_type, kept_type, float(iou_val), same_type_conflict, cross_type_conflict, ) except Exception: pass if same_type_conflict or cross_type_conflict: conflict = True if prio == kept_prio: # Exact duplicate from same component -> skip silently if ( comp_key == kept_comp and int(getattr(span, "start", 0)) == int(getattr(kept_span, "start", 0)) and int(getattr(span, "end", 0)) == int(getattr(kept_span, "end", 0)) ): try: logger.debug( "Decision: duplicate from same component -> skip candidate" ) except Exception: pass break # Prefer lower hierarchy position if provided cand_h = ( int(component_to_hierarchy.get(comp_key, 1_000_000)) if component_to_hierarchy is not None else 1_000_000 ) kept_h = ( int(component_to_hierarchy.get(kept_comp, 1_000_000)) if component_to_hierarchy is not None else 1_000_000 ) if cand_h < kept_h: try: logger.debug( "Decision: cand wins by hierarchy (%d < %d)", cand_h, kept_h, ) except Exception: pass selected.remove((kept_span, kept_prio, kept_comp)) selected.append((span, prio, comp_key)) break if cand_h == kept_h: # Deterministic tiebreaker for equal priority and hierarchy within same type cand_conf = float(getattr(span, "confidence", 0) or 0) kept_conf = float(getattr(kept_span, "confidence", 0) or 0) if cand_conf > kept_conf: try: logger.debug( "Decision: cand wins by confidence (%.4f > %.4f)", cand_conf, kept_conf, ) except Exception: pass selected.remove((kept_span, kept_prio, kept_comp)) selected.append((span, prio, comp_key)) break if cand_conf == kept_conf: cand_len = int(getattr(span, "end", 0)) - int( getattr(span, "start", 0) ) kept_len = int(getattr(kept_span, "end", 0)) - int( getattr(kept_span, "start", 0) ) if cand_len > kept_len: try: logger.debug( "Decision: cand wins by length (%d > %d)", cand_len, kept_len, ) except Exception: pass selected.remove((kept_span, kept_prio, kept_comp)) selected.append((span, prio, comp_key)) break if cand_len == kept_len: cand_start = int(getattr(span, "start", 0)) kept_start = int(getattr(kept_span, "start", 0)) if cand_start < kept_start: try: logger.debug( "Decision: cand wins by earlier start (%d < %d)", cand_start, kept_start, ) except Exception: pass selected.remove( (kept_span, kept_prio, kept_comp) ) selected.append((span, prio, comp_key)) break # Otherwise, keep the previously selected span # Otherwise, kept wins; do nothing try: logger.debug( "Decision: kept wins (higher hierarchy or tiebreakers)" ) except Exception: pass break if not conflict: try: logger.debug("No conflict -> append candidate to selected") except Exception: pass selected.append((span, prio, comp_key)) result = [s for (s, _p, _k) in selected] try: logger.debug("OverlapResolver: selected %d spans", len(result)) except Exception: pass return result