| class SentenceTransformersCrossEncoder: | |
| """Wrapper for sentence-transformers cross-encoder model. | |
| """ | |
| def __init__( | |
| self, model_name_or_path: str = "cross-encoder/ms-marco-MiniLM-L-12-v2" | |
| ): | |
| try: | |
| from sentence_transformers.cross_encoder import CrossEncoder | |
| except ImportError: | |
| raise ModuleNotFoundError( | |
| "You need to install sentence-transformers library to use SentenceTransformersCrossEncoder." | |
| ) | |
| self.model = CrossEncoder(model_name_or_path) | |
| def __call__(self, query: str, passage: list[str]) -> list[float]: | |
| return self.model.predict([[query, p] for p in passage]).tolist() | |