rahul7star commited on
Commit
8143e5c
Β·
verified Β·
1 Parent(s): 6c0c98e

Update app_flash1.py

Browse files
Files changed (1) hide show
  1. app_flash1.py +11 -25
app_flash1.py CHANGED
@@ -142,41 +142,27 @@ def train_flashpack_model(dataset_name="rahul7star/prompt-enhancer-dataset",
142
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
143
  print(f"πŸ” Checking for model in repo: {hf_repo}")
144
  local_model_path = "model.flashpack"
145
- local_mapping_path = "text_mapping.pkl"
146
 
147
- if os.path.exists(local_model_path) and os.path.exists(local_mapping_path):
148
- print("βœ… Loading local model and mapping")
149
  else:
150
- files = list_repo_files(hf_repo)
151
- if "model.flashpack" in files:
152
- print("βœ… Downloading model from HF")
153
- local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
154
- if "text_mapping.pkl" in files:
155
- print("βœ… Downloading text mapping from HF")
156
- local_mapping_path = hf_hub_download(repo_id=hf_repo, filename="text_mapping.pkl")
157
-
158
- # Load model
159
  model = GemmaTrainer().from_flashpack(local_model_path)
160
  model.eval()
161
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
162
- # Load mapping
163
- with open(local_mapping_path, "rb") as f:
164
- mapping = pickle.load(f)
165
- short_texts, long_texts = mapping["short"], mapping["long"]
166
- short_embs = torch.vstack([encode_fn(s) for s in short_texts])
167
 
168
- # Enhance function
169
  @torch.no_grad()
170
  def enhance_fn(prompt, chat):
171
  chat = chat or []
172
- query_emb = encode_fn(prompt)
173
- mapped = model(query_emb.to(device)).cpu()
174
- # Compute cosine similarity to all stored long embeddings
175
- sims = torch.nn.functional.cosine_similarity(mapped, short_embs)
176
- best_idx = int(sims.argmax())
177
- best_long_prompt = long_texts[best_idx]
178
  chat.append({"role": "user", "content": prompt})
179
- chat.append({"role": "assistant", "content": best_long_prompt})
180
  return chat
181
 
182
  return model, tokenizer, embed_model, enhance_fn
 
142
  def get_flashpack_model(hf_repo="rahul7star/FlashPack"):
143
  print(f"πŸ” Checking for model in repo: {hf_repo}")
144
  local_model_path = "model.flashpack"
 
145
 
146
+ if os.path.exists(local_model_path):
147
+ print("βœ… Loading local model")
148
  else:
149
+ print("βœ… Downloading model from HF")
150
+ local_model_path = hf_hub_download(repo_id=hf_repo, filename="model.flashpack")
151
+
 
 
 
 
 
 
152
  model = GemmaTrainer().from_flashpack(local_model_path)
153
  model.eval()
154
  tokenizer, embed_model, encode_fn = build_encoder("gpt2")
 
 
 
 
 
155
 
 
156
  @torch.no_grad()
157
  def enhance_fn(prompt, chat):
158
  chat = chat or []
159
+ short_emb = encode_fn(prompt)
160
+ mapped = model(short_emb.to(device)).cpu()
161
+ # convert mapped tensor into a string (this can be learned in training)
162
+ # For demonstration, we just return a placeholder
163
+ long_prompt = f"Enhanced long prompt for: {prompt}" # replace with your model's actual decoding if available
 
164
  chat.append({"role": "user", "content": prompt})
165
+ chat.append({"role": "assistant", "content": long_prompt})
166
  return chat
167
 
168
  return model, tokenizer, embed_model, enhance_fn