Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,514 Bytes
bfa8055 f3b481c bfa8055 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import gc
import os
import io
import time
import tempfile
import logging
import spaces
import torch
import gradio as gr
from transformers import Mistral3ForConditionalGeneration, AutoProcessor
from mistral_text_encoding_core import encode_prompt
# ------------------------------------------------------
# Logging
# ------------------------------------------------------
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
)
logger = logging.getLogger("mistral-text-encoding-gradio")
# ------------------------------------------------------
# Config
# ------------------------------------------------------
TEXT_ENCODER_ID = os.getenv("TEXT_ENCODER_ID", "/repository")
TOKENIZER_ID = os.getenv(
"TOKENIZER_ID", "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
)
DTYPE = torch.bfloat16
# ------------------------------------------------------
# Global model references
# ------------------------------------------------------
logger.info("Loading models...")
t0 = time.time()
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
TEXT_ENCODER_ID,
dtype=DTYPE,
).to("cuda")
logger.info(
"Loaded Mistral text encoder (%.2fs) dtype=%s device=%s",
time.time() - t0,
text_encoder.dtype,
DEVICE_MAP,
)
t1 = time.time()
tokenizer = AutoProcessor.from_pretrained(TOKENIZER_ID)
logger.info("Loaded tokenizer in %.2fs", time.time() - t1)
torch.set_grad_enabled(False)
def get_vram_info():
"""Get current VRAM usage info."""
if torch.cuda.is_available():
return {
"vram_allocated_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 2),
"vram_reserved_mb": round(torch.cuda.memory_reserved() / 1024 / 1024, 2),
"vram_max_allocated_mb": round(torch.cuda.max_memory_allocated() / 1024 / 1024, 2),
}
return {"vram": "CUDA not available"}
@spaces.GPU()
def encode_text(prompt: str):
"""Encode text and return a downloadable pytorch file."""
global text_encoder, tokenizer
if text_encoder is None or tokenizer is None:
return None, "Model not loaded"
t0 = time.time()
# Handle multiple prompts (one per line)
prompts = [p.strip() for p in prompt.strip().split("\n") if p.strip()]
if not prompts:
return None, "Please enter at least one prompt"
num_prompts = len(prompts)
prompt_input = prompts[0] if num_prompts == 1 else prompts
logger.info("Encoding %d prompt(s)", num_prompts)
prompt_embeds, text_ids = encode_prompt(
text_encoder=text_encoder,
tokenizer=tokenizer,
prompt=prompt_input,
)
duration = (time.time() - t0) * 1000.0
logger.info(
"Encoded in %.2f ms | prompt_embeds.shape=%s | text_ids.shape=%s",
duration,
tuple(prompt_embeds.shape),
tuple(text_ids.shape),
)
# Save tensor to a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pt")
torch.save(prompt_embeds.cpu(), temp_file.name)
# Clean up GPU tensors
del prompt_embeds, text_ids
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
vram = get_vram_info()
status = (
f"Encoded {num_prompts} prompt(s) in {duration:.2f}ms\n"
f"VRAM: {vram.get('vram_allocated_mb', 'N/A')} MB allocated, "
f"{vram.get('vram_max_allocated_mb', 'N/A')} MB peak"
)
return temp_file.name, status
# ------------------------------------------------------
# Gradio Interface
# ------------------------------------------------------
with gr.Blocks(title="Mistral Text Encoder") as demo:
gr.Markdown("# Mistral Text Encoder")
gr.Markdown("Enter text to encode. For multiple prompts, put each on a new line.")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt(s)",
placeholder="Enter your prompt here...\nOr multiple prompts, one per line",
lines=5,
)
encode_btn = gr.Button("Encode", variant="primary")
with gr.Column():
output_file = gr.File(label="Download Embeddings (.pt)")
status_output = gr.Textbox(label="Status", interactive=False)
encode_btn.click(
fn=encode_text,
inputs=[prompt_input],
outputs=[output_file, status_output],
)
if __name__ == "__main__":
load_models()
demo.launch(server_name="0.0.0.0", server_port=7860)
|