Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,6 @@ from stable_baselines3 import PPO
|
|
6 |
from gym import Env
|
7 |
from gym.spaces import Box
|
8 |
import matplotlib.pyplot as plt
|
9 |
-
import io
|
10 |
|
11 |
class PortfolioEnv(Env):
|
12 |
def __init__(self, returns, initial_balance=10000):
|
@@ -46,15 +45,16 @@ class PortfolioEnv(Env):
|
|
46 |
return self.state
|
47 |
|
48 |
def fetch_data(tickers, start_date, end_date):
|
49 |
-
|
50 |
-
|
|
|
|
|
|
|
51 |
|
52 |
def optimize_portfolio(tickers, start_date, end_date, initial_balance):
|
53 |
-
# Fetch real-time data
|
54 |
data = fetch_data(tickers, start_date, end_date)
|
55 |
returns = data.pct_change().dropna()
|
56 |
|
57 |
-
# Define the environment
|
58 |
env = PortfolioEnv(returns, initial_balance=initial_balance)
|
59 |
model = PPO("MlpPolicy", env, verbose=0)
|
60 |
model.learn(total_timesteps=5000)
|
@@ -76,21 +76,23 @@ def optimize_portfolio(tickers, start_date, end_date, initial_balance):
|
|
76 |
plt.ylabel("Portfolio Value")
|
77 |
plt.legend()
|
78 |
plt.grid()
|
79 |
-
|
|
|
|
|
|
|
80 |
plt.close()
|
81 |
|
82 |
-
# Prepare the output
|
83 |
weights = {f"Asset_{i + 1} ({tickers.split(',')[i].strip()})": float(weight) for i, weight in enumerate(portfolio_weights)}
|
84 |
-
return weights,
|
85 |
|
86 |
def run_optimization(tickers, start_date, end_date, initial_balance):
|
87 |
if not tickers or not start_date or not end_date:
|
88 |
-
return "
|
89 |
try:
|
90 |
-
weights,
|
91 |
-
return weights,
|
92 |
except Exception as e:
|
93 |
-
return
|
94 |
|
95 |
interface = gr.Interface(
|
96 |
fn=run_optimization,
|
@@ -105,11 +107,9 @@ interface = gr.Interface(
|
|
105 |
gr.Image(label="Portfolio Value Chart"),
|
106 |
],
|
107 |
title="AI-Powered Portfolio Optimization",
|
108 |
-
description="""
|
109 |
-
|
110 |
-
|
111 |
-
along with a performance chart.
|
112 |
-
"""
|
113 |
)
|
114 |
|
115 |
if __name__ == "__main__":
|
|
|
6 |
from gym import Env
|
7 |
from gym.spaces import Box
|
8 |
import matplotlib.pyplot as plt
|
|
|
9 |
|
10 |
class PortfolioEnv(Env):
|
11 |
def __init__(self, returns, initial_balance=10000):
|
|
|
45 |
return self.state
|
46 |
|
47 |
def fetch_data(tickers, start_date, end_date):
|
48 |
+
try:
|
49 |
+
data = yf.download(tickers, start=start_date, end=end_date)['Adj Close']
|
50 |
+
return data
|
51 |
+
except Exception as e:
|
52 |
+
raise ValueError(f"Failed to fetch data: {e}")
|
53 |
|
54 |
def optimize_portfolio(tickers, start_date, end_date, initial_balance):
|
|
|
55 |
data = fetch_data(tickers, start_date, end_date)
|
56 |
returns = data.pct_change().dropna()
|
57 |
|
|
|
58 |
env = PortfolioEnv(returns, initial_balance=initial_balance)
|
59 |
model = PPO("MlpPolicy", env, verbose=0)
|
60 |
model.learn(total_timesteps=5000)
|
|
|
76 |
plt.ylabel("Portfolio Value")
|
77 |
plt.legend()
|
78 |
plt.grid()
|
79 |
+
|
80 |
+
buffer = io.BytesIO()
|
81 |
+
plt.savefig(buffer, format="png")
|
82 |
+
buffer.seek(0)
|
83 |
plt.close()
|
84 |
|
|
|
85 |
weights = {f"Asset_{i + 1} ({tickers.split(',')[i].strip()})": float(weight) for i, weight in enumerate(portfolio_weights)}
|
86 |
+
return weights, buffer
|
87 |
|
88 |
def run_optimization(tickers, start_date, end_date, initial_balance):
|
89 |
if not tickers or not start_date or not end_date:
|
90 |
+
return {"error": "Please fill all the fields."}, None
|
91 |
try:
|
92 |
+
weights, chart_buffer = optimize_portfolio(tickers, start_date, end_date, initial_balance)
|
93 |
+
return weights, chart_buffer
|
94 |
except Exception as e:
|
95 |
+
return {"error": str(e)}, None
|
96 |
|
97 |
interface = gr.Interface(
|
98 |
fn=run_optimization,
|
|
|
107 |
gr.Image(label="Portfolio Value Chart"),
|
108 |
],
|
109 |
title="AI-Powered Portfolio Optimization",
|
110 |
+
description="""Enter stock tickers (e.g., AAPL, MSFT, TSLA), a date range, and your initial investment amount.
|
111 |
+
The app fetches historical data, runs AI optimization, and returns the optimized portfolio weights
|
112 |
+
along with a performance chart."""
|
|
|
|
|
113 |
)
|
114 |
|
115 |
if __name__ == "__main__":
|