gaeunseo commited on
Commit
9cba7ea
·
verified ·
1 Parent(s): 7cc2d09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -44
app.py CHANGED
@@ -11,21 +11,22 @@ DATA_FILE = "global_data.csv"
11
  data_lock = threading.Lock()
12
 
13
  def initialize_global_data():
 
 
 
 
14
  if not os.path.exists(DATA_FILE):
15
  ds = load_dataset("gaeunseo/Taskmaster_sample_data", split="train")
16
- data = ds.to_pandas()
17
- if "used" not in data.columns:
18
- data["used"] = False
19
- if "overlapping" not in data.columns:
20
- data["overlapping"] = ""
21
- if "text" not in data.columns:
22
- data["text"] = ""
23
- data.to_csv(DATA_FILE, index=False)
24
  return data
25
  else:
26
  with data_lock:
27
- df = pd.read_csv(DATA_FILE)
28
- return df
 
29
 
30
  def load_global_data():
31
  with data_lock:
@@ -39,23 +40,37 @@ def save_global_data(df):
39
  global_data = initialize_global_data()
40
 
41
  def get_random_row_from_dataset():
 
 
 
 
 
 
 
 
42
  global global_data
43
- global_data = load_global_data() # 최신 데이터 로드
44
- groups = global_data.groupby('conversation_id')
45
- valid_groups = []
46
- for cid, group in groups:
47
- if group['used'].apply(lambda x: bool(x) == False).all() and (group['overlapping'] == "TT").any():
48
- valid_groups.append((cid, group))
 
 
 
 
 
49
  if not valid_groups:
50
  return None
51
- chosen_cid, chosen_group = random.choice(valid_groups)
52
- global_data.loc[global_data['conversation_id'] == chosen_cid, 'used'] = True
 
 
 
 
 
53
  save_global_data(global_data)
54
- chosen_rows = chosen_group[chosen_group['overlapping'] == "TT"]
55
- if chosen_rows.empty:
56
- return None
57
- chosen_row = chosen_rows.iloc[0]
58
- return chosen_row.to_dict()
59
 
60
  row = get_random_row_from_dataset()
61
  if row is None:
@@ -71,6 +86,9 @@ else:
71
  #############################################
72
 
73
  def get_initial_human_html():
 
 
 
74
  wrapper_start = (
75
  """<div class="human-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-end; gap: 5px; width: 100%;">"""
76
  )
@@ -81,6 +99,10 @@ def get_initial_human_html():
81
  return wrapper_start + bubble_start + bubble_end + emoji_html + wrapper_end
82
 
83
  def stream_human_message():
 
 
 
 
84
  bubble_content = ""
85
  wrapper_start = (
86
  """<div class="human-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-end; gap: 5px; width: 100%;">"""
@@ -90,10 +112,10 @@ def stream_human_message():
90
  emoji_html = "<div class='emoji'>🧑</div>"
91
  wrapper_end = "</div>"
92
 
93
- # 초기 상태
94
  yield wrapper_start + bubble_start + bubble_end + emoji_html + wrapper_end
95
 
96
- # 한 글자씩 타이핑 효과 적용
97
  for i, ch in enumerate(human_message):
98
  bubble_content += f"<span data-index='{i}'>{ch}</span>"
99
  current_html = wrapper_start + bubble_start + bubble_content + bubble_end + emoji_html + wrapper_end
@@ -101,19 +123,25 @@ def stream_human_message():
101
  time.sleep(0.05)
102
 
103
  def submit_edit(edited_text):
104
- global human_message, ai_message
105
- data = load_global_data()
 
 
 
 
 
 
106
  new_row = {
107
- "conversation_id": "edited_" + str(random.randint(1000, 9999)),
108
  "used": False,
109
  "overlapping": "",
110
  "text": edited_text,
111
  "human_message": edited_text,
112
  "ai_message": ""
113
  }
114
- new_df = pd.DataFrame([new_row])
115
- data = pd.concat([data, new_df], ignore_index=True)
116
- save_global_data(data)
117
 
118
  new_row_data = get_random_row_from_dataset()
119
  if new_row_data is None:
@@ -143,22 +171,17 @@ with gr.Blocks() as demo:
143
  """
144
  <script>
145
  document.addEventListener("click", function(event) {
146
- // human 메시지의 span을 클릭하면 처리
147
  if (event.target && event.target.matches("div.speech-bubble.human span[data-index]")) {
148
  var span = event.target;
149
  var container = span.closest("div.speech-bubble.human");
150
- // 이전에 추가된 ✂️ 아이콘 제거
151
  var oldScissors = container.querySelectorAll("span.scissor");
152
  oldScissors.forEach(function(s) { s.remove(); });
153
- // 모든 span의 색상 초기화
154
  var spans = container.querySelectorAll("span[data-index]");
155
  spans.forEach(function(s) { s.style.color = ''; });
156
- // ✂️ 아이콘 생성 후 span 바로 뒤에 삽입
157
  var scissor = document.createElement('span');
158
  scissor.textContent = '✂️';
159
  scissor.classList.add("scissor");
160
  container.insertBefore(scissor, span.nextSibling);
161
- // 클릭한 span의 data-index를 기준으로 그 뒤 텍스트 회색 처리
162
  var cutIndex = parseInt(span.getAttribute("data-index"));
163
  spans.forEach(function(s) {
164
  var idx = parseInt(s.getAttribute("data-index"));
@@ -166,20 +189,31 @@ with gr.Blocks() as demo:
166
  s.style.color = "grey";
167
  }
168
  });
169
- // 여기서 바로 human 메시지에서 ✂️ 이전 텍스트를 hidden input에 업데이트
170
- var edited_text = container.innerText.split("✂️")[0];
171
- var hiddenInput = document.getElementById("edited_text_input");
172
- if(hiddenInput) {
173
- hiddenInput.value = edited_text;
174
- }
175
  }
176
  });
177
  </script>
178
  """
179
  )
180
 
181
- # (B) 이상 submit 클릭 별도의 스크립트로 업데이트할 필요 없음
182
- # 원래 있던 DOMContentLoaded 스크립트는 제거합니다.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  # (C) CSS 스타일
185
  gr.HTML(
 
11
  data_lock = threading.Lock()
12
 
13
  def initialize_global_data():
14
+ """
15
+ DATA_FILE이 존재하지 않으면, Dataset을 로드하여 파일에 저장합니다.
16
+ 이미 파일이 있으면 파일에서 데이터를 읽어 반환합니다.
17
+ """
18
  if not os.path.exists(DATA_FILE):
19
  ds = load_dataset("gaeunseo/Taskmaster_sample_data", split="train")
20
+ data = list(ds)
21
+ with data_lock:
22
+ with open(DATA_FILE, "w", encoding="utf-8") as f:
23
+ json.dump(data, f, ensure_ascii=False, indent=2)
 
 
 
 
24
  return data
25
  else:
26
  with data_lock:
27
+ with open(DATA_FILE, "r", encoding="utf-8") as f:
28
+ data = json.load(f)
29
+ return data
30
 
31
  def load_global_data():
32
  with data_lock:
 
40
  global_data = initialize_global_data()
41
 
42
  def get_random_row_from_dataset():
43
+ """
44
+ DATA_FILE에 저장된 global_data에서,
45
+ conversation_id별로 그룹화한 후,
46
+ - 모든 행의 used 컬럼이 False인 그룹이고,
47
+ - 그룹 내에 overlapping 컬럼이 "TT"인 행이 존재하는 그룹들 중에서
48
+ 랜덤하게 하나의 그룹을 선택하고, 해당 그룹 내에서 overlapping 컬럼이 "TT"인 행을 선택하여 반환합니다.
49
+ 반환 전에 해당 행의 used 값을 True로 업데이트하고 파일에 저장합니다.
50
+ """
51
  global global_data
52
+ global_data = load_global_data()
53
+ # conversation_id별 그룹화
54
+ conversation_groups = {}
55
+ for row in global_data:
56
+ cid = row["conversation_id"]
57
+ conversation_groups.setdefault(cid, []).append(row)
58
+ # 조건에 맞는 그룹 필터링
59
+ valid_groups = [
60
+ group for group in conversation_groups.values()
61
+ if all(not r["used"] for r in group) and any(r["overlapping"] == "TT" for r in group)
62
+ ]
63
  if not valid_groups:
64
  return None
65
+ chosen_group = random.choice(valid_groups)
66
+ chosen_row = None
67
+ for row in chosen_group:
68
+ if row["overlapping"] == "TT":
69
+ row["used"] = True # 업데이트
70
+ chosen_row = row
71
+ break
72
  save_global_data(global_data)
73
+ return chosen_row
 
 
 
 
74
 
75
  row = get_random_row_from_dataset()
76
  if row is None:
 
86
  #############################################
87
 
88
  def get_initial_human_html():
89
+ """
90
+ 페이지 로드 시, 빈 Human 말풍선과 오른쪽 🧑 이모티콘을 포함한 초기 HTML 반환
91
+ """
92
  wrapper_start = (
93
  """<div class="human-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-end; gap: 5px; width: 100%;">"""
94
  )
 
99
  return wrapper_start + bubble_start + bubble_end + emoji_html + wrapper_end
100
 
101
  def stream_human_message():
102
+ """
103
+ Start Typing 버튼 클릭 시, 전역 변수 human_message의 내용을 한 글자씩 타이핑 효과로 출력.
104
+ 이전 상태(✂️ 아이콘, 회색 처리 등)는 리셋됩니다.
105
+ """
106
  bubble_content = ""
107
  wrapper_start = (
108
  """<div class="human-wrapper" style="display: flex; align-items: flex-end; justify-content: flex-end; gap: 5px; width: 100%;">"""
 
112
  emoji_html = "<div class='emoji'>🧑</div>"
113
  wrapper_end = "</div>"
114
 
115
+ # 초기 상태: 빈 말풍선과 이모티콘
116
  yield wrapper_start + bubble_start + bubble_end + emoji_html + wrapper_end
117
 
118
+ # 한 글자씩 추가 (타이핑 효과)
119
  for i, ch in enumerate(human_message):
120
  bubble_content += f"<span data-index='{i}'>{ch}</span>"
121
  current_html = wrapper_start + bubble_start + bubble_content + bubble_end + emoji_html + wrapper_end
 
123
  time.sleep(0.05)
124
 
125
  def submit_edit(edited_text):
126
+ """
127
+ Submit 버튼 클릭 시 호출되는 함수.
128
+ 1. 편집된 human 메시지(✂️ 앞부분)를 새 행으로 global_data에 추가합니다.
129
+ 2. get_random_row_from_dataset()을 통해 새로운 대화를 가져오고, 전역 변수 human_message와 ai_message를 업데이트합니다.
130
+ 3. 초기 상태의 human 말풍선와 ai 말풍선 HTML을 반환하여 인터페이스를 리셋합니다.
131
+ """
132
+ global global_data, human_message, ai_message
133
+ # 새 행 생성 (새 conversation_id는 임의로 생성)
134
  new_row = {
135
+ "conversation_id": "edited_" + str(random.randint(1000,9999)),
136
  "used": False,
137
  "overlapping": "",
138
  "text": edited_text,
139
  "human_message": edited_text,
140
  "ai_message": ""
141
  }
142
+ global_data = load_global_data()
143
+ global_data.append(new_row)
144
+ save_global_data(global_data)
145
 
146
  new_row_data = get_random_row_from_dataset()
147
  if new_row_data is None:
 
171
  """
172
  <script>
173
  document.addEventListener("click", function(event) {
 
174
  if (event.target && event.target.matches("div.speech-bubble.human span[data-index]")) {
175
  var span = event.target;
176
  var container = span.closest("div.speech-bubble.human");
 
177
  var oldScissors = container.querySelectorAll("span.scissor");
178
  oldScissors.forEach(function(s) { s.remove(); });
 
179
  var spans = container.querySelectorAll("span[data-index]");
180
  spans.forEach(function(s) { s.style.color = ''; });
 
181
  var scissor = document.createElement('span');
182
  scissor.textContent = '✂️';
183
  scissor.classList.add("scissor");
184
  container.insertBefore(scissor, span.nextSibling);
 
185
  var cutIndex = parseInt(span.getAttribute("data-index"));
186
  spans.forEach(function(s) {
187
  var idx = parseInt(s.getAttribute("data-index"));
 
189
  s.style.color = "grey";
190
  }
191
  });
 
 
 
 
 
 
192
  }
193
  });
194
  </script>
195
  """
196
  )
197
 
198
+ # (B) 추가 스크립트: Submit 버튼 클릭 시, human_message div의 innerText에서 "✂️"를 기준으로 편집된 텍스트(앞부분)를 숨김 텍스트박스에 업데이트
199
+ gr.HTML(
200
+ """
201
+ <script>
202
+ document.addEventListener("DOMContentLoaded", function() {
203
+ var submitBtn = document.getElementById("submit_btn");
204
+ if(submitBtn){
205
+ submitBtn.addEventListener("click", function(){
206
+ var humanDiv = document.getElementById("human_message");
207
+ if(humanDiv){
208
+ var edited_text = humanDiv.innerText.split("✂️")[0];
209
+ document.getElementById("edited_text_input").value = edited_text;
210
+ }
211
+ });
212
+ }
213
+ });
214
+ </script>
215
+ """
216
+ )
217
 
218
  # (C) CSS 스타일
219
  gr.HTML(