Source code for veil.evaluator

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from veil.config.evaluator import EvaluatorConfig
from veil.core.base_entity_type import EntityTypeBase
from veil.core.document import Document
from veil.core.mask_result import MaskResult
from veil.core.span import Span
from veil.logger import init_logger

logger = init_logger(__name__)

if TYPE_CHECKING:  # pragma: no cover
    from veil.metric_store import MetricStore


[docs] @dataclass class ConfusionCounts: tp: int = 0 fp: int = 0 fn: int = 0
[docs] def to_dict(self) -> Dict[str, int]: return {"tp": int(self.tp), "fp": int(self.fp), "fn": int(self.fn)}
[docs] class Evaluator: def __init__(self, config: EvaluatorConfig): self.config = config # ----------------------------- # Span matching helpers # ----------------------------- def _span_brief(self, span: Span) -> Dict[str, Any]: """Return a concise dict for logging: start, end, canonical type, id.""" try: type_name = self._canonical_type(self._get_type_name(span)) except Exception: type_name = None return { "start": getattr(span, "start", None), "end": getattr(span, "end", None), "type": type_name, "id": getattr(span, "id", None), } @staticmethod def _span_iou(a: Span, b: Span) -> float: inter = max(0, min(a.end, b.end) - max(a.start, b.start)) if inter <= 0: return 0.0 union = (a.end - a.start) + (b.end - b.start) - inter if union <= 0: return 0.0 return float(inter) / float(union) def _canonical_type(self, raw: Optional[str]) -> str: # Normalize to upper and apply alias maps defined by enums only. name = (raw or "UNKNOWN").upper() alias_map = EntityTypeBase.global_alias_map() return alias_map.get(name, name) def _is_type_match(self, a: Optional[str], b: Optional[str]) -> bool: if not self.config.strict_type: return True return self._canonical_type(a) == self._canonical_type(b) def _get_type_name(self, span: Span) -> Optional[str]: et = getattr(span, "entity_type", None) if et is None: return None # Support either enums with .name or plain strings in ground truth name = getattr(et, "name", None) if isinstance(name, str): return name if isinstance(et, str): return et # Fallback to str(et) last-resort try: return str(et) except Exception: return None def _match_pred_to_gold_mode( self, preds: List[Span], golds: List[Span], *, mode: str, iou_threshold: float | None = None, ) -> Tuple[List[Tuple[int, int]], List[int], List[int]]: """Greedy matching of predicted spans to ground-truth spans for a given mode. mode: "exact", "exact_by_id", "iou", or "iou_by_id". When mode starts with "iou", iou_threshold must be provided. Returns (matches, unmatched_pred_indices, unmatched_gold_indices). """ matches: List[Tuple[int, int]] = [] used_gold: set[int] = set() # When evaluating by_id, build a mapping from predicted ID labels to gold ID labels # that maximizes compatible overlaps per entity type, making the metric invariant # to arbitrary ID label permutations. Optionally use Hungarian-like optimal assignment. pred_id_to_gold_id: Dict[Tuple[str, str], Optional[str]] = {} if "by_id" in mode: logger.debug( "By-id matching start: mode=%s, preds=%d, golds=%d, iou_thr=%s", mode, len(preds), len(golds), str(iou_threshold), ) # Group indices by (type, id) preds_by_cluster: Dict[Tuple[str, str], List[int]] = {} golds_by_cluster: Dict[Tuple[str, str], List[int]] = {} def _norm_type(span: Span) -> str: return self._canonical_type(self._get_type_name(span)) for pi, p in enumerate(preds): pid_raw = getattr(p, "id", None) if pid_raw is None: continue key = (_norm_type(p), str(pid_raw)) preds_by_cluster.setdefault(key, []).append(pi) for gi, g in enumerate(golds): gid_raw = getattr(g, "id", None) if gid_raw is None: continue key = (_norm_type(g), str(gid_raw)) golds_by_cluster.setdefault(key, []).append(gi) # Build overlap scores between predicted and gold clusters of the same type # Score: number of span pairs meeting the positional condition (exact/iou) # We also keep earliest position to break ties deterministically. clusters_by_type: Dict[str, List[Tuple[str, List[int]]]] = {} for (t, pid), idxs in preds_by_cluster.items(): clusters_by_type.setdefault(t, []).append((pid, idxs)) for etype, pred_clusters in clusters_by_type.items(): # Collect gold clusters for this type gold_clusters: Dict[str, List[int]] = { gid: idxs for (t2, gid), idxs in golds_by_cluster.items() if t2 == etype } # Build candidate list: (score, earliest_gold_pos, pid, gid) candidates: List[Tuple[int, int, str, str]] = [] for pid, p_indices in pred_clusters: for gid, g_indices in gold_clusters.items(): score = 0 for pi in p_indices: p = preds[pi] for gi in g_indices: g = golds[gi] if mode == "exact_by_id": ok = p.start == g.start and p.end == g.end else: # iou_by_id thr = ( float(iou_threshold) if iou_threshold is not None else 0.0 ) ok = self._span_iou(p, g) >= thr if ok: score += 1 if score > 0: # earliest gold position earliest_g = min( (golds[gi].start for gi in g_indices), default=10**9 ) candidates.append((score, earliest_g, pid, gid)) if getattr(self.config, "use_hungarian_id_mapping", False): # Hungarian-like optimal assignment via cost matrix built from -score # Build index mapping pid_list = [pid for pid, _ in pred_clusters] gid_list = list(gold_clusters.keys()) p_index = {pid: i for i, pid in enumerate(pid_list)} g_index = {gid: j for j, gid in enumerate(gid_list)} # Initialize cost matrix with large costs P, G = len(pid_list), len(gid_list) INF = 10**9 cost = [[INF for _ in range(G)] for _ in range(P)] # Fill with -score to maximize score for score, _pos, pid, gid in candidates: i = p_index[pid] j = g_index[gid] cost[i][j] = -score # Simple O(P!*G!) fallback to greedy when trivial sizes # For simplicity and to avoid heavy deps, we implement a minimal Kuhn-Munkres variant for small matrices # If matrix is empty, skip if P > 0 and G > 0: assignment = _hungarian_min_cost(cost) for i, j in assignment: if i is None or j is None: continue if cost[i][j] == INF: continue pred_id_to_gold_id[(etype, pid_list[i])] = gid_list[j] # Unassigned predicted clusters map to None for pid in pid_list: pred_id_to_gold_id.setdefault((etype, pid), None) else: # Greedy assignment by score desc, then earliest gold position candidates.sort(key=lambda x: (-x[0], x[1])) used_gids: set[str] = set() assigned_pids: set[str] = set() for score, _pos, pid, gid in candidates: if pid in assigned_pids or gid in used_gids: continue pred_id_to_gold_id[(etype, pid)] = gid assigned_pids.add(pid) used_gids.add(gid) # Unassigned predicted clusters map to None (won't match any gold id) for pid, _ in pred_clusters: pred_id_to_gold_id.setdefault((etype, pid), None) for pi, p in enumerate(preds): p_type = self._get_type_name(p) best_gi = -1 best_score = -1.0 for gi, g in enumerate(golds): if gi in used_gold: continue g_type = self._get_type_name(g) if not self._is_type_match(p_type, g_type): continue if mode == "exact": score = 1.0 if (p.start == g.start and p.end == g.end) else 0.0 elif mode == "exact_by_id": # Must have identical boundaries and IDs equal up to optimal cluster relabeling pid_raw = getattr(p, "id", None) gid = getattr(g, "id", None) # Map predicted id to gold id namespace based on overlap mapping mapped_pid = None if pid_raw is not None: mapped_pid = pred_id_to_gold_id.get( (self._canonical_type(p_type), str(pid_raw)) ) score = ( 1.0 if ( p.start == g.start and p.end == g.end and mapped_pid is not None and gid is not None and mapped_pid == str(gid) ) else 0.0 ) elif mode == "iou_by_id": pid_raw = getattr(p, "id", None) gid = getattr(g, "id", None) mapped_pid = None if pid_raw is not None: mapped_pid = pred_id_to_gold_id.get( (self._canonical_type(p_type), str(pid_raw)) ) if ( mapped_pid is not None and gid is not None and mapped_pid == str(gid) ): score = self._span_iou(p, g) else: score = 0.0 else: score = self._span_iou(p, g) if score > best_score: best_score = score best_gi = gi if best_gi != -1: if mode in ("exact", "exact_by_id") and best_score >= 1.0: matches.append((pi, best_gi)) used_gold.add(best_gi) elif ( mode in ("iou", "iou_by_id") and iou_threshold is not None and best_score >= float(iou_threshold) ): matches.append((pi, best_gi)) used_gold.add(best_gi) if "by_id" in mode: logger.debug( "By-id selection: pred_idx=%d %s -> best_gold_idx=%s score=%.4f", pi, self._span_brief(p), str(best_gi), float(best_score), ) unmatched_preds = [ i for i in range(len(preds)) if i not in {m[0] for m in matches} ] unmatched_golds = [i for i in range(len(golds)) if i not in used_gold] return matches, unmatched_preds, unmatched_golds # ----------------------------- # Public API # -----------------------------
[docs] def evaluate_document( self, *, document: Document, mask_result: MaskResult, component_spans: Dict[str, List[Span]] | None, metric_store: Optional["MetricStore"], component_supported_types: Optional[Dict[str, List[str]]] = None, ) -> Optional[Dict[str, Any]]: """Compute confusion matrices for a single document and send to metric store. Returns only variant-based evaluations under "variants": - "exact" (if enabled) - "iou@THRESH" for each threshold in config.report_iou_thresholds """ if not self.config.enabled: return None if document.ground_truth is None: return None gold_spans: List[Span] = document.ground_truth or [] pred_spans_all: List[Span] = mask_result.entities or [] # High-level visibility into inputs when by-id evaluation is enabled if getattr(self.config, "report_exact_by_id", False): try: gold_log = [self._span_brief(s) for s in (gold_spans or [])] pred_log = [self._span_brief(s) for s in (pred_spans_all or [])] # logger.info( # "Eval by-id inputs (doc_id=%s) gold_spans=%s", # getattr(document, "doc_id", None), # gold_log, # ) # logger.info( # "Eval by-id inputs (doc_id=%s) pred_spans_all=%s", # getattr(document, "doc_id", None), # pred_log, # ) if component_spans: for comp_key, comp_preds in (component_spans or {}).items(): # Also log filtered views based on supported types if provided filtered_preds, filtered_golds = ( (comp_preds or []), (gold_spans or []), ) if component_supported_types is not None: try: filtered_preds, filtered_golds = _filter_supported( comp_key, comp_preds or [] ) except Exception: pass # logger.info( # "Eval by-id component '%s' (doc_id=%s) raw_preds=%s filtered_preds=%s filtered_golds=%s", # comp_key, # getattr(document, "doc_id", None), # [self._span_brief(s) for s in (comp_preds or [])], # [self._span_brief(s) for s in (filtered_preds or [])], # [self._span_brief(s) for s in (filtered_golds or [])], # ) except Exception as e: logger.warning("Failed logging by-id inputs: %s", e) def _filter_supported( comp_key: str, preds: List[Span] ) -> Tuple[List[Span], List[Span]]: supported_raw = ( component_supported_types.get(comp_key) if component_supported_types is not None else None ) if supported_raw is not None: supported_set = {self._canonical_type(s) for s in supported_raw} def _is_supported(span: Span) -> bool: return ( self._canonical_type(self._get_type_name(span)) in supported_set ) filtered_preds = [p for p in (preds or []) if _is_supported(p)] filtered_golds = [g for g in (gold_spans or []) if _is_supported(g)] else: filtered_preds = preds filtered_golds = gold_spans return filtered_preds, filtered_golds def _compute_counts_for_mode( comp_spans: Dict[str, List[Span]] | None, *, mode: str, iou_thr: float | None, ) -> Tuple[ Dict[str, Dict[str, ConfusionCounts]], Dict[str, ConfusionCounts], Dict[str, ConfusionCounts], ConfusionCounts, ]: per_comp_by_type: Dict[str, Dict[str, ConfusionCounts]] = {} per_comp_all_types: Dict[str, ConfusionCounts] = {} # Per-component if comp_spans: for comp_key, comp_preds in comp_spans.items(): filtered_preds, filtered_golds = _filter_supported( comp_key, comp_preds ) by_type, all_types = self._evaluate_pairwise( filtered_preds, filtered_golds, mode=mode, iou_threshold=iou_thr, ) per_comp_by_type[comp_key] = by_type per_comp_all_types[comp_key] = all_types # Global global_by_type, global_all_types = self._evaluate_pairwise( pred_spans_all, gold_spans, mode=mode, iou_threshold=iou_thr, ) return ( per_comp_by_type, per_comp_all_types, global_by_type, global_all_types, ) # Build variants only variants: Dict[str, Dict[str, Any]] = {} if self.config.report_exact: pcbt, pcat, gbt, gat = _compute_counts_for_mode( component_spans, mode="exact", iou_thr=None ) variants["exact"] = { "per_component_by_type": { comp: {etype: cc.to_dict() for etype, cc in counts.items()} for comp, counts in pcbt.items() }, "per_component_all_types": { comp: ct.to_dict() for comp, ct in pcat.items() }, "global_by_type": {et: cc.to_dict() for et, cc in gbt.items()}, "global_all_types": gat.to_dict(), } if getattr(self.config, "report_exact_by_id", False): pcbt, pcat, gbt, gat = _compute_counts_for_mode( component_spans, mode="exact_by_id", iou_thr=None ) variants["exact_by_id"] = { "per_component_by_type": { comp: {etype: cc.to_dict() for etype, cc in counts.items()} for comp, counts in pcbt.items() }, "per_component_all_types": { comp: ct.to_dict() for comp, ct in pcat.items() }, "global_by_type": {et: cc.to_dict() for et, cc in gbt.items()}, "global_all_types": gat.to_dict(), } for thr in self.config.report_iou_thresholds or []: try: t = float(thr) except Exception: continue pcbt, pcat, gbt, gat = _compute_counts_for_mode( component_spans, mode="iou", iou_thr=t ) variants[f"iou@{t:.2f}"] = { "per_component_by_type": { comp: {etype: cc.to_dict() for etype, cc in counts.items()} for comp, counts in pcbt.items() }, "per_component_all_types": { comp: ct.to_dict() for comp, ct in pcat.items() }, "global_by_type": {et: cc.to_dict() for et, cc in gbt.items()}, "global_all_types": gat.to_dict(), } # ID-aware IoU variant (entity resolution) when by-id reporting is enabled if getattr(self.config, "report_exact_by_id", False): pcbt_id, pcat_id, gbt_id, gat_id = _compute_counts_for_mode( component_spans, mode="iou_by_id", iou_thr=t ) variants[f"iou@{t:.2f}_by_id"] = { "per_component_by_type": { comp: {etype: cc.to_dict() for etype, cc in counts.items()} for comp, counts in pcbt_id.items() }, "per_component_all_types": { comp: ct.to_dict() for comp, ct in pcat_id.items() }, "global_by_type": {et: cc.to_dict() for et, cc in gbt_id.items()}, "global_all_types": gat_id.to_dict(), } evaluation: Dict[str, Any] = {"variants": variants} # Send variants to metric store if metric_store is not None and variants: record_variant = getattr(metric_store, "record_evaluation_variant", None) if callable(record_variant): for v_key, payload in variants.items(): record_variant( variant=v_key, per_component_by_type=payload["per_component_by_type"], per_component_all_types=payload["per_component_all_types"], global_by_type=payload["global_by_type"], global_all_types=payload["global_all_types"], ) return evaluation
def _evaluate_pairwise( self, preds: List[Span], golds: List[Span], *, mode: str, iou_threshold: float | None, ) -> Tuple[Dict[str, ConfusionCounts], ConfusionCounts]: # Group spans by type name def type_name(s: Span) -> str: raw = self._get_type_name(s) return self._canonical_type(raw) preds_by_type: Dict[str, List[Span]] = {} golds_by_type: Dict[str, List[Span]] = {} for p in preds: preds_by_type.setdefault(type_name(p), []).append(p) for g in golds: golds_by_type.setdefault(type_name(g), []).append(g) all_type_names = set(preds_by_type.keys()) | set(golds_by_type.keys()) by_type_counts: Dict[str, ConfusionCounts] = {} total = ConfusionCounts() for etype in sorted(all_type_names): p_list = preds_by_type.get(etype, []) g_list = golds_by_type.get(etype, []) matches, unmatched_p, unmatched_g = self._match_pred_to_gold_mode( p_list, g_list, mode=mode, iou_threshold=iou_threshold ) tp = len(matches) fp = len(unmatched_p) fn = len(unmatched_g) by_type_counts[etype] = ConfusionCounts(tp=tp, fp=fp, fn=fn) total.tp += tp total.fp += fp total.fn += fn return by_type_counts, total
def _hungarian_min_cost( cost: List[List[int]], ) -> List[Tuple[Optional[int], Optional[int]]]: """Minimal Hungarian algorithm for rectangular matrices with non-negative costs. Returns list of (row_index, col_index) for assigned pairs. Rows or columns may remain unassigned if sizes differ. For simplicity, this is a compact implementation suitable for small P,G (typical for per-document clusters). """ # Implementation adapted to small sizes: reduce rows/cols, then augment. if not cost or not cost[0]: return [] import math n_rows = len(cost) n_cols = len(cost[0]) # Pad to square by adding dummy rows/cols with 0 cost n = max(n_rows, n_cols) a = [row + [0] * (n - n_cols) for row in cost] + [ [0] * n for _ in range(n - n_rows) ] u = [0] * (n + 1) v = [0] * (n + 1) p = [0] * (n + 1) way = [0] * (n + 1) for i in range(1, n + 1): p[0] = i j0 = 0 minv = [math.inf] * (n + 1) used = [False] * (n + 1) while True: used[j0] = True i0 = p[j0] delta = math.inf j1 = 0 for j in range(1, n + 1): if used[j]: continue cur = a[i0 - 1][j - 1] - u[i0] - v[j] if cur < minv[j]: minv[j] = cur way[j] = j0 if minv[j] < delta: delta = minv[j] j1 = j for j in range(0, n + 1): if used[j]: u[p[j]] += delta v[j] -= delta else: minv[j] -= delta j0 = j1 if p[j0] == 0: break while True: j1 = way[j0] p[j0] = p[j1] j0 = j1 if j0 == 0: break assignment = [(-1, -1)] * n for j in range(1, n + 1): if p[j] != 0: assignment[p[j] - 1] = (p[j] - 1, j - 1) # Filter to original matrix bounds result: List[Tuple[Optional[int], Optional[int]]] = [] for i, j in assignment: if i == -1 or j == -1: result.append((None, None)) elif i < n_rows and j < n_cols: result.append((i, j)) else: result.append((None, None)) return result