Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import InferenceClient | |
| from datasets import load_dataset | |
| import torch | |
| from transformers import pipeline | |
| class ContentFilter: | |
| def __init__(self): | |
| # Initialize toxic content detection model | |
| self.toxicity_classifier = pipeline( | |
| 'text-classification', | |
| model='unitary/toxic-bert', | |
| return_all_scores=True | |
| ) | |
| # Keyword blacklist | |
| self.blacklist = [ | |
| 'hate', 'discriminate', 'violent', | |
| 'offensive', 'inappropriate', 'racist', | |
| 'sexist', 'homophobic', 'transphobic' | |
| ] | |
| def filter_toxicity(self, text, toxicity_threshold=0.5): | |
| """ | |
| Detect toxic content using pre-trained model | |
| Args: | |
| text (str): Input text to check | |
| toxicity_threshold (float): Threshold for filtering | |
| Returns: | |
| dict: Filtering results | |
| """ | |
| results = self.toxicity_classifier(text)[0] | |
| # Convert results to dictionary | |
| toxicity_scores = { | |
| result['label']: result['score'] | |
| for result in results | |
| } | |
| # Check if any toxic category exceeds threshold | |
| is_toxic = any( | |
| score > toxicity_threshold | |
| for score in toxicity_scores.values() | |
| ) | |
| return { | |
| 'is_toxic': is_toxic, | |
| 'toxicity_scores': toxicity_scores | |
| } | |
| def filter_keywords(self, text): | |
| """ | |
| Check text against keyword blacklist | |
| Args: | |
| text (str): Input text to check | |
| Returns: | |
| list: Matched blacklisted keywords | |
| """ | |
| matched_keywords = [ | |
| keyword for keyword in self.blacklist | |
| if keyword.lower() in text.lower() | |
| ] | |
| return matched_keywords | |
| def comprehensive_filter(self, text): | |
| """ | |
| Perform comprehensive content filtering | |
| Args: | |
| text (str): Input text to filter | |
| Returns: | |
| dict: Comprehensive filtering results | |
| """ | |
| # Toxicity model filtering | |
| toxicity_results = self.filter_toxicity(text) | |
| # Keyword blacklist filtering | |
| blacklisted_keywords = self.filter_keywords(text) | |
| # Combine results | |
| return { | |
| 'toxicity': toxicity_results, | |
| 'blacklisted_keywords': blacklisted_keywords, | |
| 'is_safe': not toxicity_results['is_toxic'] and len(blacklisted_keywords) == 0 | |
| } | |
| # Initialize content filter | |
| content_filter = ContentFilter() | |
| # Initialize Hugging Face client | |
| #client = InferenceClient("HuggingFaceH4/zephyr-7b-beta") | |
| client = InferenceClient("google-t5/t5-small") | |
| # Load dataset (optional) | |
| dataset = load_dataset("JustKiddo/KiddosVault") | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p | |
| ): | |
| # First, filter the incoming user message | |
| message_filter_result = content_filter.comprehensive_filter(message) | |
| # If message is not safe, return a warning | |
| if not message_filter_result['is_safe']: | |
| toxicity_details = message_filter_result['toxicity']['toxicity_scores'] | |
| blacklisted_keywords = message_filter_result['blacklisted_keywords'] | |
| warning_message = "Message flagged for inappropriate content. " | |
| warning_message += "Detected issues: " | |
| # Add toxicity details | |
| for category, score in toxicity_details.items(): | |
| if score > 0.5: | |
| warning_message += f"{category} (Score: {score:.2f}), " | |
| # Add blacklisted keywords | |
| if blacklisted_keywords: | |
| warning_message += f"Blacklisted keywords: {', '.join(blacklisted_keywords)}" | |
| return warning_message | |
| # Prepare messages for chat completion | |
| messages = [{"role": "system", "content": system_message}] | |
| for val in history: | |
| if val[0]: | |
| messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| messages.append({"role": "assistant", "content": val[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| # Generate response | |
| response = "" | |
| for message in client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p | |
| ): | |
| token = message.choices[0].delta.content | |
| response += token | |
| yield response | |
| # Create Gradio interface | |
| demo = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox( | |
| value="You are a professional and friendly assistant.", | |
| label="System message" | |
| ), | |
| gr.Slider( | |
| minimum=1, | |
| maximum=6144, | |
| value=6144, | |
| step=1, | |
| label="Max new tokens" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=4.0, | |
| value=1, | |
| step=0.1, | |
| label="Temperature" | |
| ), | |
| gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (nucleus sampling)" | |
| ), | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) |