Spaces:
Runtime error
Runtime error
Boning c
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,21 +1,23 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from
|
| 3 |
-
from
|
| 4 |
-
from transformers import pipeline
|
| 5 |
-
import re
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
MAX_TOKENS = 500
|
| 12 |
TEMPERATURE = 0.7
|
| 13 |
TOP_P = 0.9
|
| 14 |
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
|
| 15 |
|
| 16 |
# --- Safety Configuration ---
|
| 17 |
-
# Initialize the toxicity classifier pipeline
|
| 18 |
-
# This model identifies if text is 'toxic' with a confidence score.
|
| 19 |
print("Loading safety model (unitary/toxic-bert)...")
|
| 20 |
try:
|
| 21 |
safety_classifier = pipeline(
|
|
@@ -26,11 +28,8 @@ try:
|
|
| 26 |
print("Safety model loaded successfully.")
|
| 27 |
except Exception as e:
|
| 28 |
print(f"Error loading safety model: {e}")
|
| 29 |
-
# Consider handling this error more gracefully, e.g., run without safety if model fails to load
|
| 30 |
exit(1)
|
| 31 |
|
| 32 |
-
# Threshold for flagging content as unsafe (0.0 to 1.0)
|
| 33 |
-
# A higher threshold means it's stricter (less likely to flag non-toxic content).
|
| 34 |
TOXICITY_THRESHOLD = 0.9
|
| 35 |
|
| 36 |
def is_text_safe(text: str) -> tuple[bool, str | None]:
|
|
@@ -39,14 +38,10 @@ def is_text_safe(text: str) -> tuple[bool, str | None]:
|
|
| 39 |
Returns (True, None) if safe, or (False, detected_label) if unsafe.
|
| 40 |
"""
|
| 41 |
if not text.strip():
|
| 42 |
-
return True, None
|
| 43 |
|
| 44 |
try:
|
| 45 |
-
# Classify the text. The model typically returns [{'label': 'toxic', 'score': X.XX}]
|
| 46 |
-
# or [{'label': 'nontoxic', 'score': X.XX}] depending on thresholding in the model.
|
| 47 |
-
# For unitary/toxic-bert, 'toxic' is the positive label.
|
| 48 |
results = safety_classifier(text)
|
| 49 |
-
|
| 50 |
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
|
| 51 |
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
|
| 52 |
return False, results[0]['label']
|
|
@@ -56,89 +51,165 @@ def is_text_safe(text: str) -> tuple[bool, str | None]:
|
|
| 56 |
except Exception as e:
|
| 57 |
print(f"Error during safety check: {e}")
|
| 58 |
# If the safety check fails, consider it unsafe by default or log and let it pass.
|
| 59 |
-
# For a robust solution, you might want to re-raise or yield an error message.
|
| 60 |
return False, "safety_check_failed"
|
| 61 |
|
| 62 |
|
| 63 |
-
# --- Main Model Loading (
|
| 64 |
-
print(f"
|
| 65 |
try:
|
| 66 |
-
|
| 67 |
-
|
|
|
|
| 68 |
except Exception as e:
|
| 69 |
-
print(f"Error
|
|
|
|
| 70 |
exit(1)
|
| 71 |
|
| 72 |
-
print("
|
| 73 |
try:
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
)
|
| 80 |
-
print("Llama model initialized successfully.")
|
| 81 |
except Exception as e:
|
| 82 |
-
print(f"Error
|
|
|
|
| 83 |
exit(1)
|
| 84 |
|
| 85 |
-
#
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
current_sentence_buffer += token
|
| 107 |
-
full_output_so_far += token # Keep track of full output for comprehensive check if needed
|
| 108 |
|
| 109 |
-
|
| 110 |
-
if re.search(r'[.!?]\s*$', current_sentence_buffer) or len(current_sentence_buffer) > 100: # Max sentence length fallback
|
| 111 |
-
is_safe, detected_label = is_text_safe(current_sentence_buffer)
|
| 112 |
if not is_safe:
|
| 113 |
-
print(f"Safety check failed for
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
#
|
| 117 |
-
|
| 118 |
-
# return
|
| 119 |
else:
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
|
| 133 |
# --- Gradio Blocks Interface ---
|
| 134 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 135 |
gr.Markdown(
|
| 136 |
"""
|
| 137 |
-
# SmilyAI: Sam-reason-
|
| 138 |
-
Enter a prompt and get a word-by-word response from the Sam-reason-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
Running on Hugging Face Spaces' free CPU tier.
|
| 142 |
"""
|
| 143 |
)
|
| 144 |
|
|
@@ -153,15 +224,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 153 |
|
| 154 |
send_button = gr.Button("Send", variant="primary")
|
| 155 |
|
| 156 |
-
# Connect the button click to the inference function with safety check
|
| 157 |
send_button.click(
|
| 158 |
-
fn=generate_word_by_word_with_safety,
|
| 159 |
inputs=user_prompt,
|
| 160 |
outputs=generated_text,
|
| 161 |
api_name="predict",
|
| 162 |
)
|
| 163 |
|
| 164 |
-
# Launch the Gradio application
|
| 165 |
if __name__ == "__main__":
|
| 166 |
print("Launching Gradio app...")
|
| 167 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from huggingface_hub import hf_hub_download # Still useful if model is private and needs custom token
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, GenerationConfig
|
| 4 |
+
from transformers.pipelines import pipeline
|
| 5 |
+
import re
|
| 6 |
+
import os
|
| 7 |
+
import torch # Required for transformers models
|
| 8 |
+
import threading
|
| 9 |
+
import time # For short sleeps in streamer
|
| 10 |
+
|
| 11 |
+
# --- Model Configuration ---
|
| 12 |
+
# Your SmilyAI model ID on Hugging Face Hub
|
| 13 |
+
MODEL_REPO_ID = "Smilyai-labs/Sam-reason-S3"
|
| 14 |
+
N_CTX = 2048 # Context window for the model (applies more to LLMs)
|
| 15 |
MAX_TOKENS = 500
|
| 16 |
TEMPERATURE = 0.7
|
| 17 |
TOP_P = 0.9
|
| 18 |
STOP_SEQUENCES = ["USER:", "\n\n"] # Model will stop generating when it encounters these
|
| 19 |
|
| 20 |
# --- Safety Configuration ---
|
|
|
|
|
|
|
| 21 |
print("Loading safety model (unitary/toxic-bert)...")
|
| 22 |
try:
|
| 23 |
safety_classifier = pipeline(
|
|
|
|
| 28 |
print("Safety model loaded successfully.")
|
| 29 |
except Exception as e:
|
| 30 |
print(f"Error loading safety model: {e}")
|
|
|
|
| 31 |
exit(1)
|
| 32 |
|
|
|
|
|
|
|
| 33 |
TOXICITY_THRESHOLD = 0.9
|
| 34 |
|
| 35 |
def is_text_safe(text: str) -> tuple[bool, str | None]:
|
|
|
|
| 38 |
Returns (True, None) if safe, or (False, detected_label) if unsafe.
|
| 39 |
"""
|
| 40 |
if not text.strip():
|
| 41 |
+
return True, None
|
| 42 |
|
| 43 |
try:
|
|
|
|
|
|
|
|
|
|
| 44 |
results = safety_classifier(text)
|
|
|
|
| 45 |
if results and results[0]['label'] == 'toxic' and results[0]['score'] > TOXICITY_THRESHOLD:
|
| 46 |
print(f"Detected unsafe content: '{text.strip()}' (Score: {results[0]['score']:.4f})")
|
| 47 |
return False, results[0]['label']
|
|
|
|
| 51 |
except Exception as e:
|
| 52 |
print(f"Error during safety check: {e}")
|
| 53 |
# If the safety check fails, consider it unsafe by default or log and let it pass.
|
|
|
|
| 54 |
return False, "safety_check_failed"
|
| 55 |
|
| 56 |
|
| 57 |
+
# --- Main Model Loading (using Transformers) ---
|
| 58 |
+
print(f"Loading tokenizer for {MODEL_REPO_ID}...")
|
| 59 |
try:
|
| 60 |
+
# AutoTokenizer fetches the correct tokenizer for the model
|
| 61 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO_ID)
|
| 62 |
+
print("Tokenizer loaded.")
|
| 63 |
except Exception as e:
|
| 64 |
+
print(f"Error loading tokenizer: {e}")
|
| 65 |
+
print("Make sure the model ID is correct and, if it's a private repo, you've set the HF_TOKEN secret in your Space.")
|
| 66 |
exit(1)
|
| 67 |
|
| 68 |
+
print(f"Loading model {MODEL_REPO_ID} (this will be VERY slow on CPU and might take a long time)...")
|
| 69 |
try:
|
| 70 |
+
# AutoModelForCausalLM loads the language model.
|
| 71 |
+
# device_map="cpu" ensures all model layers are loaded onto the CPU.
|
| 72 |
+
# torch_dtype=torch.float32 is standard for CPU; float16 can save memory but might not be faster on all CPUs.
|
| 73 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_REPO_ID, device_map="cpu", torch_dtype=torch.float32)
|
| 74 |
+
model.eval() # Set model to evaluation mode for inference
|
| 75 |
+
print("Model loaded successfully.")
|
|
|
|
| 76 |
except Exception as e:
|
| 77 |
+
print(f"Error loading model: {e}")
|
| 78 |
+
print("Ensure it's a standard Transformers model and you have HF_TOKEN secret if private.")
|
| 79 |
exit(1)
|
| 80 |
|
| 81 |
+
# Configure generation for streaming
|
| 82 |
+
# Use GenerationConfig from the model for default parameters, then override as needed.
|
| 83 |
+
generation_config = GenerationConfig.from_pretrained(MODEL_REPO_ID)
|
| 84 |
+
generation_config.max_new_tokens = MAX_TOKENS
|
| 85 |
+
generation_config.temperature = TEMPERATURE
|
| 86 |
+
generation_config.top_p = TOP_P
|
| 87 |
+
generation_config.do_sample = True # Enable sampling for temperature/top_p
|
| 88 |
+
# Set EOS and PAD token IDs for proper generation stopping and padding
|
| 89 |
+
generation_config.eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
|
| 90 |
+
generation_config.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id if tokenizer.eos_token_id is not None else -1
|
| 91 |
+
# Fallback for pad_token_id if not explicitly set
|
| 92 |
+
if generation_config.pad_token_id == -1:
|
| 93 |
+
generation_config.pad_token_id = 0 # Fallback to 0, though not ideal for all models
|
| 94 |
+
|
| 95 |
+
# --- Custom Streamer for Gradio and Safety Check ---
|
| 96 |
+
class GradioSafetyStreamer(TextIteratorStreamer):
|
| 97 |
+
def __init__(self, tokenizer, safety_checker_fn, toxicity_threshold, skip_special_tokens=True, **kwargs):
|
| 98 |
+
super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, **kwargs)
|
| 99 |
+
self.safety_checker_fn = safety_checker_fn
|
| 100 |
+
self.toxicity_threshold = toxicity_threshold
|
| 101 |
+
self.current_sentence_buffer = ""
|
| 102 |
+
self.output_queue = [] # Queue to store safety-checked sentences to be yielded by Gradio
|
| 103 |
+
self.sentence_regex = re.compile(r'[.!?]\s*') # Regex for sentence end, simple version
|
| 104 |
+
self.text_done = threading.Event() # Event to signal when internal text processing is complete
|
| 105 |
+
|
| 106 |
+
def on_finalized_text(self, text: str, stream_end: bool = False):
|
| 107 |
+
# This method is called by the superclass when a decoded token chunk is ready.
|
| 108 |
+
self.current_sentence_buffer += text
|
| 109 |
+
|
| 110 |
+
# Split buffer into sentences. Keep the last part in buffer if it's incomplete.
|
| 111 |
+
sentences = self.sentence_regex.split(self.current_sentence_buffer)
|
| 112 |
+
|
| 113 |
+
sentences_to_process = []
|
| 114 |
+
if not stream_end and sentences and self.sentence_regex.search(sentences[-1]) is None:
|
| 115 |
+
# If not end of stream and last part is not a complete sentence, buffer it for next time
|
| 116 |
+
sentences_to_process = sentences[:-1]
|
| 117 |
+
self.current_sentence_buffer = sentences[-1]
|
| 118 |
+
else:
|
| 119 |
+
# Otherwise, process all segments and clear buffer
|
| 120 |
+
sentences_to_process = sentences
|
| 121 |
+
self.current_sentence_buffer = ""
|
| 122 |
|
| 123 |
+
for sentence in sentences_to_process:
|
| 124 |
+
if not sentence.strip(): continue # Skip empty strings from splitting
|
|
|
|
|
|
|
| 125 |
|
| 126 |
+
is_safe, detected_label = self.safety_checker_fn(sentence)
|
|
|
|
|
|
|
| 127 |
if not is_safe:
|
| 128 |
+
print(f"Safety check failed for: '{sentence.strip()}' (Detected: {detected_label})")
|
| 129 |
+
self.output_queue.append("[Content removed due to safety guidelines]")
|
| 130 |
+
self.output_queue.append("__STOP_GENERATION__") # Special signal to stop LLM generation
|
| 131 |
+
return # Stop processing further sentences from this chunk if unsafe
|
| 132 |
+
|
|
|
|
| 133 |
else:
|
| 134 |
+
self.output_queue.append(sentence)
|
| 135 |
+
|
| 136 |
+
if stream_end:
|
| 137 |
+
# If stream ends and there's leftover text in buffer, process it
|
| 138 |
+
if self.current_sentence_buffer.strip():
|
| 139 |
+
is_safe, detected_label = self.safety_checker_fn(self.current_sentence_buffer)
|
| 140 |
+
if not is_safe:
|
| 141 |
+
self.output_queue.append("[Content removed due to safety guidelines]")
|
| 142 |
+
else:
|
| 143 |
+
self.output_queue.append(self.current_sentence_buffer)
|
| 144 |
+
self.current_sentence_buffer = "" # Clear after final check
|
| 145 |
+
self.text_done.set() # Signal that all text processing is complete
|
| 146 |
+
|
| 147 |
+
def __iter__(self):
|
| 148 |
+
# This method allows Gradio to iterate over the safety-checked output.
|
| 149 |
+
while True:
|
| 150 |
+
if self.output_queue:
|
| 151 |
+
item = self.output_queue.pop(0)
|
| 152 |
+
if item == "__STOP_GENERATION__":
|
| 153 |
+
# Signal to the outer Gradio loop to stop yielding.
|
| 154 |
+
raise StopIteration
|
| 155 |
+
yield item
|
| 156 |
+
elif self.text_done.is_set(): # Check if internal generation and safety processing is truly finished
|
| 157 |
+
raise StopIteration # End of generation and safety check
|
| 158 |
+
else:
|
| 159 |
+
time.sleep(0.01) # Small sleep to prevent busy-waiting while waiting for new tokens
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# --- Inference Function with Safety and Streaming ---
|
| 163 |
+
def generate_word_by_word_with_safety(prompt_text: str):
|
| 164 |
+
formatted_prompt = f"USER: {prompt_text}\nASSISTANT:"
|
| 165 |
+
# Encode input on the model's device (CPU)
|
| 166 |
+
input_ids = tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(model.device)
|
| 167 |
+
|
| 168 |
+
# Initialize the custom streamer
|
| 169 |
+
streamer = GradioSafetyStreamer(tokenizer, is_text_safe, TOXICITY_THRESHOLD)
|
| 170 |
+
|
| 171 |
+
# Use a separate thread for model generation because model.generate is a blocking call.
|
| 172 |
+
# This allows the streamer to continuously fill its queue while Gradio yields.
|
| 173 |
+
generate_kwargs = {
|
| 174 |
+
"input_ids": input_ids,
|
| 175 |
+
"streamer": streamer,
|
| 176 |
+
"generation_config": generation_config,
|
| 177 |
+
# Explicitly pass these for clarity, even if in generation_config
|
| 178 |
+
"do_sample": True,
|
| 179 |
+
"temperature": TEMPERATURE,
|
| 180 |
+
"top_p": TOP_P,
|
| 181 |
+
"max_new_tokens": MAX_TOKENS,
|
| 182 |
+
"eos_token_id": generation_config.eos_token_id,
|
| 183 |
+
"pad_token_id": generation_config.pad_token_id,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
# Start generation in a separate thread
|
| 187 |
+
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
|
| 188 |
+
thread.start()
|
| 189 |
+
|
| 190 |
+
# Yield tokens from the streamer's output queue for Gradio to display progressively
|
| 191 |
+
full_generated_text = ""
|
| 192 |
+
try:
|
| 193 |
+
for new_sentence_or_chunk in streamer:
|
| 194 |
+
full_generated_text += new_sentence_or_chunk
|
| 195 |
+
yield full_generated_text # Gradio expects accumulated string for streaming display
|
| 196 |
+
except StopIteration:
|
| 197 |
+
pass # Streamer signaled end
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f"Error during streaming: {e}")
|
| 200 |
+
yield full_generated_text + f"\n\n[Error during streaming: {e}]" # Show error in output
|
| 201 |
+
finally:
|
| 202 |
+
thread.join() # Ensure the generation thread finishes gracefully
|
| 203 |
|
| 204 |
|
| 205 |
# --- Gradio Blocks Interface ---
|
| 206 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 207 |
gr.Markdown(
|
| 208 |
"""
|
| 209 |
+
# SmilyAI: Sam-reason-S3 Inference (Transformers on CPU with Safety Filter)
|
| 210 |
+
Enter a prompt and get a word-by-word response from the **Smilyai-labs/Sam-reason-S3** model.
|
| 211 |
+
**β οΈ WARNING: This model is running on a free CPU tier via the `transformers` library. Inference will be VERY slow.**
|
| 212 |
+
All generated sentences are checked for safety using an AI filter; unsafe content will be replaced.
|
|
|
|
| 213 |
"""
|
| 214 |
)
|
| 215 |
|
|
|
|
| 224 |
|
| 225 |
send_button = gr.Button("Send", variant="primary")
|
| 226 |
|
|
|
|
| 227 |
send_button.click(
|
| 228 |
+
fn=generate_word_by_word_with_safety,
|
| 229 |
inputs=user_prompt,
|
| 230 |
outputs=generated_text,
|
| 231 |
api_name="predict",
|
| 232 |
)
|
| 233 |
|
|
|
|
| 234 |
if __name__ == "__main__":
|
| 235 |
print("Launching Gradio app...")
|
| 236 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|