Spaces:
Runtime error
Runtime error
| from pymilvus import MilvusClient, DataType | |
| import numpy as np | |
| import concurrent.futures | |
| class MilvusManager: | |
| def __init__(self, milvus_uri, collection_name, create_collection, dim=128): | |
| self.client = MilvusClient(uri=milvus_uri) | |
| self.collection_name = collection_name | |
| if self.client.has_collection(collection_name=self.collection_name): | |
| self.client.load_collection(collection_name) | |
| self.dim = dim | |
| if create_collection: | |
| self.create_collection() | |
| self.create_index() | |
| def create_collection(self): | |
| if self.client.has_collection(collection_name=self.collection_name): | |
| self.client.drop_collection(collection_name=self.collection_name) | |
| schema = self.client.create_schema( | |
| auto_id=True, | |
| enable_dynamic_fields=True, | |
| ) | |
| schema.add_field(field_name="pk", datatype=DataType.INT64, is_primary=True) | |
| schema.add_field( | |
| field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=self.dim | |
| ) | |
| schema.add_field(field_name="seq_id", datatype=DataType.INT16) | |
| schema.add_field(field_name="doc_id", datatype=DataType.INT64) | |
| schema.add_field(field_name="doc", datatype=DataType.VARCHAR, max_length=65535) | |
| self.client.create_collection( | |
| collection_name=self.collection_name, schema=schema | |
| ) | |
| def create_index(self): | |
| self.client.release_collection(collection_name=self.collection_name) | |
| self.client.drop_index( | |
| collection_name=self.collection_name, index_name="vector" | |
| ) | |
| index_params = self.client.prepare_index_params() | |
| index_params.add_index( | |
| field_name="vector", | |
| index_name="vector_index", | |
| index_type="FLAT", | |
| metric_type="IP", | |
| params={ | |
| "M": 16, | |
| "efConstruction": 500, | |
| }, | |
| ) | |
| self.client.create_index( | |
| collection_name=self.collection_name, index_params=index_params, sync=True | |
| ) | |
| def create_scalar_index(self): | |
| self.client.release_collection(collection_name=self.collection_name) | |
| index_params = self.client.prepare_index_params() | |
| index_params.add_index( | |
| field_name="doc_id", | |
| index_name="int32_index", | |
| index_type="INVERTED", | |
| ) | |
| self.client.create_index( | |
| collection_name=self.collection_name, index_params=index_params, sync=True | |
| ) | |
| def search(self, data, topk): | |
| search_params = {"metric_type": "IP", "params": {}} | |
| results = self.client.search( | |
| self.collection_name, | |
| data, | |
| limit=int(50), | |
| output_fields=["vector", "seq_id", "doc_id"], | |
| search_params=search_params, | |
| ) | |
| doc_ids = set() | |
| for r_id in range(len(results)): | |
| for r in range(len(results[r_id])): | |
| doc_ids.add(results[r_id][r]["entity"]["doc_id"]) | |
| scores = [] | |
| def rerank_single_doc(doc_id, data, client, collection_name): | |
| doc_colbert_vecs = client.query( | |
| collection_name=collection_name, | |
| filter=f"doc_id in [{doc_id}, {doc_id + 1}]", | |
| output_fields=["seq_id", "vector", "doc"], | |
| limit=1000, | |
| ) | |
| doc_vecs = np.vstack( | |
| [doc_colbert_vecs[i]["vector"] for i in range(len(doc_colbert_vecs))] | |
| ) | |
| score = np.dot(data, doc_vecs.T).max(1).sum() | |
| return (score, doc_id) | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=300) as executor: | |
| futures = { | |
| executor.submit( | |
| rerank_single_doc, doc_id, data, self.client, self.collection_name | |
| ): doc_id | |
| for doc_id in doc_ids | |
| } | |
| for future in concurrent.futures.as_completed(futures): | |
| score, doc_id = future.result() | |
| scores.append((score, doc_id)) | |
| scores.sort(key=lambda x: x[0], reverse=True) | |
| if len(scores) >= topk: | |
| return scores[:topk] | |
| else: | |
| return scores | |
| def insert(self, data): | |
| colbert_vecs = [vec for vec in data["colbert_vecs"]] | |
| seq_length = len(colbert_vecs) | |
| doc_ids = [data["doc_id"] for i in range(seq_length)] | |
| seq_ids = list(range(seq_length)) | |
| docs = [""] * seq_length | |
| docs[0] = data["filepath"] | |
| self.client.insert( | |
| self.collection_name, | |
| [ | |
| { | |
| "vector": colbert_vecs[i], | |
| "seq_id": seq_ids[i], | |
| "doc_id": doc_ids[i], | |
| "doc": docs[i], | |
| } | |
| for i in range(seq_length) | |
| ], | |
| ) | |
| def get_images_as_doc(self, images_with_vectors:list): | |
| images_data = [] | |
| for i in range(len(images_with_vectors)): | |
| data = { | |
| "colbert_vecs": images_with_vectors[i]["colbert_vecs"], | |
| "doc_id": i, | |
| "filepath": images_with_vectors[i]["filepath"], | |
| } | |
| images_data.append(data) | |
| return images_data | |
| def insert_images_data(self, image_data): | |
| data = self.get_images_as_doc(image_data) | |
| for i in range(len(data)): | |
| self.insert(data[i]) | |