Source code for veil.entity_detectors.gliner.gliner_entity_detector

from __future__ import annotations

from typing import List, Optional, Tuple

from veil.config.entity_detectors import GlinerEntityDetectorConfig
from veil.core.base_entity_detector import BaseEntityDetector
from veil.core.document import Document
from veil.core.span import Span
from veil.logger import init_logger

from .gliner_entity_type import GlinerEntityType

logger = init_logger(__name__)


[docs] class GlinerEntityDetector(BaseEntityDetector[GlinerEntityType]): """Adapter over Gliner entity detection models with support for chunking long documents.""" ENTITY_TYPES = {e for e in GlinerEntityType} def __init__(self, config: GlinerEntityDetectorConfig): super().__init__(config) try: from gliner import GLiNER except ImportError as e: raise RuntimeError( "GLiNER library not available. Please install it using: pip install gliner" ) from e # Load the model device = "cuda" if config.cuda_device >= 0 else "cpu" logger.info(f"Loading Gliner model from {config.model} on {device}") self.model = GLiNER.from_pretrained(config.model, map_location=device) # Get the tokenizer for chunking self.tokenizer = self.model.data_processor.transformer_tokenizer # Build alias mapping for label conversion self._alias_map = GlinerEntityType._build_alias_map_for_subclass() def _map_label_to_entity(self, label: str) -> Optional[GlinerEntityType]: """Map a GLiNER output label to a GlinerEntityType using the alias system.""" key = label.upper() canonical_name = self._alias_map.get(key) if canonical_name: try: return GlinerEntityType[canonical_name] except KeyError: return None return None def _create_chunks(self, text: str) -> List[Tuple[str, int, int]]: """Create overlapping chunks of text with character offsets. Args: text: Input text to chunk Returns: List of tuples (chunk_text, start_char, end_char) where start_char and end_char are the character positions in the original text """ # Tokenize the full text encoding = self.tokenizer( text, return_offsets_mapping=True, add_special_tokens=False, truncation=False, ) tokens = encoding.tokens() offsets = encoding["offset_mapping"] # If text fits within max_length, return it as a single chunk if len(tokens) <= self.config.max_length: return [(text, 0, len(text))] chunks = [] chunk_size = self.config.max_length overlap = self.config.chunk_overlap i = 0 while i < len(tokens): # Determine chunk end chunk_end = min(i + chunk_size, len(tokens)) # Get character positions for this chunk start_char = offsets[i][0] end_char = offsets[chunk_end - 1][1] # Extract chunk text chunk_text = text[start_char:end_char] chunks.append((chunk_text, start_char, end_char)) # Move to next chunk with overlap if chunk_end >= len(tokens): break i = chunk_end - overlap return chunks def _merge_overlapping_entities(self, spans: List[Span]) -> List[Span]: """Deduplicate entities that appear in overlapping chunks. Keeps the entity with the highest confidence score when duplicates are found. """ if not spans: return [] # Sort by start position sorted_spans = sorted(spans, key=lambda s: (s.start, s.end)) merged = [] i = 0 while i < len(sorted_spans): current = sorted_spans[i] j = i + 1 # Find all overlapping spans with the same entity type duplicates = [current] while j < len(sorted_spans): next_span = sorted_spans[j] # Check if spans overlap significantly (IoU > 0.5) and same type if ( next_span.entity_type == current.entity_type and self._calculate_iou(current, next_span) > self.config.nms_iou_threshold ): duplicates.append(next_span) j += 1 else: break # Keep the one with highest confidence best = max(duplicates, key=lambda s: s.confidence or 0.0) merged.append(best) i = j return merged def _calculate_iou(self, span1: Span, span2: Span) -> float: """Calculate Intersection over Union for two spans.""" # Calculate intersection start = max(span1.start, span2.start) end = min(span1.end, span2.end) intersection = max(0, end - start) # Calculate union union = (span1.end - span1.start) + (span2.end - span2.start) - intersection if union == 0: return 0.0 return intersection / union
[docs] def detect_entities(self, doc: Document) -> List[Span]: """Detect entities in the document using GLiNER model with chunking support. Long documents are automatically split into overlapping chunks to handle the model's maximum sequence length limitation. Args: doc: Document to process Returns: List of Span objects with detected entities """ text = doc.text or "" if not text: return [] # Get labels from config labels = self.config.labels # Create chunks chunks = self._create_chunks(text) # Process each chunk all_spans: List[Span] = [] for chunk_text, start_offset, _ in chunks: # Run GLiNER inference on this chunk entities = self.model.predict_entities( chunk_text, labels, threshold=self.config.threshold, ) # Convert GLiNER entities to Span objects filtered = [] for entity in entities: # Map the label to our entity type entity_type = self._map_label_to_entity(entity["label"]) # Only include entities with recognized types if entity_type is not None: # Adjust offsets to match original document start = entity["start"] + start_offset end = entity["end"] + start_offset replacement = entity["text"] # Basic length filters if end <= start: continue span_len = end - start if ( span_len < self.config.min_span_chars or span_len > self.config.max_span_chars ): continue filtered.append( Span( start=start, end=end, entity_type=entity_type, replacement=replacement, confidence=entity.get("score"), ) ) # Keep top-K per chunk to limit over-detection if ( self.config.top_k_per_chunk and len(filtered) > self.config.top_k_per_chunk ): filtered = sorted( filtered, key=lambda s: s.confidence or 0.0, reverse=True )[: self.config.top_k_per_chunk] all_spans.extend(filtered) # Merge overlapping entities from different chunks merged_spans = self._merge_overlapping_entities(all_spans) return merged_spans