Skip to content

Commit

Permalink
✨ Add API endpoint to fetch portfolio returns
Browse files Browse the repository at this point in the history
  • Loading branch information
mingi3314 committed Mar 31, 2024
1 parent 4d96e46 commit 4bb2bc9
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 3 deletions.
26 changes: 25 additions & 1 deletion pyrb/controllers/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from uuid import UUID
from zoneinfo import ZoneInfo

from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from pydantic import AwareDatetime, BaseModel
from starlette.status import HTTP_201_CREATED
Expand All @@ -12,6 +12,7 @@
from pyrb.exceptions import InitializationError
from pyrb.models.account import Account, AccountFactory
from pyrb.models.order import Order, OrderPlacementResult
from pyrb.models.portfolio import PortfolioReturn
from pyrb.models.position import Position
from pyrb.services.rebalance import Rebalancer
from pyrb.services.strategy.asset_allocate import AssetAllocationStrategyFactory
Expand Down Expand Up @@ -52,6 +53,12 @@ class PortfolioResponse(BaseModel):
positions: list[Position]


class PortfolioReturnsResponse(BaseModel):
start_dt: AwareDatetime
end_dt: AwareDatetime
returns: list[PortfolioReturn]


class OrdersPrepareResponse(BaseModel):
orders: list[Order]

Expand Down Expand Up @@ -98,6 +105,23 @@ async def get_portfolio(context: RebalanceContextDep) -> PortfolioResponse:
)


@app.get("/portfolio/returns", response_model=PortfolioReturnsResponse)
async def fetch_portfolio_returns(
context: RebalanceContextDep,
start_dt: AwareDatetime = Query(),
end_dt: AwareDatetime = Query(
default_factory=lambda: datetime.datetime.now(ZoneInfo("Asia/Seoul"))
),
) -> PortfolioReturnsResponse:
returns = context.portfolio.fetch_returns(start_dt, end_dt)

return PortfolioReturnsResponse(
start_dt=start_dt,
end_dt=end_dt,
returns=returns,
)


@app.get("/strategies/{strategy_type}/orders", response_model=OrdersPrepareResponse)
async def prepare_orders(
context: RebalanceContextDep,
Expand Down
7 changes: 7 additions & 0 deletions pyrb/models/portfolio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from pydantic import AwareDatetime, BaseModel


class PortfolioReturn(BaseModel):
dt: AwareDatetime
rtn: float
pnl: float
18 changes: 17 additions & 1 deletion pyrb/repositories/brokerages/base/portfolio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import abc

from pydantic import NonNegativeFloat
from pydantic import AwareDatetime, NonNegativeFloat

from pyrb.models.portfolio import PortfolioReturn
from pyrb.models.position import Position


Expand Down Expand Up @@ -73,6 +74,21 @@ def get_position_amount(self, symbol: str) -> NonNegativeFloat:
"""
...

@abc.abstractmethod
def fetch_returns(
self, start_dt: AwareDatetime, end_dt: AwareDatetime
) -> list[PortfolioReturn]:
"""Fetches the returns of the portfolio for the given period.
Args:
start_dt (AwareDatetime): The start date of the period.
end_dt (AwareDatetime): The end date of the period.
Returns:
list[PortfolioReturn]: A list of PortfolioReturn objects.
"""
...

@abc.abstractmethod
def refresh(self) -> None:
"""Refreshes the portfolio object."""
Expand Down
34 changes: 33 additions & 1 deletion pyrb/repositories/brokerages/ebest/portfolio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import datetime
from typing import Any
from zoneinfo import ZoneInfo

from pydantic import NonNegativeFloat
from pydantic import AwareDatetime, NonNegativeFloat

from pyrb.models.portfolio import PortfolioReturn
from pyrb.models.position import Asset, Position
from pyrb.repositories.brokerages.base.portfolio import Portfolio
from pyrb.repositories.brokerages.ebest.client import EbestAPIClient
Expand Down Expand Up @@ -56,6 +59,35 @@ def get_position_amount(self, symbol: str) -> NonNegativeFloat:
position = self.get_position(symbol)
return position.total_amount if position else 0

def fetch_returns(
self, start_date: AwareDatetime, end_date: AwareDatetime
) -> list[PortfolioReturn]:
path = "stock/accno"
content_type = "application/json; charset=UTF-8"

headers = {"content-type": content_type, "tr_cd": "FOCCQ33600", "tr_cont": "N"}

body = {
"FOCCQ33600InBlock1": {
"QrySrtDt": start_date.strftime("%Y%m%d"),
"QryEndDt": end_date.strftime("%Y%m%d"),
"TermTp": "1",
}
}

response = self._api_client.send_request("POST", path, headers=headers, json=body)
print(response.json())
return [
PortfolioReturn(
dt=datetime.datetime.strptime(each["BaseDt"], "%Y%m%d").replace(
tzinfo=ZoneInfo("Asia/Seoul")
),
rtn=float(each["TermErnrat"]) / 100,
pnl=each["EvalPnlAmt"],
)
for each in response.json()["FOCCQ33600OutBlock3"]
]

def refresh(self) -> None:
self._serialized_portfolio = self._fetch_portfolio()

Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import tempfile
from collections.abc import Generator
from datetime import datetime
from pathlib import Path

import pytest

from pyrb.models.order import Order
from pyrb.models.portfolio import PortfolioReturn
from pyrb.models.position import Asset, Position
from pyrb.models.price import CurrentPrice
from pyrb.repositories.account import AccountRepository, LocalConfigAccountRepository
Expand Down Expand Up @@ -62,6 +64,9 @@ def get_position_amount(self, symbol: str) -> float:
position = self.get_position(symbol)
return position.total_amount if position else 0

def fetch_returns(self, start_dt: datetime, end_dt: datetime) -> list[PortfolioReturn]:
return [PortfolioReturn(dt=start_dt, rtn=0.0, pnl=0.0)]

def refresh(self) -> None: ...


Expand Down
14 changes: 14 additions & 0 deletions tests/controllers/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,20 @@ def test_get_portfolio(fake_rebalance_context: RebalanceContext) -> None:
app.dependency_overrides.clear()


def test_fetch_portfolio_returns(fake_rebalance_context: RebalanceContext) -> None:
# Given
create_account()
app.dependency_overrides[context_dep] = lambda: fake_rebalance_context

# When
response = client.get("/portfolio/returns?start_dt=2024-01-01T00:00:00.009Z")

# Then
assert response.status_code == 200
data = response.json()
assert "returns" in data


def test_get_portfolio_without_account() -> None:
# When
response = client.get("/portfolio")
Expand Down

0 comments on commit 4bb2bc9

Please sign in to comment.