File size: 2,843 Bytes
64a69e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import torch
import torch.nn as nn
import torch.optim as optim
from flask import Flask, request, jsonify, render_template

API_KEY = os.getenv("API_KEY", "mysecretkey")

stories = [
    "The sun rose over the quiet village.",
    "A cat chased a butterfly through the garden.",
    "Rain tapped gently on the windowpane.",
    "Children laughed and played in the park.",
    "The old man told stories by the fire."
]
text = " ".join(stories).lower()
tokens = text.split()
vocab = sorted(set(tokens))
word2idx = {w: i for i, w in enumerate(vocab)}
encoded = [word2idx[w] for w in tokens]

seq_len = 4
inputs, targets = [], []
for i in range(len(encoded) - seq_len):
    inputs.append(encoded[i:i+seq_len])
    targets.append(encoded[i+1:i+seq_len+1])
inputs = torch.tensor(inputs)
targets = torch.tensor(targets)

class SLM(nn.Module):
    def __init__(self, vocab_size, embed_dim):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim * seq_len, vocab_size)
    def forward(self, x):
        x = self.embed(x).view(x.size(0), -1)
        return self.fc(x)

loss_fn = nn.CrossEntropyLoss()
model = SLM(len(vocab), embed_dim=10)
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(50):  # reduced for speed
    model.train()
    optimizer.zero_grad()
    output = model(inputs)
    loss = loss_fn(output, targets[:, -1])
    loss.backward()
    optimizer.step()

def generate_sentence(phrase):
    phrase = phrase.strip().lower()
    phrase_words = phrase.split()
    if len(phrase_words) < 2:
        return "Please enter at least 2 words from a sentence."
    for story in stories:
        story_text = story.lower()
        if all(word in story_text for word in phrase_words):
            return story
    return f'No match found for "{phrase}". Try phrases like: quiet village, chased butterfly, laughed played, told stories'

app = Flask(__name__, template_folder="templates")

@app.route("/", methods=["GET"])
def home():
    return render_template("index.html")

@app.route("/generate", methods=["POST"])
def generate():
    key = request.headers.get("x-api-key")
    if key != API_KEY:
        return jsonify({"error": "Unauthorized"}), 401

    data = request.get_json()
    phrase = data.get("phrase", "")
    result = generate_sentence(phrase)
    return jsonify({"input": phrase, "output": result})

# Proxy route: frontend calls this, backend injects API key
@app.route("/frontend_generate", methods=["POST"])
def frontend_generate():
    data = request.get_json()
    phrase = data.get("phrase", "")
    # Internally call the same logic as /generate, but skip key check
    result = generate_sentence(phrase)
    return jsonify({"input": phrase, "output": result})

if __name__ == "__main__":
    app.run(host="0.0.0.0", port=7860)