File size: 11,233 Bytes
48155ff
 
afecdc5
159faf0
9988b25
7793bb6
afecdc5
159faf0
9988b25
 
0a7f9b4
 
7793bb6
afecdc5
9988b25
 
 
 
159faf0
9988b25
 
 
 
 
 
afecdc5
48155ff
 
 
 
9988b25
 
 
48155ff
7793bb6
159faf0
9988b25
7793bb6
afecdc5
7793bb6
32e4125
 
 
afecdc5
32e4125
 
 
 
 
 
 
9988b25
 
 
 
 
 
 
 
 
 
0a7f9b4
32e4125
159faf0
 
 
 
 
 
7793bb6
48155ff
9988b25
 
7793bb6
 
159faf0
48155ff
9988b25
48155ff
159faf0
7793bb6
 
9988b25
 
 
 
 
48155ff
7793bb6
48155ff
7793bb6
48155ff
7793bb6
48155ff
 
 
9988b25
 
 
 
 
159faf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15f6c83
159faf0
 
 
15f6c83
159faf0
 
 
 
 
 
 
9988b25
 
48155ff
 
9988b25
7793bb6
9988b25
7793bb6
9988b25
48155ff
 
 
 
9988b25
 
7793bb6
0a7f9b4
afecdc5
9988b25
afecdc5
 
7793bb6
afecdc5
9988b25
48155ff
0a7f9b4
 
48155ff
7793bb6
48155ff
afecdc5
7793bb6
0a7f9b4
9988b25
 
 
159faf0
 
 
 
 
9988b25
 
 
 
 
 
159faf0
9988b25
7e43525
 
 
 
9988b25
0a7f9b4
7793bb6
48155ff
 
7793bb6
0a7f9b4
 
 
 
9988b25
 
0a7f9b4
 
159faf0
 
afecdc5
 
0a7f9b4
48155ff
7793bb6
afecdc5
0a7f9b4
 
9988b25
 
 
0a7f9b4
 
 
7793bb6
0a7f9b4
48155ff
9988b25
7793bb6
afecdc5
48155ff
afecdc5
 
 
 
159faf0
afecdc5
 
0a7f9b4
7793bb6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""Embedding service: lazy-loading sentence-transformers wrapper."""

import logging
import os
from typing import Dict, List, Optional, Tuple

import numpy as np
import onnxruntime as ort
from optimum.onnxruntime import ORTModelForFeatureExtraction
from transformers import AutoTokenizer, PreTrainedTokenizer

from src.utils.memory_utils import log_memory_checkpoint, memory_monitor


def mean_pooling(model_output, attention_mask: np.ndarray) -> np.ndarray:
    """Mean Pooling - Take attention mask into account for correct averaging."""
    token_embeddings = model_output.last_hidden_state
    input_mask_expanded = (
        np.expand_dims(attention_mask, axis=-1).repeat(token_embeddings.shape[-1], axis=-1).astype(float)
    )
    sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
    sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), a_min=1e-9, a_max=None)
    return sum_embeddings / sum_mask


class EmbeddingService:
    """HuggingFace sentence-transformers wrapper for generating embeddings.

    Uses lazy loading and a class-level cache to avoid repeated expensive model
    loads and to minimize memory footprint at startup.

    This version is optimized to use a quantized ONNX model for lower memory
    footprint.
    """

    _model_cache: Dict[str, Tuple[ORTModelForFeatureExtraction, PreTrainedTokenizer]] = {}
    _quantized_model_name = "optimum/all-MiniLM-L6-v2"

    def __init__(
        self,
        model_name: Optional[str] = None,
        device: Optional[str] = None,
        batch_size: Optional[int] = None,
    ):
        # Import config values as defaults
        from src.config import (
            EMBEDDING_BATCH_SIZE,
            EMBEDDING_DEVICE,
            EMBEDDING_MODEL_NAME,
        )

        # The original model name is kept for reference. Use quantized model only
        # when explicitly enabled via configuration (to avoid breaking tests).
        self.original_model_name = model_name or EMBEDDING_MODEL_NAME
        from src.config import EMBEDDING_USE_QUANTIZED

        if EMBEDDING_USE_QUANTIZED:
            self.model_name = self._quantized_model_name
        else:
            # Keep the model name as originally requested for compatibility
            self.model_name = self.original_model_name
        self.device = device or EMBEDDING_DEVICE or "cpu"
        self.batch_size = batch_size or EMBEDDING_BATCH_SIZE
        # Max tokens (sequence length) to bound memory; configurable via env
        # EMBEDDING_MAX_TOKENS (default 512)
        try:
            self.max_tokens = int(os.getenv("EMBEDDING_MAX_TOKENS", "512"))
        except ValueError:
            self.max_tokens = 512

        # Lazy loading - don't load model at initialization
        self.model: Optional[ORTModelForFeatureExtraction] = None
        self.tokenizer: Optional[PreTrainedTokenizer] = None

        logging.info(
            "Initialized EmbeddingService: model=%s base=%s device=%s max_tokens=%s",
            self.model_name,
            self.original_model_name,
            self.device,
            getattr(self, "max_tokens", "unset"),
        )

    def _ensure_model_loaded(
        self,
    ) -> Tuple[ORTModelForFeatureExtraction, PreTrainedTokenizer]:
        """Ensure the quantized ONNX model and tokenizer are loaded."""
        if self.model is None or self.tokenizer is None:
            import gc

            gc.collect()

            cache_key = f"{self.model_name}_{self.device}"

            if cache_key not in self._model_cache:
                log_memory_checkpoint("before_model_load")
                logging.info(
                    "Loading quantized model '%s' and tokenizer...",
                    self.model_name,
                )
                # Use the original model's tokenizer
                tokenizer = AutoTokenizer.from_pretrained(self.original_model_name)
                # Load the quantized model from Optimum Hugging Face Hub.
                # Some model repos contain multiple ONNX export files; we select a default explicitly.
                provider = "CPUExecutionProvider" if self.device == "cpu" else "CUDAExecutionProvider"
                file_name = os.getenv("EMBEDDING_ONNX_FILE", "model.onnx")
                local_dir = os.getenv("EMBEDDING_ONNX_LOCAL_DIR")
                if local_dir and os.path.isdir(local_dir):
                    # Attempt to load from a local exported directory first.
                    try:
                        logging.info(
                            "Attempting local ONNX load from %s (file=%s)",
                            local_dir,
                            file_name,
                        )
                        model = ORTModelForFeatureExtraction.from_pretrained(
                            local_dir,
                            provider=provider,
                            file_name=file_name,
                        )
                        logging.info("Loaded ONNX model from local directory '%s'", local_dir)
                    except Exception as e:
                        logging.warning(
                            "Local ONNX load failed (%s); " "falling back to hub repo '%s'",
                            e,
                            self.model_name,
                        )
                        local_dir = None  # disable local path for subsequent attempts
                if not local_dir:
                    # Configure ONNX Runtime threading for constrained CPU
                    intra = int(os.getenv("ORT_INTRA_OP_NUM_THREADS", "1"))
                    inter = int(os.getenv("ORT_INTER_OP_NUM_THREADS", "1"))
                    so = ort.SessionOptions()
                    so.intra_op_num_threads = intra
                    so.inter_op_num_threads = inter
                    try:
                        model = ORTModelForFeatureExtraction.from_pretrained(
                            self.model_name,
                            provider=provider,
                            file_name=file_name,
                            session_options=so,
                        )
                        logging.info(
                            "Loaded ONNX model file '%s' (intra=%d, inter=%d)",
                            file_name,
                            intra,
                            inter,
                        )
                    except Exception as e:
                        logging.warning(
                            "Explicit ONNX file '%s' failed (%s); " "retrying with auto-selection.",
                            file_name,
                            e,
                        )
                        # The key change: we now pass the file_name to the fallback as well
                        model = ORTModelForFeatureExtraction.from_pretrained(
                            self.model_name,
                            provider=provider,
                            file_name=file_name,  # Added this line
                            session_options=so,
                        )
                        logging.info(
                            "Loaded ONNX model using auto-selection fallback " "(intra=%d, inter=%d)",
                            intra,
                            inter,
                        )
                self._model_cache[cache_key] = (model, tokenizer)
                logging.info("Quantized model and tokenizer loaded successfully")
                log_memory_checkpoint("after_model_load")
            else:
                logging.info("Using cached quantized model '%s'", self.model_name)

            self.model, self.tokenizer = self._model_cache[cache_key]

        return self.model, self.tokenizer

    @memory_monitor
    def embed_text(self, text: str) -> List[float]:
        """Generate embedding for a single text."""
        embeddings = self.embed_texts([text])
        return embeddings[0]

    @memory_monitor
    def embed_texts(self, texts: List[str]) -> List[List[float]]:
        """Generate embeddings for multiple texts in batches using ONNX model."""
        if not texts:
            return []

        try:
            model, tokenizer = self._ensure_model_loaded()

            log_memory_checkpoint("before_batch_embedding")

            processed_texts: List[str] = [t if t.strip() else " " for t in texts]

            all_embeddings: List[List[float]] = []
            for i in range(0, len(processed_texts), self.batch_size):
                batch_texts = processed_texts[i : i + self.batch_size]
                log_memory_checkpoint(f"batch_start_{i}//{self.batch_size}")

                # Tokenize sentences
                encoded_input = tokenizer(
                    batch_texts,
                    padding=True,
                    truncation=True,
                    max_length=self.max_tokens,
                    return_tensors="np",
                )

                # Compute token embeddings
                model_output = model(**encoded_input)

                # Perform pooling
                sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])

                # Normalize embeddings (L2) using pure NumPy to avoid torch dependency
                norms = np.linalg.norm(sentence_embeddings, axis=1, keepdims=True)
                norms = np.clip(norms, 1e-12, None)
                batch_embeddings = sentence_embeddings / norms

                log_memory_checkpoint(f"batch_end_{i}//{self.batch_size}")

                for emb in batch_embeddings:
                    all_embeddings.append(emb.tolist())

                import gc

                del batch_embeddings
                del batch_texts
                del encoded_input
                del model_output
                gc.collect()

            if os.getenv("LOG_DETAIL", "verbose") == "verbose":
                logging.info("Generated embeddings for %d texts", len(texts))
            return all_embeddings
        except Exception as e:
            logging.error("Failed to generate embeddings for texts: %s", e)
            raise

    def get_embedding_dimension(self) -> int:
        """Get the dimension of embeddings produced by this model."""
        try:
            model, _ = self._ensure_model_loaded()
            # The dimension can be found in the model's config
            return int(model.config.hidden_size)
        except Exception:
            logging.debug("Failed to get embedding dimension; returning 0")
            return 0

    def encode_batch(self, texts: List[str]) -> List[List[float]]:
        """Convenience wrapper that returns embeddings for a list of texts."""
        return self.embed_texts(texts)

    def similarity(self, text1: str, text2: str) -> float:
        """Cosine similarity between embeddings of two texts."""
        try:
            embeddings = self.embed_texts([text1, text2])
            embed1 = np.array(embeddings[0])
            embed2 = np.array(embeddings[1])
            similarity = np.dot(embed1, embed2) / (np.linalg.norm(embed1) * np.linalg.norm(embed2))
            return float(similarity)
        except Exception as e:
            logging.error("Failed to calculate similarity: %s", e)
            return 0.0