|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import yfinance as yf |
|
from stable_baselines3 import PPO |
|
from gym import Env |
|
from gym.spaces import Box |
|
import matplotlib.pyplot as plt |
|
import io |
|
|
|
class PortfolioEnv(Env): |
|
def __init__(self, returns, initial_balance=10000): |
|
super(PortfolioEnv, self).__init__() |
|
self.returns = returns |
|
self.n_assets = returns.shape[1] |
|
self.initial_balance = initial_balance |
|
self.current_balance = initial_balance |
|
|
|
self.action_space = Box(low=0, high=1, shape=(self.n_assets,), dtype=np.float32) |
|
self.observation_space = Box(low=-np.inf, high=np.inf, shape=(self.n_assets,), dtype=np.float32) |
|
|
|
self.weights = np.ones(self.n_assets) / self.n_assets |
|
self.state = self.returns.iloc[0].values |
|
self.current_step = 0 |
|
|
|
def step(self, action): |
|
action = np.clip(action, 0, 1) |
|
weights = action / (np.sum(action) + 1e-8) |
|
|
|
portfolio_return = np.dot(weights, self.returns.iloc[self.current_step].values) |
|
self.current_balance *= (1 + portfolio_return) |
|
|
|
self.current_step += 1 |
|
done = self.current_step >= len(self.returns) - 1 |
|
|
|
if not done: |
|
self.state = self.returns.iloc[self.current_step].values |
|
|
|
reward = self.current_balance |
|
return self.state, reward, done, {} |
|
|
|
def reset(self): |
|
self.current_balance = self.initial_balance |
|
self.current_step = 0 |
|
self.state = self.returns.iloc[0].values |
|
return self.state |
|
|
|
def fetch_data(tickers, start_date, end_date): |
|
try: |
|
data = yf.download(tickers, start=start_date, end=end_date)['Adj Close'] |
|
return data |
|
except Exception as e: |
|
raise ValueError(f"Failed to fetch data: {e}") |
|
|
|
def optimize_portfolio(tickers, start_date, end_date, initial_balance): |
|
|
|
data = fetch_data(tickers, start_date, end_date) |
|
returns = data.pct_change().dropna() |
|
|
|
|
|
env = PortfolioEnv(returns, initial_balance=initial_balance) |
|
model = PPO("MlpPolicy", env, verbose=0) |
|
model.learn(total_timesteps=5000) |
|
|
|
state = env.reset() |
|
done = False |
|
portfolio_weights = [] |
|
portfolio_values = [initial_balance] |
|
while not done: |
|
action, _ = model.predict(state) |
|
state, reward, done, _ = env.step(action) |
|
portfolio_weights = action / (np.sum(action) + 1e-8) |
|
portfolio_values.append(reward) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
plt.plot(portfolio_values, label="Portfolio Value") |
|
plt.title("Portfolio Value Over Time") |
|
plt.xlabel("Time Steps") |
|
plt.ylabel("Portfolio Value") |
|
plt.legend() |
|
plt.grid() |
|
|
|
chart_path = "portfolio_chart.png" |
|
plt.savefig(chart_path) |
|
plt.close() |
|
|
|
|
|
weights = {f"Asset_{i + 1} ({tickers.split(',')[i].strip()})": float(weight) for i, weight in enumerate(portfolio_weights)} |
|
return weights, chart_path |
|
|
|
|
|
def run_optimization(tickers, start_date, end_date, initial_balance): |
|
if not tickers or not start_date or not end_date: |
|
return {"error": "Please fill all the fields."}, None |
|
try: |
|
weights, chart_buffer = optimize_portfolio(tickers, start_date, end_date, initial_balance) |
|
return weights, chart_buffer |
|
except Exception as e: |
|
return {"error": str(e)}, None |
|
|
|
interface = gr.Interface( |
|
fn=run_optimization, |
|
inputs=[ |
|
gr.Textbox(label="Enter Stock Tickers (comma-separated)", placeholder="AAPL, MSFT, TSLA"), |
|
gr.Textbox(label="Start Date (YYYY-MM-DD)", placeholder="2023-01-01"), |
|
gr.Textbox(label="End Date (YYYY-MM-DD)", placeholder="2023-12-31"), |
|
gr.Number(label="Initial Investment Amount", value=10000), |
|
], |
|
outputs=[ |
|
gr.JSON(label="Optimized Portfolio Weights"), |
|
gr.Image(label="Portfolio Value Chart"), |
|
], |
|
title="AI-Powered Portfolio Optimization", |
|
description="""Enter stock tickers (e.g., AAPL, MSFT, TSLA), a date range, and your initial investment amount. |
|
The app fetches historical data, runs AI optimization, and returns the optimized portfolio weights |
|
along with a performance chart.""" |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |