mahesh1209 commited on
Commit
64a69e3
·
verified ·
1 Parent(s): 32939ef

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from flask import Flask, request, jsonify, render_template
6
+
7
+ API_KEY = os.getenv("API_KEY", "mysecretkey")
8
+
9
+ stories = [
10
+ "The sun rose over the quiet village.",
11
+ "A cat chased a butterfly through the garden.",
12
+ "Rain tapped gently on the windowpane.",
13
+ "Children laughed and played in the park.",
14
+ "The old man told stories by the fire."
15
+ ]
16
+ text = " ".join(stories).lower()
17
+ tokens = text.split()
18
+ vocab = sorted(set(tokens))
19
+ word2idx = {w: i for i, w in enumerate(vocab)}
20
+ encoded = [word2idx[w] for w in tokens]
21
+
22
+ seq_len = 4
23
+ inputs, targets = [], []
24
+ for i in range(len(encoded) - seq_len):
25
+ inputs.append(encoded[i:i+seq_len])
26
+ targets.append(encoded[i+1:i+seq_len+1])
27
+ inputs = torch.tensor(inputs)
28
+ targets = torch.tensor(targets)
29
+
30
+ class SLM(nn.Module):
31
+ def __init__(self, vocab_size, embed_dim):
32
+ super().__init__()
33
+ self.embed = nn.Embedding(vocab_size, embed_dim)
34
+ self.fc = nn.Linear(embed_dim * seq_len, vocab_size)
35
+ def forward(self, x):
36
+ x = self.embed(x).view(x.size(0), -1)
37
+ return self.fc(x)
38
+
39
+ loss_fn = nn.CrossEntropyLoss()
40
+ model = SLM(len(vocab), embed_dim=10)
41
+ optimizer = optim.Adam(model.parameters(), lr=0.01)
42
+
43
+ for epoch in range(50): # reduced for speed
44
+ model.train()
45
+ optimizer.zero_grad()
46
+ output = model(inputs)
47
+ loss = loss_fn(output, targets[:, -1])
48
+ loss.backward()
49
+ optimizer.step()
50
+
51
+ def generate_sentence(phrase):
52
+ phrase = phrase.strip().lower()
53
+ phrase_words = phrase.split()
54
+ if len(phrase_words) < 2:
55
+ return "Please enter at least 2 words from a sentence."
56
+ for story in stories:
57
+ story_text = story.lower()
58
+ if all(word in story_text for word in phrase_words):
59
+ return story
60
+ return f'No match found for "{phrase}". Try phrases like: quiet village, chased butterfly, laughed played, told stories'
61
+
62
+ app = Flask(__name__, template_folder="templates")
63
+
64
+ @app.route("/", methods=["GET"])
65
+ def home():
66
+ return render_template("index.html")
67
+
68
+ @app.route("/generate", methods=["POST"])
69
+ def generate():
70
+ key = request.headers.get("x-api-key")
71
+ if key != API_KEY:
72
+ return jsonify({"error": "Unauthorized"}), 401
73
+
74
+ data = request.get_json()
75
+ phrase = data.get("phrase", "")
76
+ result = generate_sentence(phrase)
77
+ return jsonify({"input": phrase, "output": result})
78
+
79
+ # Proxy route: frontend calls this, backend injects API key
80
+ @app.route("/frontend_generate", methods=["POST"])
81
+ def frontend_generate():
82
+ data = request.get_json()
83
+ phrase = data.get("phrase", "")
84
+ # Internally call the same logic as /generate, but skip key check
85
+ result = generate_sentence(phrase)
86
+ return jsonify({"input": phrase, "output": result})
87
+
88
+ if __name__ == "__main__":
89
+ app.run(host="0.0.0.0", port=7860)