File size: 4,150 Bytes
bf97bdc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# app.py (FastAPI server to host the Jina Embedding model)
# Must be set before importing Hugging Face libs
import os
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface/hub"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface/transformers"
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List, Optional
import torch
from transformers import AutoModel, AutoTokenizer
app = FastAPI()
# -----------------------------
# Load model once on startup
# -----------------------------
MODEL_NAME = "jinaai/jina-embeddings-v4"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModel.from_pretrained(
MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float16
).to(device)
model.eval()
# -----------------------------
# Request / Response Models
# -----------------------------
class EmbedRequest(BaseModel):
text: str
task: str = "retrieval" # "retrieval", "text-matching", "code", etc.
prompt_name: Optional[str] = None
return_token_embeddings: bool = True # False → for queries (pooled embedding)
class EmbedResponse(BaseModel):
embeddings: List[List[float]] # (num_tokens, hidden_dim) if token-level
# (1, hidden_dim) if pooled query
class TokenizeRequest(BaseModel):
text: str
class TokenizeResponse(BaseModel):
input_ids: List[int]
class DecodeRequest(BaseModel):
input_ids: List[int]
class DecodeResponse(BaseModel):
text: str
# -----------------------------
# Embedding Endpoint
# -----------------------------
@app.post("/embed", response_model=EmbedResponse)
def embed(req: EmbedRequest):
text = req.text
# -----------------------------
# Case 1: Query → directly pooled embedding
# -----------------------------
if not req.return_token_embeddings:
with torch.no_grad():
emb = model.encode_text(
texts=[text],
task=req.task,
prompt_name=req.prompt_name or "query",
return_multivector=False
)
return {"embeddings": emb.tolist()} # shape: (1, hidden_dim)
# -----------------------------
# Case 2: Long passages → sliding window token embeddings
# -----------------------------
enc = tokenizer(text, add_special_tokens=False, return_tensors="pt")
input_ids = enc["input_ids"].squeeze(0).to(device) # (total_tokens,)
total_tokens = input_ids.size(0)
max_len = model.config.max_position_embeddings # e.g., 32k for v4
stride = 50 # overlap for sliding window
embeddings = []
position = 0
while position < total_tokens:
end = min(position + max_len, total_tokens)
window_ids = input_ids[position:end].unsqueeze(0).to(device)
with torch.no_grad():
outputs = model.encode_text(
texts=[tokenizer.decode(window_ids[0])],
task=req.task,
prompt_name=req.prompt_name or "passage",
return_multivector=True,
)
window_embeds = outputs.squeeze(0).cpu() # (window_len, hidden_dim)
# Drop overlapping tokens except in first window
if position > 0:
window_embeds = window_embeds[stride:]
embeddings.append(window_embeds)
# Advance window
position += max_len - stride
full_embeddings = torch.cat(embeddings, dim=0) # (total_tokens, hidden_dim)
return {"embeddings": full_embeddings.tolist()}
# -----------------------------
# Tokenize Endpoint
# -----------------------------
@app.post("/tokenize", response_model=TokenizeResponse)
def tokenize(req: TokenizeRequest):
enc = tokenizer(req.text, add_special_tokens=False)
return {"input_ids": enc["input_ids"]}
# -----------------------------
# Decode Endpoint
# -----------------------------
@app.post("/decode", response_model=DecodeResponse)
def decode(req: DecodeRequest):
decoded = tokenizer.decode(req.input_ids)
return {"text": decoded} |