File size: 5,065 Bytes
43d0740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
import plotly.graph_objects as go
import plotly.express as px
from datetime import datetime
import requests
from collections import defaultdict

HF_DATASET_API = '/static-proxy?url=https%3A%2F%2Fdatasets-server.huggingface.co%2Frows%26%23x27%3B%3C%2Fspan%3E
DATASET_NAME = 'shachardon/ShareLM'

def fetch_dataset_sample(max_rows=500):
    """Fetch a sample from the Hugging Face dataset"""
    MAX_BATCH_SIZE = 100
    batches = min((max_rows + MAX_BATCH_SIZE - 1) // MAX_BATCH_SIZE, 10)
    all_rows = []
    
    for i in range(batches):
        offset = i * MAX_BATCH_SIZE
        length = min(MAX_BATCH_SIZE, max_rows - offset)
        
        if length <= 0:
            break
        
        url = f"{HF_DATASET_API}?dataset={DATASET_NAME.replace('/', '%2F')}&config=default&split=train&offset={offset}&length={length}"
        
        try:
            response = requests.get(url, headers={'Accept': 'application/json'}, timeout=25)
            response.raise_for_status()
            data = response.json()
            
            if data.get('rows') and isinstance(data['rows'], list):
                all_rows.extend(data['rows'])
            
            # Small delay to avoid rate limiting
            if i < batches - 1:
                import time
                time.sleep(0.1)
        except Exception as e:
            print(f"Error fetching batch {i}: {e}")
            if i == 0:
                raise
    
    return all_rows

def process_data():
    """Process dataset and return charts"""
    try:
        rows = fetch_dataset_sample(500)
        
        if not rows:
            return None, "No data fetched. Please try again."
        
        source_counts = defaultdict(int)
        time_series = defaultdict(int)
        
        for row in rows:
            row_data = row.get('row', row) if isinstance(row, dict) else row
            
            # Count by source
            source = row_data.get('source', 'unknown')
            source_counts[source] += 1
            
            # Count by date
            if 'timestamp' in row_data:
                try:
                    date = datetime.fromisoformat(str(row_data['timestamp']).replace('Z', '+00:00'))
                    date_key = date.strftime('%Y-%m-%d')
                    time_series[date_key] += 1
                except:
                    pass
        
        # Create source breakdown pie chart
        if source_counts:
            sources = list(source_counts.keys())
            values = list(source_counts.values())
            
            fig_pie = go.Figure(data=[go.Pie(
                labels=sources,
                values=values,
                hole=0.4,
                textinfo='label+percent',
                textposition='outside'
            )])
            fig_pie.update_layout(
                title="Source Breakdown",
                height=500,
                showlegend=True
            )
        else:
            fig_pie = None
        
        # Create time series chart
        if time_series:
            sorted_dates = sorted(time_series.keys())
            counts = [time_series[date] for date in sorted_dates]
            
            fig_line = go.Figure()
            fig_line.add_trace(go.Scatter(
                x=sorted_dates,
                y=counts,
                mode='lines+markers',
                name='Conversations',
                line=dict(width=2)
            ))
            fig_line.update_layout(
                title="Total Count Over Time",
                xaxis_title="Date",
                yaxis_title="Count",
                height=500,
                hovermode='x unified'
            )
        else:
            fig_line = None
        
        total = sum(source_counts.values())
        info = f"Processed {len(rows)} rows\nTotal conversations: {total:,}\nSources: {len(source_counts)}\nTime points: {len(time_series)}"
        
        return (fig_pie, fig_line, info)
        
    except Exception as e:
        return (None, None, f"Error: {str(e)}")

def create_interface():
    """Create the Gradio interface"""
    with gr.Blocks(title="ShareLM Dataset Analysis", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# ShareLM Dataset Analysis")
        gr.Markdown("Analyzing conversations from the ShareLM Hugging Face dataset")
        
        with gr.Row():
            btn = gr.Button("Load & Analyze Data", variant="primary")
        
        with gr.Row():
            with gr.Column():
                pie_chart = gr.Plot(label="Source Breakdown")
            with gr.Column():
                line_chart = gr.Plot(label="Time Series")
        
        info_text = gr.Textbox(label="Statistics", lines=4, interactive=False)
        
        btn.click(
            fn=process_data,
            outputs=[pie_chart, line_chart, info_text]
        )
        
        # Load data on startup
        demo.load(
            fn=process_data,
            outputs=[pie_chart, line_chart, info_text]
        )
    
    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch()