omniverse1 commited on
Commit
3d7d95e
·
verified ·
1 Parent(s): b19777d

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +32 -119
utils.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import torch
5
  from datetime import datetime, timedelta
6
  import plotly.graph_objects as go
7
- import plotly.express as px
8
  from plotly.subplots import make_subplots
9
  import spaces
10
 
@@ -41,10 +40,7 @@ def calculate_technical_indicators(data):
41
  rs = gain / loss
42
  rsi = 100 - (100 / (1 + rs))
43
  return rsi
44
- indicators['rsi'] = {
45
- 'current': calculate_rsi(data['Close']).iloc[-1],
46
- 'values': calculate_rsi(data['Close'])
47
- }
48
  def calculate_macd(prices, fast=12, slow=26, signal=9):
49
  exp1 = prices.ewm(span=fast).mean()
50
  exp2 = prices.ewm(span=slow).mean()
@@ -53,14 +49,7 @@ def calculate_technical_indicators(data):
53
  histogram = macd - signal_line
54
  return macd, signal_line, histogram
55
  macd, signal_line, histogram = calculate_macd(data['Close'])
56
- indicators['macd'] = {
57
- 'macd': macd.iloc[-1],
58
- 'signal': signal_line.iloc[-1],
59
- 'histogram': histogram.iloc[-1],
60
- 'signal_text': 'BUY' if histogram.iloc[-1] > 0 else 'SELL',
61
- 'macd_values': macd,
62
- 'signal_values': signal_line
63
- }
64
  def calculate_bollinger_bands(prices, period=20, std_dev=2):
65
  sma = prices.rolling(window=period).mean()
66
  std = prices.rolling(window=period).std()
@@ -70,28 +59,11 @@ def calculate_technical_indicators(data):
70
  upper, middle, lower = calculate_bollinger_bands(data['Close'])
71
  current_price = data['Close'].iloc[-1]
72
  bb_position = (current_price - lower.iloc[-1]) / (upper.iloc[-1] - lower.iloc[-1])
73
- indicators['bollinger'] = {
74
- 'upper': upper.iloc[-1],
75
- 'middle': middle.iloc[-1],
76
- 'lower': lower.iloc[-1],
77
- 'position': 'UPPER' if bb_position > 0.8 else 'LOWER' if bb_position < 0.2 else 'MIDDLE'
78
- }
79
  sma_20_series = data['Close'].rolling(20).mean()
80
  sma_50_series = data['Close'].rolling(50).mean()
81
- indicators['moving_averages'] = {
82
- 'sma_20': sma_20_series.iloc[-1],
83
- 'sma_50': sma_50_series.iloc[-1],
84
- 'sma_200': data['Close'].rolling(200).mean().iloc[-1],
85
- 'ema_12': data['Close'].ewm(span=12).mean().iloc[-1],
86
- 'ema_26': data['Close'].ewm(span=26).mean().iloc[-1],
87
- 'sma_20_values': sma_20_series,
88
- 'sma_50_values': sma_50_series
89
- }
90
- indicators['volume'] = {
91
- 'current': data['Volume'].iloc[-1],
92
- 'avg_20': data['Volume'].rolling(20).mean().iloc[-1],
93
- 'ratio': data['Volume'].iloc[-1] / data['Volume'].rolling(20).mean().iloc[-1]
94
- }
95
  return indicators
96
 
97
  def generate_trading_signals(data, indicators):
@@ -146,59 +118,21 @@ def generate_trading_signals(data, indicators):
146
  signal_details.append(f"⚪ Normal volume ({volume_ratio:.1f}x avg)")
147
  total_signals = buy_signals + sell_signals
148
  signal_strength = (buy_signals / max(total_signals, 1)) * 100
149
- if buy_signals > sell_signals:
150
- overall_signal = "BUY"
151
- elif sell_signals > buy_signals:
152
- overall_signal = "SELL"
153
- else:
154
- overall_signal = "HOLD"
155
  recent_high = data['High'].tail(20).max()
156
  recent_low = data['Low'].tail(20).min()
157
- signals = {
158
- 'overall': overall_signal,
159
- 'strength': signal_strength,
160
- 'details': '\n'.join(signal_details),
161
- 'support': recent_low,
162
- 'resistance': recent_high,
163
- 'stop_loss': recent_low * 0.95 if overall_signal == "BUY" else recent_high * 1.05
164
- }
165
  return signals
166
 
167
  def get_fundamental_data(stock):
168
  try:
169
  info = stock.info
170
  history = stock.history(period="1d")
171
- fundamental_info = {
172
- 'name': info.get('longName', 'N/A'),
173
- 'current_price': history['Close'].iloc[-1] if not history.empty else 0,
174
- 'market_cap': info.get('marketCap', 0),
175
- 'pe_ratio': info.get('forwardPE', 0),
176
- 'dividend_yield': info.get('dividendYield', 0) * 100 if info.get('dividendYield') else 0,
177
- 'volume': history['Volume'].iloc[-1] if not history.empty else 0,
178
- 'info': f"""
179
- Sector: {info.get('sector', 'N/A')}
180
- Industry: {info.get('industry', 'N/A')}
181
- Market Cap: {format_large_number(info.get('marketCap', 0))}
182
- 52 Week High: {info.get('fiftyTwoWeekHigh', 'N/A')}
183
- 52 Week Low: {info.get('fiftyTwoWeekLow', 'N/A')}
184
- Beta: {info.get('beta', 'N/A')}
185
- EPS: {info.get('forwardEps', 'N/A')}
186
- Book Value: {info.get('bookValue', 'N/A')}
187
- Price to Book: {info.get('priceToBook', 'N/A')}
188
- """.strip()
189
- }
190
  return fundamental_info
191
  except Exception as e:
192
  print(f"Error getting fundamental data: {e}")
193
- return {
194
- 'name': 'N/A',
195
- 'current_price': 0,
196
- 'market_cap': 0,
197
- 'pe_ratio': 0,
198
- 'dividend_yield': 0,
199
- 'volume': 0,
200
- 'info': 'Unable to fetch fundamental data'
201
- }
202
 
203
  def format_large_number(num):
204
  if num >= 1e12:
@@ -218,49 +152,37 @@ def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
218
  prices = data['Close'].values.astype(np.float32)
219
  try:
220
  from chronos import BaseChronosPipeline
221
- except Exception as ie:
222
- return {
223
- 'values': [],
224
- 'dates': [],
225
- 'high_30d': 0,
226
- 'low_30d': 0,
227
- 'mean_30d': 0,
228
- 'change_pct': 0,
229
- 'summary': 'chronos package not installed. install with: pip install chronos-forecasting'
230
- }
231
  pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
232
  with torch.no_grad():
233
  forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
234
- if hasattr(forecast, 'mean'):
235
- mean_forecast = forecast.mean(dim=1).squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
236
  else:
237
- mean_forecast = forecast.mean(dim=1).squeeze().cpu().numpy()
238
  pred_len = len(mean_forecast)
 
 
239
  last_price = prices[-1]
240
- predicted_high = np.max(mean_forecast) if pred_len > 0 else 0
241
- predicted_low = np.min(mean_forecast) if pred_len > 0 else 0
242
- predicted_mean = np.mean(mean_forecast) if pred_len > 0 else 0
243
- change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 and pred_len > 0 else 0
244
- return {
245
- 'values': mean_forecast,
246
- 'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=pred_len, freq='D') if pred_len > 0 else [],
247
- 'high_30d': predicted_high,
248
- 'low_30d': predicted_low,
249
- 'mean_30d': predicted_mean,
250
- 'change_pct': change_pct,
251
- 'summary': f"AI Model: Amazon Chronos-Bolt (Base)\nPrediction Period: {pred_len} days\nExpected Change: {change_pct:.2f}%\nConfidence: Medium\nNote: AI predictions are for reference only and not financial advice"
252
- }
253
  except Exception as e:
254
  print(f"Error in prediction: {e}")
255
- return {
256
- 'values': [],
257
- 'dates': [],
258
- 'high_30d': 0,
259
- 'low_30d': 0,
260
- 'mean_30d': 0,
261
- 'change_pct': 0,
262
- 'summary': f'Prediction unavailable due to model error: {e}'
263
- }
264
 
265
  def create_price_chart(data, indicators):
266
  fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.05, subplot_titles=('Price & Moving Averages', 'RSI', 'MACD'), row_width=[0.2, 0.2, 0.7])
@@ -275,15 +197,6 @@ def create_price_chart(data, indicators):
275
  fig.update_layout(title='Technical Analysis Dashboard', height=900, showlegend=True, xaxis_rangeslider_visible=False)
276
  return fig
277
 
278
- def create_technical_chart(data, indicators):
279
- fig = make_subplots(rows=2, cols=2, subplot_titles=('Bollinger Bands', 'Volume', 'Price vs MA', 'RSI Analysis'), specs=[[{"secondary_y": False}, {"secondary_y": False}], [{"secondary_y": False}, {"secondary_y": False}]])
280
- fig.add_trace(go.Scatter(x=data.index, y=data['Close'], name='Price', line=dict(color='black')), row=1, col=1)
281
- fig.add_trace(go.Bar(x=data.index, y=data['Volume'], name='Volume', marker_color='lightblue'), row=1, col=2)
282
- fig.add_trace(go.Scatter(x=data.index, y=data['Close'], name='Price', line=dict(color='black')), row=2, col=1)
283
- fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_20_values'], name='SMA 20', line=dict(color='orange', dash='dash')), row=2, col=1)
284
- fig.update_layout(title='Technical Indicators Overview', height=600, showlegend=False)
285
- return fig
286
-
287
  def create_prediction_chart(data, predictions):
288
  if not len(predictions['values']):
289
  return go.Figure()
 
4
  import torch
5
  from datetime import datetime, timedelta
6
  import plotly.graph_objects as go
 
7
  from plotly.subplots import make_subplots
8
  import spaces
9
 
 
40
  rs = gain / loss
41
  rsi = 100 - (100 / (1 + rs))
42
  return rsi
43
+ indicators['rsi'] = {'current': calculate_rsi(data['Close']).iloc[-1], 'values': calculate_rsi(data['Close'])}
 
 
 
44
  def calculate_macd(prices, fast=12, slow=26, signal=9):
45
  exp1 = prices.ewm(span=fast).mean()
46
  exp2 = prices.ewm(span=slow).mean()
 
49
  histogram = macd - signal_line
50
  return macd, signal_line, histogram
51
  macd, signal_line, histogram = calculate_macd(data['Close'])
52
+ indicators['macd'] = {'macd': macd.iloc[-1], 'signal': signal_line.iloc[-1], 'histogram': histogram.iloc[-1], 'signal_text': 'BUY' if histogram.iloc[-1] > 0 else 'SELL', 'macd_values': macd, 'signal_values': signal_line}
 
 
 
 
 
 
 
53
  def calculate_bollinger_bands(prices, period=20, std_dev=2):
54
  sma = prices.rolling(window=period).mean()
55
  std = prices.rolling(window=period).std()
 
59
  upper, middle, lower = calculate_bollinger_bands(data['Close'])
60
  current_price = data['Close'].iloc[-1]
61
  bb_position = (current_price - lower.iloc[-1]) / (upper.iloc[-1] - lower.iloc[-1])
62
+ indicators['bollinger'] = {'upper': upper.iloc[-1], 'middle': middle.iloc[-1], 'lower': lower.iloc[-1], 'position': 'UPPER' if bb_position > 0.8 else 'LOWER' if bb_position < 0.2 else 'MIDDLE'}
 
 
 
 
 
63
  sma_20_series = data['Close'].rolling(20).mean()
64
  sma_50_series = data['Close'].rolling(50).mean()
65
+ indicators['moving_averages'] = {'sma_20': sma_20_series.iloc[-1], 'sma_50': sma_50_series.iloc[-1], 'sma_200': data['Close'].rolling(200).mean().iloc[-1], 'ema_12': data['Close'].ewm(span=12).mean().iloc[-1], 'ema_26': data['Close'].ewm(span=26).mean().iloc[-1], 'sma_20_values': sma_20_series, 'sma_50_values': sma_50_series}
66
+ indicators['volume'] = {'current': data['Volume'].iloc[-1], 'avg_20': data['Volume'].rolling(20).mean().iloc[-1], 'ratio': data['Volume'].iloc[-1] / data['Volume'].rolling(20).mean().iloc[-1]}
 
 
 
 
 
 
 
 
 
 
 
 
67
  return indicators
68
 
69
  def generate_trading_signals(data, indicators):
 
118
  signal_details.append(f"⚪ Normal volume ({volume_ratio:.1f}x avg)")
119
  total_signals = buy_signals + sell_signals
120
  signal_strength = (buy_signals / max(total_signals, 1)) * 100
121
+ overall_signal = "BUY" if buy_signals > sell_signals else "SELL" if sell_signals > buy_signals else "HOLD"
 
 
 
 
 
122
  recent_high = data['High'].tail(20).max()
123
  recent_low = data['Low'].tail(20).min()
124
+ signals = {'overall': overall_signal, 'strength': signal_strength, 'details': '\n'.join(signal_details), 'support': recent_low, 'resistance': recent_high, 'stop_loss': recent_low * 0.95 if overall_signal == "BUY" else recent_high * 1.05}
 
 
 
 
 
 
 
125
  return signals
126
 
127
  def get_fundamental_data(stock):
128
  try:
129
  info = stock.info
130
  history = stock.history(period="1d")
131
+ fundamental_info = {'name': info.get('longName', 'N/A'), 'current_price': history['Close'].iloc[-1] if not history.empty else 0, 'market_cap': info.get('marketCap', 0), 'pe_ratio': info.get('forwardPE', 0), 'dividend_yield': info.get('dividendYield', 0) * 100 if info.get('dividendYield') else 0, 'volume': history['Volume'].iloc[-1] if not history.empty else 0, 'info': f"Sector: {info.get('sector', 'N/A')}\nIndustry: {info.get('industry', 'N/A')}\nMarket Cap: {format_large_number(info.get('marketCap', 0))}\n52 Week High: {info.get('fiftyTwoWeekHigh', 'N/A')}\n52 Week Low: {info.get('fiftyTwoWeekLow', 'N/A')}\nBeta: {info.get('beta', 'N/A')}\nEPS: {info.get('forwardEps', 'N/A')}\nBook Value: {info.get('bookValue', 'N/A')}\nPrice to Book: {info.get('priceToBook', 'N/A')}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  return fundamental_info
133
  except Exception as e:
134
  print(f"Error getting fundamental data: {e}")
135
+ return {'name': 'N/A', 'current_price': 0, 'market_cap': 0, 'pe_ratio': 0, 'dividend_yield': 0, 'volume': 0, 'info': 'Unable to fetch fundamental data'}
 
 
 
 
 
 
 
 
136
 
137
  def format_large_number(num):
138
  if num >= 1e12:
 
152
  prices = data['Close'].values.astype(np.float32)
153
  try:
154
  from chronos import BaseChronosPipeline
155
+ except Exception:
156
+ return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': 'chronos package not installed. install with: pip install chronos-forecasting'}
 
 
 
 
 
 
 
 
157
  pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
158
  with torch.no_grad():
159
  forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
160
+ if isinstance(forecast, torch.Tensor):
161
+ forecast_np = forecast.squeeze().cpu().numpy()
162
+ elif hasattr(forecast, 'numpy'):
163
+ forecast_np = forecast.numpy()
164
+ else:
165
+ forecast_np = np.array(forecast)
166
+ if forecast_np.ndim == 2:
167
+ mean_forecast = forecast_np.mean(axis=0)
168
+ elif forecast_np.ndim == 3:
169
+ mean_forecast = forecast_np.mean(axis=(0, 1))
170
+ elif forecast_np.ndim == 1:
171
+ mean_forecast = forecast_np
172
  else:
173
+ mean_forecast = np.array([])
174
  pred_len = len(mean_forecast)
175
+ if pred_len == 0:
176
+ return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': 'Model did not return valid prediction output.'}
177
  last_price = prices[-1]
178
+ predicted_high = float(np.max(mean_forecast))
179
+ predicted_low = float(np.min(mean_forecast))
180
+ predicted_mean = float(np.mean(mean_forecast))
181
+ change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
182
+ return {'values': mean_forecast, 'dates': pd.date_range(start=data.index[-1] + timedelta(days=1), periods=pred_len, freq='D'), 'high_30d': predicted_high, 'low_30d': predicted_low, 'mean_30d': predicted_mean, 'change_pct': change_pct, 'summary': f"AI Model: Amazon Chronos-Bolt (Base)\nPrediction Period: {pred_len} days\nPredicted High: {predicted_high:.2f}\nPredicted Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%\nConfidence: Medium\nNote: AI predictions are for reference only and not financial advice"}
 
 
 
 
 
 
 
 
183
  except Exception as e:
184
  print(f"Error in prediction: {e}")
185
+ return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': f'Prediction unavailable due to model error: {e}'}
 
 
 
 
 
 
 
 
186
 
187
  def create_price_chart(data, indicators):
188
  fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.05, subplot_titles=('Price & Moving Averages', 'RSI', 'MACD'), row_width=[0.2, 0.2, 0.7])
 
197
  fig.update_layout(title='Technical Analysis Dashboard', height=900, showlegend=True, xaxis_rangeslider_visible=False)
198
  return fig
199
 
 
 
 
 
 
 
 
 
 
200
  def create_prediction_chart(data, predictions):
201
  if not len(predictions['values']):
202
  return go.Figure()