Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import pandas as pd | |
| import torch | |
| from transformers import pipeline | |
| from sentence_transformers import SentenceTransformer, util | |
| import concurrent.futures | |
| import time | |
| import sys | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from transformers import AutoTokenizer, AutoModel | |
| import numpy as np | |
| from scipy import stats | |
| from PyDictionary import PyDictionary | |
| import matplotlib.pyplot as plt | |
| from scipy import stats | |
| import litellm | |
| import re | |
| import sentencepiece | |
| import random | |
| from global_vars import t, translations | |
| from app import Plugin | |
| from embeddings_ft import finetune as finetune_embeddings | |
| from bart_ft import finetune as finetune_bart | |
| from webrankings_helper import * | |
| from plugins.scansite import ScansitePlugin | |
| #from data import reference_data_valid, reference_data_rejected | |
| #reference_data = reference_data_valid + reference_data_rejected | |
| # Ajout des traductions spécifiques à ce plugin | |
| translations["en"].update({ | |
| "webrankings_title": "Comparative os sorter", | |
| "clear_memory": "Clear Memory", | |
| "enter_topic": "Enter the topic you're interested in (e.g. longevity):", | |
| "use_keyword_expansion": "Use keyword expansion", | |
| "test_content": "Also test link content in addition to titles", | |
| "select_llm_models": "Select LLM models to use", | |
| "select_zero_shot_models": "Select zero-shot models to use", | |
| "select_embedding_models": "Select embedding models to use", | |
| "analyze_button": "Analyze", | |
| "loading_models": "Loading models and analyzing links...", | |
| "expanded_keywords": "Expanded keywords:", | |
| "analysis_completed": "Analysis completed in {:.2f} seconds", | |
| "evaluation_results": "Evaluation results with optimal thresholds:", | |
| "summary_table": "Summary table of scores", | |
| "optimal_thresholds": "Optimal thresholds:", | |
| "spearman_comparison": "Comparison of Spearman correlations", | |
| "methods": "Methods", | |
| "spearman_correlation": "Spearman correlation coefficient", | |
| "results_for": "Results for {}", | |
| "device_info": "Device used for inference: {}", | |
| "finetune_bart_title": "BART Fine-tuning Interface", | |
| "finetune_embeddings_title": "Embeddings Fine-tuning Interface", | |
| }) | |
| translations["fr"].update({ | |
| "webrankings_title": "Analyseur de classeurs", | |
| "clear_memory": "Vider la mémoire", | |
| "enter_topic": "Entrez le sujet qui vous intéresse (ex: longévité):", | |
| "use_keyword_expansion": "Utiliser l'expansion des mots-clés", | |
| "test_content": "Tester aussi le contenu des liens en plus des titres", | |
| "select_llm_models": "Sélectionnez les modèles LLM à utiliser", | |
| "select_zero_shot_models": "Sélectionnez les modèles zero-shot à utiliser", | |
| "select_embedding_models": "Sélectionnez les modèles d'embedding à utiliser", | |
| "analyze_button": "Analyser", | |
| "loading_models": "Chargement des modèles et analyse des liens...", | |
| "expanded_keywords": "Mots-clés étendus :", | |
| "analysis_completed": "Analyse terminée en {:.2f} secondes", | |
| "evaluation_results": "Résultats d'évaluation avec les seuils optimaux :", | |
| "summary_table": "Tableau récapitulatif des scores", | |
| "optimal_thresholds": "Seuils optimaux :", | |
| "spearman_comparison": "Comparaison des corrélations de Spearman", | |
| "methods": "Méthodes", | |
| "spearman_correlation": "Coefficient de corrélation de Spearman", | |
| "results_for": "Résultats pour {}", | |
| "device_info": "Dispositif utilisé pour l'inférence : {}", | |
| "finetune_bart_title": "Interface de Fine-tuning BART", | |
| "finetune_embeddings_title": "Interface de Fine-tuning des Embeddings", | |
| }) | |
| # Liste des modèles LLM | |
| llm_models = [] #["ollama/llama3", "ollama/llama3.1", "ollama/qwen2", "ollama/phi3:medium-128k", "ollama/openhermes"] | |
| # Liste des modèles zero-shot | |
| zero_shot_models = [ | |
| ("facebook/bart-large-mnli", "facebook/bart-large-mnli"), | |
| ("bart-large-ft", "./bart-large-ft") | |
| #("cross-encoder/nli-deberta-v3-base", "cross-encoder/nli-deberta-v3-base") | |
| ] | |
| # Liste des modèles d'embedding | |
| embedding_models = [ | |
| ("paraphrase-MiniLM-L6-v2", "paraphrase-MiniLM-L6-v2"), | |
| ("all-MiniLM-L6-v2", "all-MiniLM-L6-v2"), | |
| ("nomic-embed-text-v1", "nomic-ai/nomic-embed-text-v1"), | |
| ("embeddings-ft", "./embeddings-ft") | |
| ] | |
| class WebrankingsPlugin(Plugin): | |
| def __init__(self, name, plugin_manager): | |
| super().__init__(name, plugin_manager) | |
| self.scansite_plugin = ScansitePlugin('scansite', plugin_manager) | |
| def get_tabs(self): | |
| return [ | |
| {"name": t("webrankings_title"), "plugin": "webrankings"} | |
| ] | |
| def run(self, config): | |
| tab1, tab2, tab3 = st.tabs([t("webrankings_title"), t("finetune_bart_title"), t("finetune_embeddings_title")]) | |
| reference_data_valid, reference_data_rejected = self.scansite_plugin.get_reference_data() | |
| reference_data = reference_data_valid + [(url, title, 0) for url, title in reference_data_rejected] | |
| with tab1: | |
| st.title(t("webrankings_title")) | |
| if st.button(t("clear_memory")): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| clear_globals() | |
| reset_cuda_context() | |
| topic = st.text_input(t("enter_topic"), value="longevity, health, healthspan, lifespan") | |
| use_synonyms = st.checkbox(t("use_keyword_expansion"), value=False) | |
| check_content = st.checkbox(t("test_content"), value=False) | |
| selected_llm_models = st.multiselect(t("select_llm_models"), llm_models, default=llm_models) | |
| selected_zero_shot_models = st.multiselect(t("select_zero_shot_models"), [m[0] for m in zero_shot_models], default=[m[0] for m in zero_shot_models]) | |
| selected_embedding_models = st.multiselect(t("select_embedding_models"), [m[0] for m in embedding_models], default=[m[0] for m in embedding_models]) | |
| if st.button(t("analyze_button")): | |
| with st.spinner(t("loading_models")): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Préparation des modèles | |
| zero_shot_classifiers = {name: pipeline("zero-shot-classification", model=model, device=device) | |
| for name, model in zero_shot_models if name in selected_zero_shot_models} | |
| embedding_models_dict = {} | |
| for name, model in embedding_models: | |
| import os | |
| if name == "embeddings-ft": | |
| if os.path.exists('./embeddings-ft'): | |
| embedding_models_dict[name] = SentenceTransformer('./embeddings-ft', trust_remote_code=True).to(device) | |
| else: | |
| embedding_models_dict[name] = SentenceTransformer(model, trust_remote_code=True).to(device) | |
| bert_models = [AutoModel.from_pretrained('bert-base-uncased').to(device)] | |
| tfidf_objects = [TfidfVectorizer()] | |
| #release_vram(zero_shot_classifiers, embedding_models_dict, bert_models, tfidf_objects) | |
| # Expansion des mots-clés (utilisant le premier modèle LLM sélectionné) | |
| if use_synonyms and selected_llm_models: | |
| expanded_query = [] | |
| for word in topic.split(): | |
| expanded_query.extend(expand_keywords_llm(word, llm_model=selected_llm_models[0])) | |
| expanded_query = " ".join(expanded_query) | |
| st.write("Mots-clés étendus :", expanded_query) | |
| else: | |
| expanded_query = topic | |
| start_time = time.time() | |
| # Analyse pour chaque lien | |
| results = [] | |
| for title, link,note in reference_data: | |
| result = analyze_link( | |
| title, link, topic, zero_shot_classifiers, embedding_models_dict, | |
| expanded_query, selected_llm_models, check_content | |
| ) | |
| if result is not None: | |
| results.append(result) | |
| end_time = time.time() | |
| # Libération de la mémoire VRAM et des autres ressources | |
| release_vram(zero_shot_classifiers, embedding_models_dict, bert_models, tfidf_objects) | |
| # Création du DataFrame avec tous les résultats | |
| df = pd.DataFrame(results) | |
| print(f"Analyse terminée en {end_time - start_time:.2f} secondes") | |
| st.success(t("analysis_completed").format(end_time - start_time)) | |
| # Évaluation et affichage des résultats | |
| evaluation_results = {} | |
| optimal_thresholds = {} | |
| for column in df.columns: | |
| if column != "Titre": | |
| method_scores = df.set_index("Titre")[column].to_dict() | |
| optimal_threshold = find_optimal_threshold( | |
| [item[0] for item in reference_data_valid], | |
| [item[0] for item in reference_data_rejected], | |
| method_scores | |
| ) | |
| optimal_thresholds[column] = optimal_threshold | |
| evaluation_results[column] = evaluate_ranking( | |
| [item[0] for item in reference_data_valid], | |
| [item[0] for item in reference_data_rejected], | |
| method_scores, | |
| optimal_threshold, False | |
| ) | |
| # Affichage des résultats | |
| st.write(t("evaluation_results")) | |
| eval_df = pd.DataFrame(evaluation_results).T | |
| st.dataframe(eval_df) | |
| st.subheader(t("summary_table")) | |
| st.dataframe(df) | |
| st.write(t("optimal_thresholds")) | |
| st.json(optimal_thresholds) | |
| # Graphique de comparaison des corrélations de Spearman | |
| spearman_scores = [results['spearman_correlation'] for results in evaluation_results.values()] | |
| plt.figure(figsize=(15, 8)) | |
| plt.bar(evaluation_results.keys(), spearman_scores) | |
| plt.title(t("spearman_comparison")) | |
| plt.xlabel(t("methods")) | |
| plt.ylabel(t("spearman_correlation")) | |
| plt.xticks(rotation=90, ha='right') | |
| plt.tight_layout() | |
| st.pyplot(plt) | |
| # Affichage des résultats pour chaque méthode | |
| for column in df.columns: | |
| if column != "Titre": | |
| st.subheader(f"Résultats pour {column}") | |
| df_method = df[["Titre", column]].sort_values(column, ascending=False) | |
| threshold = find_optimal_threshold( | |
| [item[0] for item in reference_data_valid], | |
| [item[0] for item in reference_data_rejected], | |
| df_method.set_index("Titre")[column].to_dict() | |
| ) | |
| df_method = df_method[df_method[column] > threshold] | |
| st.dataframe(df_method) | |
| with tab2: | |
| st.title(t("finetune_bart_title")) | |
| num_epochs = st.number_input("Nombre d'époques", min_value=1, max_value=10, value=2) | |
| lr = st.number_input("Learning Rate", min_value=1e-6, max_value=1e-1, value=2e-5, format="%.6f", step=1e-5) | |
| weight_decay = st.number_input("Poids de Décroissance (Weight Decay)", min_value=0.0, max_value=0.1, value=0.01, step=0.005) | |
| batch_size = st.number_input("Taille du Batch", min_value=1, max_value=16, value=1) | |
| start = st.slider("Score initial des données valides", min_value=0.0, max_value=1.0, value=0.9, step=0.01) | |
| model_name = st.text_input("Nom du modèle", value='facebook/bart-large-mnli') | |
| num_warmup_steps = st.number_input("Nombre d'étapes de Warmup", min_value=0, max_value=100, value=0) | |
| # Bouton pour lancer le fine-tuning | |
| if st.button("Lancer le fine-tuning"): | |
| with st.spinner("Fine-tuning en cours..."): | |
| finetune_bart(num_epochs=num_epochs, lr=lr, weight_decay=weight_decay, | |
| batch_size=batch_size, model_name=model_name, output_model='./bart-large-ft', | |
| num_warmup_steps=num_warmup_steps) | |
| st.success("Fine-tuning terminé et modèle sauvegardé.") | |
| with tab3: | |
| st.title(t("finetune_embeddings_title")) | |
| num_epochs_emb = st.number_input("Nombre d'époques (Embeddings)", min_value=1, max_value=100, value=10) | |
| lr_emb = st.number_input("Learning Rate (Embeddings)", min_value=1e-6, max_value=1e-1, value=2e-5, format="%.6f", step=5e-6) | |
| weight_decay_emb = st.number_input("Poids de Décroissance (Weight Decay) (Embeddings)", min_value=0.0, max_value=0.1, value=0.01, step=0.005) | |
| batch_size_emb = st.number_input("Taille du Batch (Embeddings)", min_value=1, max_value=32, value=16) | |
| start_emb = st.slider("Score initial des données valides (Embeddings)", min_value=0.0, max_value=1.0, value=0.9, step=0.01) | |
| model_name_emb = st.selectbox("Modèle d'embeddings de base", ["nomic-ai/nomic-embed-text-v1", "all-MiniLM-L6-v2", "paraphrase-MiniLM-L6-v2"]) | |
| margin_erb = st.slider("Marge (Embeddings)", min_value=0.0, max_value=1.0, value=0.5, step=0.01) | |
| # Bouton pour lancer le fine-tuning des embeddings | |
| if st.button("Lancer le fine-tuning des embeddings"): | |
| with st.spinner("Fine-tuning des embeddings en cours..."): | |
| finetune_embeddings(model_name=model_name_emb, output_model_name="./embeddings-ft", | |
| num_epochs=num_epochs_emb, | |
| learning_rate=lr_emb, | |
| weight_decay=weight_decay_emb, | |
| batch_size=batch_size_emb, | |
| ) | |
| st.success("Fine-tuning des embeddings terminé et modèle sauvegardé.") | |
| # Affichage de l'information sur le dispositif utilisé | |
| device = "GPU (CUDA)" if torch.cuda.is_available() else "CPU" | |
| st.sidebar.info(t("device_info").format(device)) | |