borjasoutoprego's picture
Update app.py
713c665 verified
from smolagents import CodeAgent, HfApiModel, load_tool, tool
import pandas as pd
import matplotlib.pyplot as plt
import io
import base64
import yaml
from tools.final_answer import FinalAnswerTool
from Gradio_UI import GradioUI
@tool
def load_and_summarize_data(file: bytes, filename: str) -> str:
"""Loads a CSV/Excel file from user input and generates a summary.
Args:
file: The uploaded file in bytes.
filename: The original file name to determine format.
"""
try:
if filename.endswith('.csv'):
df = pd.read_csv(io.BytesIO(file))
elif filename.endswith('.xlsx'):
df = pd.read_excel(io.BytesIO(file))
else:
return "Unsupported format. Please upload a CSV or Excel file."
summary = df.describe().to_string()
return f"Dataset Summary:\n{summary}"
except Exception as e:
return f"Error processing the file: {str(e)}"
@tool
def generate_basic_charts(file: bytes, filename: str) -> str:
"""Generates basic charts from an uploaded CSV/Excel file.
Args:
file: The uploaded file in bytes.
filename: The original file name to determine format.
"""
try:
if filename.endswith('.csv'):
df = pd.read_csv(io.BytesIO(file))
elif filename.endswith('.xlsx'):
df = pd.read_excel(io.BytesIO(file))
else:
return "Unsupported format. Please upload a CSV or Excel file."
img_paths = []
for column in df.select_dtypes(include=['number']).columns[:3]: # Limit to 3 charts
plt.figure()
df[column].hist(bins=20)
plt.title(f"Distribution of {column}")
plt.xlabel(column)
plt.ylabel("Frequency")
img_buf = io.BytesIO()
plt.savefig(img_buf, format="png")
img_buf.seek(0)
img_str = base64.b64encode(img_buf.getvalue()).decode("utf-8")
img_paths.append(f'<img src="data:image/png;base64,{img_str}" />')
return "Generated Charts:\n" + "\n".join(img_paths)
except Exception as e:
return f"Error generating charts: {str(e)}"
final_answer = FinalAnswerTool()
model = HfApiModel(
max_tokens=2096,
temperature=0.5,
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
custom_role_conversions=None,
)
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[final_answer, load_and_summarize_data, generate_basic_charts], # Updated to support file upload
max_steps=6,
verbosity_level=1,
grammar=None,
planning_interval=None,
name="Data Report Assistant",
description="An agent that processes uploaded CSV/Excel files, generates summaries, and visualizations.",
prompt_templates=prompt_templates
)
GradioUI(agent).launch()