multimodalart HF Staff commited on
Commit
bfa8055
·
verified ·
1 Parent(s): 20a769e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -0
app.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import io
4
+ import time
5
+ import tempfile
6
+ import logging
7
+
8
+ import torch
9
+ import gradio as gr
10
+ from transformers import Mistral3ForConditionalGeneration, AutoProcessor
11
+
12
+ from mistral_text_encoding_core import encode_prompt
13
+
14
+ # ------------------------------------------------------
15
+ # Logging
16
+ # ------------------------------------------------------
17
+ logging.basicConfig(
18
+ level=os.getenv("LOG_LEVEL", "INFO"),
19
+ format="%(asctime)s [%(levelname)s] %(name)s - %(message)s",
20
+ )
21
+ logger = logging.getLogger("mistral-text-encoding-gradio")
22
+
23
+ # ------------------------------------------------------
24
+ # Config
25
+ # ------------------------------------------------------
26
+ TEXT_ENCODER_ID = os.getenv("TEXT_ENCODER_ID", "/repository")
27
+ TOKENIZER_ID = os.getenv(
28
+ "TOKENIZER_ID", "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
29
+ )
30
+ DTYPE = torch.bfloat16
31
+
32
+ # ------------------------------------------------------
33
+ # Global model references
34
+ # ------------------------------------------------------
35
+ logger.info("Loading models...")
36
+
37
+ t0 = time.time()
38
+ text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
39
+ TEXT_ENCODER_ID,
40
+ dtype=DTYPE,
41
+ ).to("cuda")
42
+ logger.info(
43
+ "Loaded Mistral text encoder (%.2fs) dtype=%s device=%s",
44
+ time.time() - t0,
45
+ text_encoder.dtype,
46
+ DEVICE_MAP,
47
+ )
48
+
49
+ t1 = time.time()
50
+ tokenizer = AutoProcessor.from_pretrained(TOKENIZER_ID)
51
+ logger.info("Loaded tokenizer in %.2fs", time.time() - t1)
52
+
53
+ torch.set_grad_enabled(False)
54
+
55
+
56
+ def get_vram_info():
57
+ """Get current VRAM usage info."""
58
+ if torch.cuda.is_available():
59
+ return {
60
+ "vram_allocated_mb": round(torch.cuda.memory_allocated() / 1024 / 1024, 2),
61
+ "vram_reserved_mb": round(torch.cuda.memory_reserved() / 1024 / 1024, 2),
62
+ "vram_max_allocated_mb": round(torch.cuda.max_memory_allocated() / 1024 / 1024, 2),
63
+ }
64
+ return {"vram": "CUDA not available"}
65
+
66
+ @spaces.GPU()
67
+ def encode_text(prompt: str):
68
+ """Encode text and return a downloadable pytorch file."""
69
+ global text_encoder, tokenizer
70
+
71
+ if text_encoder is None or tokenizer is None:
72
+ return None, "Model not loaded"
73
+
74
+ t0 = time.time()
75
+
76
+ # Handle multiple prompts (one per line)
77
+ prompts = [p.strip() for p in prompt.strip().split("\n") if p.strip()]
78
+ if not prompts:
79
+ return None, "Please enter at least one prompt"
80
+
81
+ num_prompts = len(prompts)
82
+ prompt_input = prompts[0] if num_prompts == 1 else prompts
83
+
84
+ logger.info("Encoding %d prompt(s)", num_prompts)
85
+
86
+ prompt_embeds, text_ids = encode_prompt(
87
+ text_encoder=text_encoder,
88
+ tokenizer=tokenizer,
89
+ prompt=prompt_input,
90
+ )
91
+
92
+ duration = (time.time() - t0) * 1000.0
93
+
94
+ logger.info(
95
+ "Encoded in %.2f ms | prompt_embeds.shape=%s | text_ids.shape=%s",
96
+ duration,
97
+ tuple(prompt_embeds.shape),
98
+ tuple(text_ids.shape),
99
+ )
100
+
101
+ # Save tensor to a temporary file
102
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pt")
103
+ torch.save(prompt_embeds.cpu(), temp_file.name)
104
+
105
+ # Clean up GPU tensors
106
+ del prompt_embeds, text_ids
107
+ gc.collect()
108
+ if torch.cuda.is_available():
109
+ torch.cuda.empty_cache()
110
+
111
+ vram = get_vram_info()
112
+ status = (
113
+ f"Encoded {num_prompts} prompt(s) in {duration:.2f}ms\n"
114
+ f"VRAM: {vram.get('vram_allocated_mb', 'N/A')} MB allocated, "
115
+ f"{vram.get('vram_max_allocated_mb', 'N/A')} MB peak"
116
+ )
117
+
118
+ return temp_file.name, status
119
+
120
+
121
+ # ------------------------------------------------------
122
+ # Gradio Interface
123
+ # ------------------------------------------------------
124
+ with gr.Blocks(title="Mistral Text Encoder") as demo:
125
+ gr.Markdown("# Mistral Text Encoder")
126
+ gr.Markdown("Enter text to encode. For multiple prompts, put each on a new line.")
127
+
128
+ with gr.Row():
129
+ with gr.Column():
130
+ prompt_input = gr.Textbox(
131
+ label="Prompt(s)",
132
+ placeholder="Enter your prompt here...\nOr multiple prompts, one per line",
133
+ lines=5,
134
+ )
135
+ encode_btn = gr.Button("Encode", variant="primary")
136
+
137
+ with gr.Column():
138
+ output_file = gr.File(label="Download Embeddings (.pt)")
139
+ status_output = gr.Textbox(label="Status", interactive=False)
140
+
141
+ encode_btn.click(
142
+ fn=encode_text,
143
+ inputs=[prompt_input],
144
+ outputs=[output_file, status_output],
145
+ )
146
+
147
+
148
+ if __name__ == "__main__":
149
+ load_models()
150
+ demo.launch(server_name="0.0.0.0", server_port=7860)