Source code for veil.metric_store

import json
import math
import threading
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
from uuid import uuid4

from veil.config.metric_store import MetricStoreConfig
from veil.config.utils import dataclass_to_dict
from veil.core.span import Span
from veil.logger import init_logger

logger = init_logger(__name__)


[docs] class MetricStore: """File-system backed metric store initialization. Creates a unique run directory under the configured output directory to store all metric files for a single run, e.g.: <output_dir>/run-20250814T102234-ab12cd/ """ def __init__(self, config: MetricStoreConfig): self.config = config # ensure base output directory exists self.base_dir: Path = Path(self.config.output_dir).expanduser().resolve() self.base_dir.mkdir(parents=True, exist_ok=True) # unique per-run directory timestamp = datetime.now().strftime("%Y%m%dT%H%M%S") self.run_id: str = f"run-{timestamp}-{uuid4().hex[:6]}" self.run_dir: Path = self.base_dir / self.run_id # uuid-based run_id should be unique; avoid overwriting by failing if exists self.run_dir.mkdir(parents=False, exist_ok=False) # --- In-memory metric state --- # Counters self.documents_processed_total: int = 0 self.documents_failed_total: int = 0 self.spans_detected_total: int = 0 self.spans_masked_total: int = 0 # Errors per component/error_type self.errors_total: Dict[str, int] = {} # Entity-type breakdowns self.detected_by_entity_type: Dict[str, int] = {} self.masked_by_entity_type: Dict[str, int] = {} # Component timings: key=f"{component_type}:{component}" -> list[seconds] self.component_durations_seconds: Dict[str, List[float]] = {} # Normalized by current document length (seconds per char) self.component_durations_per_char: Dict[str, List[float]] = {} # Per-component span counts (raw and per-char normalized totals) # key: component_key, value: {entity_type -> total_count} self.component_span_counts_raw: Dict[str, Dict[str, int]] = {} # key: component_key, value: {entity_type -> sum(count/doc_length)} self.component_span_counts_per_char: Dict[str, Dict[str, float]] = {} # Supported entity types per component (names) self.component_supported_entities: Dict[str, List[str]] = {} # Concurrency primitives self._lock: threading.Lock = threading.Lock() # Per-thread document context self._doc_ctx = threading.local() # Per-document durations to compute throughput stats self.document_durations_seconds: List[float] = [] self.document_lengths_chars: List[int] = [] # Run timing baseline self.run_start_ts: float = time.time() # --- Evaluation (confusion matrix) aggregates --- # Variant evaluations only (e.g., exact, iou@0.50, ...) # Structure: variant -> same four maps: per_component_by_type, per_component_all_types, # global_by_type, global_all_types self.eval_variants: Dict[ str, Dict[ str, Dict[str, Dict[str, int]] | Dict[str, int], ], ] = {} # Persist a minimal manifest self._write_json( self.path_in_run_dir("manifest.json"), { "run_id": self.run_id, "created_at_epoch": self.run_start_ts, }, )
[docs] def path_in_run_dir(self, *relative: str) -> Path: """Return a path within the current run directory. Example: metric_store.path_in_run_dir("component_times.jsonl") """ return self.run_dir.joinpath(*relative)
[docs] def save_config(self, config_obj: Any, filename: str = "config.json") -> Path: """Save a JSON-serialisable view of the pipeline config into the run dir.""" config_dict = dataclass_to_dict(config_obj) target_path = self.path_in_run_dir(filename) with target_path.open("w", encoding="utf-8") as f: json.dump(config_dict, f, ensure_ascii=False, indent=2) f.write("\n") return target_path
# Public recording API
[docs] def start_document(self, doc_length_chars: int) -> None: # store per-thread context setattr(self._doc_ctx, "start_ts", time.time()) setattr(self._doc_ctx, "doc_len", int(doc_length_chars))
[docs] def end_document(self) -> None: start_ts = getattr(self._doc_ctx, "start_ts", None) doc_len = getattr(self._doc_ctx, "doc_len", None) if start_ts is None: return elapsed = time.time() - float(start_ts) # clear thread-local setattr(self._doc_ctx, "start_ts", None) setattr(self._doc_ctx, "doc_len", None) with self._lock: self.document_durations_seconds.append(elapsed) self.documents_processed_total += 1 if doc_len is not None: self.document_lengths_chars.append(int(doc_len)) self._write_metrics()
[docs] def record_component_step( self, *, component: str, component_type: str, duration_seconds: float, detected_spans: Optional[Iterable[Span]] = None, masked_spans: Optional[Iterable[Span]] = None, supported_entity_types: Optional[Iterable[str]] = None, ) -> None: key = f"{component_type}:{component}" doc_len = getattr(self._doc_ctx, "doc_len", None) with self._lock: self.component_durations_seconds.setdefault(key, []).append( float(duration_seconds) ) # Normalized duration per character if current document length is known and > 0 if doc_len and int(doc_len) > 0: norm = float(duration_seconds) / float(int(doc_len)) self.component_durations_per_char.setdefault(key, []).append(norm) if supported_entity_types is not None: sup_list = [str(s) for s in supported_entity_types] self.component_supported_entities[key] = sup_list if detected_spans is not None: total_count = 0 per_type_counts: Dict[str, int] = {} for span in detected_spans: etype = ( getattr(getattr(span, "entity_type", None), "name", None) or "UNKNOWN" ) # Only count supported types if we know them for this component if ( key in self.component_supported_entities and etype not in self.component_supported_entities[key] ): continue per_type_counts[etype] = per_type_counts.get(etype, 0) + 1 self.detected_by_entity_type[etype] = ( self.detected_by_entity_type.get(etype, 0) + 1 ) total_count += 1 # Update global total self.spans_detected_total += total_count # Update per-component totals raw_map = self.component_span_counts_raw.setdefault(key, {}) norm_map = self.component_span_counts_per_char.setdefault(key, {}) for etype, cnt in per_type_counts.items(): raw_map[etype] = raw_map.get(etype, 0) + cnt if doc_len and int(doc_len) > 0: norm_map[etype] = norm_map.get(etype, 0.0) + ( float(cnt) / float(int(doc_len)) ) if masked_spans is not None: count = 0 for span in masked_spans: etype = ( getattr(getattr(span, "entity_type", None), "name", None) or "UNKNOWN" ) self.masked_by_entity_type[etype] = ( self.masked_by_entity_type.get(etype, 0) + 1 ) count += 1 self.spans_masked_total += count
[docs] def record_error( self, *, component: str, component_type: str, error_type: str ) -> None: key = f"{component_type}:{component}:{error_type}" with self._lock: self.errors_total[key] = self.errors_total.get(key, 0) + 1
# Derived stats and snapshotting
[docs] def overall_duration_seconds(self) -> float: return max(0.0, time.time() - self.run_start_ts)
[docs] def mean_docs_per_second(self) -> float: total_time = sum(self.document_durations_seconds) if total_time <= 0.0: return 0.0 return float(self.documents_processed_total) / total_time
def _compute_quantiles( self, values: List[float], points: List[float] ) -> Dict[str, float]: if not values: return {f"p{int(p)}": 0.0 for p in points} sorted_vals = sorted(values) n = len(sorted_vals) out: Dict[str, float] = {} for p in points: # Nearest-rank method rank = max(1, math.ceil((p / 100.0) * n)) out[f"p{int(p)}"] = float(sorted_vals[rank - 1]) return out def _compute_min_max(self, values: List[float]) -> Dict[str, float]: if not values: return {"min": 0.0, "max": 0.0} return {"min": float(min(values)), "max": float(max(values))} def _compute_stats_pack( self, values: List[float], percentiles: List[float] ) -> Dict[str, float]: stats = {} stats.update(self._compute_min_max(values)) stats.update(self._compute_quantiles(values, percentiles)) return stats def _write_metrics(self) -> None: # Build component span metrics ensuring only supported entities are included when known component_spans: Dict[str, Dict[str, Dict[str, float | int]]] = {} for key, raw_counts in self.component_span_counts_raw.items(): supported = self.component_supported_entities.get(key) # Filter maps based on supported entities if provided if supported is not None: raw_filtered = { e: raw_counts.get(e, 0) for e in supported if e in raw_counts } norm_src = self.component_span_counts_per_char.get(key, {}) norm_filtered = { e: norm_src.get(e, 0.0) for e in supported if e in norm_src } else: raw_filtered = dict(raw_counts) norm_filtered = dict(self.component_span_counts_per_char.get(key, {})) component_spans[key] = { "raw_totals": raw_filtered, "per_char_totals": norm_filtered, } # Helper to compute precision/recall/F1 from counts def _derive_metrics(counts: Dict[str, int]) -> Dict[str, float | int]: tp = int(counts.get("tp", 0)) fp = int(counts.get("fp", 0)) fn = int(counts.get("fn", 0)) precision = float(tp) / float(tp + fp) if (tp + fp) > 0 else 0.0 recall = float(tp) / float(tp + fn) if (tp + fn) > 0 else 0.0 f1 = ( 2.0 * precision * recall / (precision + recall) if (precision + recall) > 0.0 else 0.0 ) enriched = dict(counts) enriched.update( { "precision": precision, "recall": recall, "f1": f1, } ) return enriched # Variants enriched eval_variants_enriched: Dict[ str, Dict[ str, Dict[str, Dict[str, float | int]] | Dict[str, float | int], ], ] = {} for variant, maps in self.eval_variants.items(): per_comp_by_type_raw = maps.get("per_component_by_type", {}) per_comp_all_types_raw = maps.get("per_component_all_types", {}) global_by_type_raw = maps.get("global_by_type", {}) global_all_types_raw = maps.get( "global_all_types", {"tp": 0, "fp": 0, "fn": 0} ) per_comp_by_type_enr: Dict[str, Dict[str, Dict[str, float | int]]] = {} for comp, by_type in per_comp_by_type_raw.items(): per_comp_by_type_enr[comp] = { etype: _derive_metrics(counts) for etype, counts in by_type.items() } per_comp_all_types_enr: Dict[str, Dict[str, float | int]] = { comp: _derive_metrics(counts) for comp, counts in per_comp_all_types_raw.items() } global_by_type_enr: Dict[str, Dict[str, float | int]] = { etype: _derive_metrics(counts) for etype, counts in global_by_type_raw.items() } global_all_types_enr: Dict[str, float | int] = _derive_metrics( global_all_types_raw # type: ignore[arg-type] ) eval_variants_enriched[variant] = { "per_component_by_type": per_comp_by_type_enr, "per_component_all_types": per_comp_all_types_enr, "global_by_type": global_by_type_enr, "global_all_types": global_all_types_enr, } metrics = { # Counters "veil_pipeline_documents_processed_total": self.documents_processed_total, "veil_pipeline_documents_failed_total": self.documents_failed_total, "veil_pipeline_spans_detected_total": self.spans_detected_total, "veil_pipeline_spans_masked_total": self.spans_masked_total, # Durations "veil_pipeline_overall_duration_seconds": self.overall_duration_seconds(), # Entity-type breakdowns "veil_pipeline_detected_spans_by_type_total": self.detected_by_entity_type, "veil_pipeline_masked_spans_by_type_total": self.masked_by_entity_type, # Throughput "veil_pipeline_mean_docs_per_second": self.mean_docs_per_second(), # Component duration aggregates: raw seconds and per-char seconds "veil_pipeline_component_duration_seconds": { key: { "raw_seconds": self._compute_stats_pack(values, [50, 95, 99]), "per_char_seconds": self._compute_stats_pack( self.component_durations_per_char.get(key, []), [50, 95, 99] ), } for key, values in self.component_durations_seconds.items() }, # Detected spans by component (entity detector): raw and per-char normalized totals "veil_pipeline_component_spans": component_spans, # Document length statistics (chars) "veil_pipeline_document_length_chars": self._compute_stats_pack( [float(x) for x in self.document_lengths_chars], [50, 90, 99] ), # Timestamping "run_id": self.run_id, "ts_epoch": time.time(), # Evaluation variants (only) "veil_eval_variants": eval_variants_enriched, } target = self.path_in_run_dir("metrics.json") with target.open("w", encoding="utf-8") as f: json.dump(metrics, f, ensure_ascii=False, indent=2) f.write("\n") # Plotting / finalization
[docs] def finalize(self) -> None: """Write final metrics and plots to the run directory.""" logger.info("Finalizing metrics and plots...") # Always emit the latest metrics snapshot with self._lock: self._write_metrics() # Then generate plots from in-memory buffers self._generate_plots() # And finally, generate a compact evaluation summary document from metrics.json try: self._generate_evaluation_summary_docs() except Exception: logger.exception("Failed to write evaluation summary docs.")
def _generate_plots(self) -> None: # Lazy-import plotting stack and use a non-interactive backend try: import matplotlib matplotlib.use("Agg", force=True) import matplotlib.pyplot as plt except Exception: # If plotting stack is unavailable, skip silently return plots_dir = self.path_in_run_dir("plots") plots_dir.mkdir(parents=True, exist_ok=True) # Subfolders for organization durations_dir = plots_dir / "durations" durations_components_dir = durations_dir / "components" entity_spans_dir = plots_dir / "entity_spans" entity_spans_components_dir = entity_spans_dir / "components" evaluation_dir = plots_dir / "evaluation" evaluation_aggregated_dir = evaluation_dir / "aggregated" evaluation_components_dir = evaluation_dir / "components" durations_dir.mkdir(parents=True, exist_ok=True) durations_components_dir.mkdir(parents=True, exist_ok=True) entity_spans_dir.mkdir(parents=True, exist_ok=True) entity_spans_components_dir.mkdir(parents=True, exist_ok=True) evaluation_dir.mkdir(parents=True, exist_ok=True) evaluation_aggregated_dir.mkdir(parents=True, exist_ok=True) evaluation_components_dir.mkdir(parents=True, exist_ok=True) # Organized subfolders durations_dir = plots_dir / "durations" durations_components_dir = durations_dir / "components" entity_spans_dir = plots_dir / "entity_spans" entity_spans_components_dir = entity_spans_dir / "components" evaluation_dir = plots_dir / "evaluation" evaluation_aggregated_dir = evaluation_dir / "aggregated" evaluation_components_dir = evaluation_dir / "components" durations_components_dir.mkdir(parents=True, exist_ok=True) entity_spans_components_dir.mkdir(parents=True, exist_ok=True) evaluation_aggregated_dir.mkdir(parents=True, exist_ok=True) evaluation_components_dir.mkdir(parents=True, exist_ok=True) def _sanitize(name: str) -> str: return "".join(c if c.isalnum() or c in ("_", "-") else "_" for c in name) def _plot_cdf( values: List[float], title: str, xlabel: str, outfile: Path ) -> None: if not values: return xs = sorted(values) n = len(xs) ys = [i / n for i in range(1, n + 1)] plt.figure(figsize=(6, 4)) plt.plot(xs, ys, drawstyle="steps-post") plt.xlabel(xlabel) plt.ylabel("CDF") plt.title(title) plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(outfile) plt.close() def _plot_bars( mapping: Dict[str, float], title: str, xlabel: str, outfile: Path ) -> None: if not mapping: return labels = sorted(mapping.keys()) values = [mapping[k] for k in labels] plt.figure(figsize=(max(6, len(labels) * 0.6), 4)) plt.bar(labels, values) plt.xlabel(xlabel) plt.ylabel("value") plt.title(title) plt.xticks(rotation=45, ha="right") plt.grid(True, axis="y", alpha=0.3) plt.tight_layout() plt.savefig(outfile) plt.close() # 1) Document-level CDFs _plot_cdf( self.document_durations_seconds, title="Document processing time CDF", xlabel="seconds", outfile=durations_dir / "doc_duration_cdf.png", ) _plot_cdf( [float(x) for x in self.document_lengths_chars], title="Document length CDF", xlabel="characters", outfile=durations_dir / "doc_length_cdf.png", ) # 2) Component duration CDFs (per component and aggregated) # Aggregated raw durations (overlay per component) if self.component_durations_seconds: try: import matplotlib.pyplot as plt # type: ignore[no-redef] plt.figure(figsize=(7, 5)) for key, values in self.component_durations_seconds.items(): if not values: continue xs = sorted(values) n = len(xs) ys = [i / n for i in range(1, n + 1)] plt.plot(xs, ys, drawstyle="steps-post", label=key) plt.xlabel("seconds") plt.ylabel("CDF") plt.title("Component duration CDFs (raw seconds)") plt.grid(True, alpha=0.3) plt.legend(fontsize="small") plt.tight_layout() plt.savefig(durations_dir / "components_duration_cdf.png") plt.close() except Exception: pass # Aggregated per-char durations (overlay per component) if self.component_durations_per_char: try: import matplotlib.pyplot as plt # type: ignore[no-redef] plt.figure(figsize=(7, 5)) for key, values in self.component_durations_per_char.items(): if not values: continue xs = sorted(values) n = len(xs) ys = [i / n for i in range(1, n + 1)] plt.plot(xs, ys, drawstyle="steps-post", label=key) plt.xlabel("seconds per char") plt.ylabel("CDF") plt.title("Component duration CDFs (normalized per char)") plt.grid(True, alpha=0.3) plt.legend(fontsize="small") plt.tight_layout() plt.savefig(durations_dir / "components_per_char_duration_cdf.png") plt.close() except Exception: pass # Per-component individual CDFs for key, values in self.component_durations_seconds.items(): _plot_cdf( values, title=f"{key} duration CDF", xlabel="seconds", outfile=durations_components_dir / f"comp_{_sanitize(key)}_duration_cdf.png", ) for key, values in self.component_durations_per_char.items(): _plot_cdf( values, title=f"{key} duration CDF (per char)", xlabel="seconds per char", outfile=durations_components_dir / f"comp_{_sanitize(key)}_per_char_duration_cdf.png", ) # 2b) Horizontal stacked bars: end-to-end by component for different statistics if self.component_durations_seconds: try: import matplotlib.pyplot as plt # type: ignore[no-redef] # Determine component order by numeric prefix in component id if present def _order_key(comp_key: str) -> int: try: # comp_key format: "type:NN-Name" -> extract NN after_colon = comp_key.split(":", 1)[1] idx_str = after_colon.split("-", 1)[0] return int(idx_str) except Exception: return 10_000 ordered_components = sorted( self.component_durations_seconds.keys(), key=_order_key ) # Compute stats per component comp_stats: Dict[str, Dict[str, float]] = {} for comp in ordered_components: vals = self.component_durations_seconds.get(comp, []) if not vals: comp_stats[comp] = { "min": 0.0, "p50": 0.0, "p95": 0.0, "p99": 0.0, "max": 0.0, "mean": 0.0, } continue stats_pack = self._compute_stats_pack(vals, [50, 95, 99]) comp_stats[comp] = { "min": stats_pack.get("min", 0.0), "p50": stats_pack.get("p50", 0.0), "p95": stats_pack.get("p95", 0.0), "p99": stats_pack.get("p99", 0.0), "max": stats_pack.get("max", 0.0), "mean": float(sum(vals)) / float(len(vals)), } # Create stacked bars for each requested statistic stat_names = ["min", "p50", "p95", "p99", "max"] y_positions = list(range(len(stat_names))) plt.figure(figsize=(8, max(4, len(stat_names) * 0.8))) left = [0.0 for _ in stat_names] for comp in ordered_components: widths = [comp_stats[comp][s] for s in stat_names] plt.barh(y_positions, widths, left=left, label=comp) left = [l + w for l, w in zip(left, widths)] plt.xlabel("seconds") plt.yticks(y_positions, stat_names) plt.title("Pipeline stacked component durations (per statistic)") plt.grid(True, axis="x", alpha=0.3) plt.legend( fontsize="small", loc="upper center", bbox_to_anchor=(0.5, -0.1), ncol=2, ) plt.tight_layout() plt.savefig(durations_dir / "pipeline_stacked_component_durations.png") plt.close() except Exception: pass # 3) Entity totals bars – global and per component _plot_bars( {k: float(v) for k, v in self.detected_by_entity_type.items()}, title="Detected spans by entity type (global)", xlabel="entity type", outfile=entity_spans_dir / "detected_spans_by_type.png", ) _plot_bars( {k: float(v) for k, v in self.masked_by_entity_type.items()}, title="Masked spans by entity type (global)", xlabel="entity type", outfile=entity_spans_dir / "masked_spans_by_type.png", ) for key, raw_counts in self.component_span_counts_raw.items(): _plot_bars( {k: float(v) for k, v in raw_counts.items()}, title=f"{key} detected spans (raw totals)", xlabel="entity type", outfile=entity_spans_components_dir / f"comp_{_sanitize(key)}_span_totals.png", ) for key, norm_counts in self.component_span_counts_per_char.items(): _plot_bars( norm_counts, title=f"{key} detected spans (per char totals)", xlabel="entity type", outfile=entity_spans_components_dir / f"comp_{_sanitize(key)}_span_per_char_totals.png", ) # 4) Quality (evaluation) plots def _derive(counts: Dict[str, int]) -> Dict[str, float]: tp = int(counts.get("tp", 0)) fp = int(counts.get("fp", 0)) fn = int(counts.get("fn", 0)) precision = float(tp) / float(tp + fp) if (tp + fp) > 0 else 0.0 recall = float(tp) / float(tp + fn) if (tp + fn) > 0 else 0.0 f1 = ( 2.0 * precision * recall / (precision + recall) if (precision + recall) > 0.0 else 0.0 ) return {"precision": precision, "recall": recall, "f1": f1} # Helper: grouped bars (series overlay per group) def _plot_grouped_bars( group_labels: List[str], series_to_values: Dict[str, List[float]], title: str, xlabel: str, ylabel: str, outfile: Path, ) -> None: if not group_labels or not series_to_values: return try: import matplotlib.pyplot as plt # type: ignore[no-redef] num_groups = len(group_labels) series_names = list(series_to_values.keys()) num_series = len(series_names) x = list(range(num_groups)) total_group_width = 0.8 bar_width = total_group_width / max(1, num_series) offsets = [ (-total_group_width / 2) + (i + 0.5) * bar_width for i in range(num_series) ] plt.figure(figsize=(max(6, len(group_labels) * 0.6), 4)) for idx, sname in enumerate(series_names): vals = series_to_values.get(sname, []) if len(vals) != num_groups: # pad or trim to fit vals = (vals + [0.0] * num_groups)[:num_groups] xs = [xi + offsets[idx] for xi in x] plt.bar(xs, vals, width=bar_width, label=sname) plt.xlabel(xlabel) plt.ylabel(ylabel) plt.title(title) plt.xticks(x, group_labels, rotation=45, ha="right") plt.grid(True, axis="y", alpha=0.3) plt.legend(fontsize="small") plt.tight_layout() plt.savefig(outfile) plt.close() except Exception: pass # Order variants: IoU thresholds ascending (by-id after type-only), then exact_by_id, then exact last def _variant_sort_key(vname: str) -> Tuple[int, float, int]: name = vname.lower() try: if name.startswith("iou@"): suffix = name.split("@", 1)[1] if "_by_id" in suffix: thr = float(suffix.split("_", 1)[0]) return (0, thr, 1) thr = float(suffix) return (0, thr, 0) except Exception: pass if name == "exact_by_id": return (1, 1e8, 1) if name == "exact": return (1, 1e9, 0) return (2, 0.0, 0) variant_names = sorted(list(self.eval_variants.keys()), key=_variant_sort_key) # Split into type-only vs by-id groups type_only_variants = [ v for v in variant_names if not v.lower().endswith("_by_id") and v != "exact_by_id" ] by_id_variants = [ v for v in variant_names if v.lower().endswith("_by_id") or v == "exact_by_id" ] # Global by type, combined across variants # Build union of entity types all_types_set: set[str] = set() for v in variant_names: by_type_map = self.eval_variants.get(v, {}).get("global_by_type", {}) all_types_set.update(list(by_type_map.keys())) type_labels = sorted(list(all_types_set)) # For each metric, build series per variant for metric_name in ("f1", "precision", "recall"): # Type-only plot series: Dict[str, List[float]] = {} for v in type_only_variants: by_type_map = self.eval_variants.get(v, {}).get("global_by_type", {}) vals: List[float] = [] for et in type_labels: counts = by_type_map.get(et) if not counts: vals.append(0.0) else: derived = _derive(counts) vals.append(float(derived.get(metric_name, 0.0))) series[v] = vals _plot_grouped_bars( type_labels, series, title=f"Global {metric_name.upper()} by entity type (type-only)", xlabel="entity type", ylabel=metric_name, outfile=evaluation_aggregated_dir / f"global_{metric_name}_by_type_type_only.png", ) # By-ID plot if by_id_variants: series_id: Dict[str, List[float]] = {} for v in by_id_variants: by_type_map = self.eval_variants.get(v, {}).get( "global_by_type", {} ) vals: List[float] = [] for et in type_labels: counts = by_type_map.get(et) if not counts: vals.append(0.0) else: derived = _derive(counts) vals.append(float(derived.get(metric_name, 0.0))) series_id[v] = vals _plot_grouped_bars( type_labels, series_id, title=f"Global {metric_name.upper()} by entity type (by-id)", xlabel="entity type", ylabel=metric_name, outfile=evaluation_aggregated_dir / f"global_{metric_name}_by_type_by_id.png", ) # Global all types: separate plots for type-only and by-id metrics = ("precision", "recall", "f1") group_labels = list(metrics) # Type-only series_values: Dict[str, List[float]] = {v: [] for v in type_only_variants} for v in type_only_variants: global_all = self.eval_variants.get(v, {}).get( "global_all_types", {"tp": 0, "fp": 0, "fn": 0} ) derived = _derive(global_all) # type: ignore[arg-type] for m in metrics: series_values[v].append(float(derived.get(m, 0.0))) _plot_grouped_bars( group_labels, series_values, title="Global quality metrics (type-only)", xlabel="metric", ylabel="value", outfile=evaluation_aggregated_dir / "global_quality_all_types_type_only.png", ) # By-ID if by_id_variants: series_values_id: Dict[str, List[float]] = {v: [] for v in by_id_variants} for v in by_id_variants: global_all = self.eval_variants.get(v, {}).get( "global_all_types", {"tp": 0, "fp": 0, "fn": 0} ) derived = _derive(global_all) # type: ignore[arg-type] for m in metrics: series_values_id[v].append(float(derived.get(m, 0.0))) _plot_grouped_bars( group_labels, series_values_id, title="Global quality metrics (by-id)", xlabel="metric", ylabel="value", outfile=evaluation_aggregated_dir / "global_quality_all_types_by_id.png", ) # Per component plots, combined across variants # 1) By type: F1, precision, recall comp_keys: set[str] = set() for v in variant_names: comp_keys.update( self.eval_variants.get(v, {}).get("per_component_by_type", {}).keys() ) for comp in sorted(comp_keys): # union of types for this component type_labels_comp: set[str] = set() for v in variant_names: by_type = ( self.eval_variants.get(v, {}) .get("per_component_by_type", {}) .get(comp, {}) ) type_labels_comp.update(list(by_type.keys())) type_labels_list = sorted(list(type_labels_comp)) for metric_name in ("f1", "precision", "recall"): # Type-only series: Dict[str, List[float]] = {} for v in type_only_variants: by_type = ( self.eval_variants.get(v, {}) .get("per_component_by_type", {}) .get(comp, {}) ) vals: List[float] = [] for et in type_labels_list: counts = by_type.get(et) if not counts: vals.append(0.0) else: derived = _derive(counts) vals.append(float(derived.get(metric_name, 0.0))) series[v] = vals _plot_grouped_bars( type_labels_list, series, title=f"{comp} {metric_name.upper()} by entity type (type-only)", xlabel="entity type", ylabel=metric_name, outfile=evaluation_components_dir / f"comp_{_sanitize(comp)}_{metric_name}_by_type_type_only.png", ) # By-ID if by_id_variants: series_id: Dict[str, List[float]] = {} for v in by_id_variants: by_type = ( self.eval_variants.get(v, {}) .get("per_component_by_type", {}) .get(comp, {}) ) vals: List[float] = [] for et in type_labels_list: counts = by_type.get(et) if not counts: vals.append(0.0) else: derived = _derive(counts) vals.append(float(derived.get(metric_name, 0.0))) series_id[v] = vals _plot_grouped_bars( type_labels_list, series_id, title=f"{comp} {metric_name.upper()} by entity type (by-id)", xlabel="entity type", ylabel=metric_name, outfile=evaluation_components_dir / f"comp_{_sanitize(comp)}_{metric_name}_by_type_by_id.png", ) # 2) Per component all types (group by metric, series = variants with exact last) comp_keys_all: set[str] = set() for v in variant_names: comp_keys_all.update( self.eval_variants.get(v, {}).get("per_component_all_types", {}).keys() ) for comp in sorted(comp_keys_all): # Type-only series_values: Dict[str, List[float]] = { vn: [] for vn in type_only_variants } for v in type_only_variants: counts = ( self.eval_variants.get(v, {}) .get("per_component_all_types", {}) .get(comp, {"tp": 0, "fp": 0, "fn": 0}) ) derived = _derive(counts) for m in metrics: # Will be appended per variant series for each metric group later series_values[v].append(float(derived.get(m, 0.0))) # Transpose series_values to have groups=metrics, series=variants # series_values currently: variant -> [prec, rec, f1] # _plot_grouped_bars expects: series -> list aligned with group_labels _plot_grouped_bars( list(metrics), series_values, title=f"{comp} quality metrics (type-only)", xlabel="metric", ylabel="value", outfile=evaluation_components_dir / f"comp_{_sanitize(comp)}_quality_all_types_type_only.png", ) # By-ID if by_id_variants: series_values_id: Dict[str, List[float]] = { vn: [] for vn in by_id_variants } for v in by_id_variants: counts = ( self.eval_variants.get(v, {}) .get("per_component_all_types", {}) .get(comp, {"tp": 0, "fp": 0, "fn": 0}) ) derived = _derive(counts) for m in metrics: # Will be appended per variant series for each metric group later series_values_id[v].append(float(derived.get(m, 0.0))) _plot_grouped_bars( list(metrics), series_values_id, title=f"{comp} quality metrics (by-id)", xlabel="metric", ylabel="value", outfile=evaluation_components_dir / f"comp_{_sanitize(comp)}_quality_all_types_by_id.png", ) def _generate_evaluation_summary_docs(self) -> None: """Create small Markdown and HTML summaries of global evaluation metrics. Reads the persisted metrics.json to ensure summaries reflect saved output. Produces: - evaluation_summary.md - evaluation_summary.html in the run directory. """ metrics_path = self.path_in_run_dir("metrics.json") if not metrics_path.exists(): return try: with metrics_path.open("r", encoding="utf-8") as f: payload = json.load(f) except Exception: return run_id = str(payload.get("run_id", "")) variants = payload.get("veil_eval_variants") or {} if not isinstance(variants, dict) or not variants: return def _vsort_key(vname: str) -> Tuple[int, float, int]: name = str(vname).lower() try: if name.startswith("iou@"): suffix = name.split("@", 1)[1] if "_by_id" in suffix: thr = float(suffix.split("_", 1)[0]) return (0, thr, 1) thr = float(suffix) return (0, thr, 0) except Exception: pass if name == "exact_by_id": return (1, 1e8, 1) if name == "exact": return (1, 1e9, 0) return (2, 0.0, 0) ordered_variants = sorted(list(variants.keys()), key=_vsort_key) type_only_variants = [ v for v in ordered_variants if not str(v).lower().endswith("_by_id") and str(v).lower() != "exact_by_id" ] by_id_variants = [ v for v in ordered_variants if str(v).lower().endswith("_by_id") or str(v).lower() == "exact_by_id" ] # Collect union of entity types present across variants all_types: set[str] = set() for v in ordered_variants: by_type = (variants.get(v) or {}).get("global_by_type") or {} if isinstance(by_type, dict): all_types.update([str(k) for k in by_type.keys()]) type_labels = sorted(all_types) def _num(x: Any, dflt: float = 0.0) -> float: try: return float(x) except Exception: return dflt # Build Markdown md_lines: List[str] = [] md_lines.append(f"# Evaluation summary\n") if run_id: md_lines.append(f"Run: `{run_id}`\n") md_lines.append("\n") # Global F1 by entity type (type-only) md_lines.append("## Global F1 by entity type (type-only)\n") if type_only_variants and type_labels: headers = ["Entity type"] + [str(v) for v in type_only_variants] md_lines.append("| " + " | ".join(headers) + " |") md_lines.append("| " + " | ".join(["---"] * len(headers)) + " |") for et in type_labels: row: List[str] = [et] for v in type_only_variants: f1 = _num( ((variants.get(v) or {}).get("global_by_type") or {}) .get(et, {}) .get("f1", 0.0) ) row.append(f"{f1:.3f}") md_lines.append("| " + " | ".join(row) + " |") else: md_lines.append("_No type-only variants available._\n") md_lines.append("\n") # Global F1 by entity type (by-id) md_lines.append("## Global F1 by entity type (by-id)\n") if by_id_variants and type_labels: headers = ["Entity type"] + [str(v) for v in by_id_variants] md_lines.append("| " + " | ".join(headers) + " |") md_lines.append("| " + " | ".join(["---"] * len(headers)) + " |") for et in type_labels: row = [et] for v in by_id_variants: f1 = _num( ((variants.get(v) or {}).get("global_by_type") or {}) .get(et, {}) .get("f1", 0.0) ) row.append(f"{f1:.3f}") md_lines.append("| " + " | ".join(row) + " |") else: md_lines.append("_No by-id variants available._\n") md_lines.append("\n") # Global quality metrics (all types) - type-only md_lines.append("## Global quality metrics (all types, type-only)\n") if type_only_variants: md_lines.append("| Variant | TP | FP | FN | P | R | F1 |") md_lines.append("| --- | ---:| ---:| ---:| ---:| ---:| ---:|") for v in type_only_variants: counts = (variants.get(v) or {}).get("global_all_types") or {} tp = int(counts.get("tp", 0)) fp = int(counts.get("fp", 0)) fn = int(counts.get("fn", 0)) p = _num(counts.get("precision", 0.0)) r = _num(counts.get("recall", 0.0)) f1 = _num(counts.get("f1", 0.0)) md_lines.append( f"| {v} | {tp} | {fp} | {fn} | {p:.3f} | {r:.3f} | {f1:.3f} |" ) else: md_lines.append("_No type-only variants available._\n") md_lines.append("\n") # Global quality metrics (all types) - by-id md_lines.append("## Global quality metrics (all types, by-id)\n") if by_id_variants: md_lines.append("| Variant | TP | FP | FN | P | R | F1 |") md_lines.append("| --- | ---:| ---:| ---:| ---:| ---:| ---:|") for v in by_id_variants: counts = (variants.get(v) or {}).get("global_all_types") or {} tp = int(counts.get("tp", 0)) fp = int(counts.get("fp", 0)) fn = int(counts.get("fn", 0)) p = _num(counts.get("precision", 0.0)) r = _num(counts.get("recall", 0.0)) f1 = _num(counts.get("f1", 0.0)) md_lines.append( f"| {v} | {tp} | {fp} | {fn} | {p:.3f} | {r:.3f} | {f1:.3f} |" ) else: md_lines.append("_No by-id variants available._\n") md_lines.append("\n") md_text = "\n".join(md_lines) + "\n" md_path = self.path_in_run_dir("evaluation_summary.md") with md_path.open("w", encoding="utf-8") as f: f.write(md_text) # Basic HTML rendition def _html_escape(s: str) -> str: try: import html as _html return _html.escape(s) except Exception: return s def _md_table_to_html(md_table_lines: List[str]) -> str: rows: List[str] = [] for idx, line in enumerate(md_table_lines): if not line.startswith("| ") or " |" not in line: continue parts = [p.strip() for p in line.strip().strip("|").split("|")] if idx == 0: # header ths = "".join([f"<th>{_html_escape(p)}</th>" for p in parts]) rows.append(f"<thead><tr>{ths}</tr></thead>") elif idx == 1: # separator row, skip continue else: tds = "".join([f"<td>{_html_escape(p)}</td>" for p in parts]) rows.append(f"<tr>{tds}</tr>") body_rows = [r for r in rows if not r.startswith("<thead>")] thead = next((r for r in rows if r.startswith("<thead>")), "") return f"<table>{thead}<tbody>{''.join(body_rows)}</tbody></table>" # Extract individual table blocks from markdown sections def _extract_section(title: str) -> List[str]: out: List[str] = [] started = False for ln in md_lines: if ln.strip() == title: started = True continue if started and ln.startswith("## "): break if started: out.append(ln) return [l for l in out if l.strip()] html_parts: List[str] = [] html_parts.append( "<html><head><meta charset='utf-8'><style>body{font-family:-apple-system,Segoe UI,Roboto,Helvetica,Arial,sans-serif;margin:16px}h1,h2{margin:0 0 8px}table{border-collapse:collapse;margin:8px 0}th,td{border:1px solid #ddd;padding:6px 8px;font-size:13px}th{background:#f7f7f7;text-align:left}</style></head><body>" ) html_parts.append(f"<h1>Evaluation summary</h1>") if run_id: html_parts.append( f"<div><strong>Run:</strong> <code>{_html_escape(run_id)}</code></div>" ) def _section_to_html(sec_title_md: str, heading_html: str) -> None: tbl = _extract_section(sec_title_md) if not any(l.startswith("| ") for l in tbl): html_parts.append( f"<h2>{heading_html}</h2><div><em>No data.</em></div>" ) return html_parts.append(f"<h2>{heading_html}</h2>") html_parts.append(_md_table_to_html([l for l in tbl if l.startswith("| ")])) _section_to_html( "## Global F1 by entity type (type-only)", "Global F1 by entity type (type-only)", ) _section_to_html( "## Global F1 by entity type (by-id)", "Global F1 by entity type (by-id)" ) _section_to_html( "## Global quality metrics (all types, type-only)", "Global quality metrics (all types, type-only)", ) _section_to_html( "## Global quality metrics (all types, by-id)", "Global quality metrics (all types, by-id)", ) html_parts.append("</body></html>") html_text = "".join(html_parts) html_path = self.path_in_run_dir("evaluation_summary.html") with html_path.open("w", encoding="utf-8") as f: f.write(html_text) # Utils def _write_json(self, path: Path, payload: Any) -> None: with path.open("w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) f.write("\n") # ----------------------------- # Evaluation recording API # ----------------------------- def _merge_confusion(self, a: Dict[str, int], b: Dict[str, int]) -> Dict[str, int]: return { "tp": int(a.get("tp", 0)) + int(b.get("tp", 0)), "fp": int(a.get("fp", 0)) + int(b.get("fp", 0)), "fn": int(a.get("fn", 0)) + int(b.get("fn", 0)), }
[docs] def record_evaluation_variant( self, *, variant: str, per_component_by_type: Dict[str, Dict[str, Dict[str, int]]], per_component_all_types: Dict[str, Dict[str, int]], global_by_type: Dict[str, Dict[str, int]], global_all_types: Dict[str, int], ) -> None: """Record evaluation counts for a named variant (e.g., 'exact', 'iou@0.50').""" with self._lock: v = self.eval_variants.setdefault( str(variant), { "per_component_by_type": {}, "per_component_all_types": {}, "global_by_type": {}, "global_all_types": {"tp": 0, "fp": 0, "fn": 0}, }, ) # Per-component by-type v_pcbt = v["per_component_by_type"] # type: ignore[index] for comp, mapping in per_component_by_type.items(): comp_map = v_pcbt.setdefault(comp, {}) # type: ignore[assignment] for etype, counts in mapping.items(): existing = comp_map.get(etype, {"tp": 0, "fp": 0, "fn": 0}) comp_map[etype] = self._merge_confusion(existing, counts) # Per-component all-types v_pcat = v["per_component_all_types"] # type: ignore[index] for comp, counts in per_component_all_types.items(): existing = v_pcat.get(comp, {"tp": 0, "fp": 0, "fn": 0}) v_pcat[comp] = self._merge_confusion(existing, counts) # Global by-type v_gbt = v["global_by_type"] # type: ignore[index] for etype, counts in global_by_type.items(): existing = v_gbt.get(etype, {"tp": 0, "fp": 0, "fn": 0}) v_gbt[etype] = self._merge_confusion(existing, counts) # Global all-types v_gat = v["global_all_types"] # type: ignore[index] v["global_all_types"] = self._merge_confusion(v_gat, global_all_types) # type: ignore[index]