diff --git a/pyrb/controllers/api/main.py b/pyrb/controllers/api/main.py index e0278b5..7f3c82c 100644 --- a/pyrb/controllers/api/main.py +++ b/pyrb/controllers/api/main.py @@ -2,7 +2,7 @@ from uuid import UUID from zoneinfo import ZoneInfo -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import AwareDatetime, BaseModel from starlette.status import HTTP_201_CREATED @@ -12,6 +12,7 @@ from pyrb.exceptions import InitializationError from pyrb.models.account import Account, AccountFactory from pyrb.models.order import Order, OrderPlacementResult +from pyrb.models.portfolio import PortfolioReturn from pyrb.models.position import Position from pyrb.services.rebalance import Rebalancer from pyrb.services.strategy.asset_allocate import AssetAllocationStrategyFactory @@ -52,6 +53,12 @@ class PortfolioResponse(BaseModel): positions: list[Position] +class PortfolioReturnsResponse(BaseModel): + start_dt: AwareDatetime + end_dt: AwareDatetime + returns: list[PortfolioReturn] + + class OrdersPrepareResponse(BaseModel): orders: list[Order] @@ -98,6 +105,23 @@ async def get_portfolio(context: RebalanceContextDep) -> PortfolioResponse: ) +@app.get("/portfolio/returns", response_model=PortfolioReturnsResponse) +async def fetch_portfolio_returns( + context: RebalanceContextDep, + start_dt: AwareDatetime = Query(), + end_dt: AwareDatetime = Query( + default_factory=lambda: datetime.datetime.now(ZoneInfo("Asia/Seoul")) + ), +) -> PortfolioReturnsResponse: + returns = context.portfolio.fetch_returns(start_dt, end_dt) + + return PortfolioReturnsResponse( + start_dt=start_dt, + end_dt=end_dt, + returns=returns, + ) + + @app.get("/strategies/{strategy_type}/orders", response_model=OrdersPrepareResponse) async def prepare_orders( context: RebalanceContextDep, diff --git a/pyrb/models/portfolio.py b/pyrb/models/portfolio.py new file mode 100644 index 0000000..96736dc --- /dev/null +++ b/pyrb/models/portfolio.py @@ -0,0 +1,7 @@ +from pydantic import AwareDatetime, BaseModel + + +class PortfolioReturn(BaseModel): + dt: AwareDatetime + rtn: float + pnl: float diff --git a/pyrb/repositories/brokerages/base/portfolio.py b/pyrb/repositories/brokerages/base/portfolio.py index ab436f6..474ea26 100644 --- a/pyrb/repositories/brokerages/base/portfolio.py +++ b/pyrb/repositories/brokerages/base/portfolio.py @@ -1,7 +1,8 @@ import abc -from pydantic import NonNegativeFloat +from pydantic import AwareDatetime, NonNegativeFloat +from pyrb.models.portfolio import PortfolioReturn from pyrb.models.position import Position @@ -73,6 +74,21 @@ def get_position_amount(self, symbol: str) -> NonNegativeFloat: """ ... + @abc.abstractmethod + def fetch_returns( + self, start_dt: AwareDatetime, end_dt: AwareDatetime + ) -> list[PortfolioReturn]: + """Fetches the returns of the portfolio for the given period. + + Args: + start_dt (AwareDatetime): The start date of the period. + end_dt (AwareDatetime): The end date of the period. + + Returns: + list[PortfolioReturn]: A list of PortfolioReturn objects. + """ + ... + @abc.abstractmethod def refresh(self) -> None: """Refreshes the portfolio object.""" diff --git a/pyrb/repositories/brokerages/ebest/portfolio.py b/pyrb/repositories/brokerages/ebest/portfolio.py index 1e5b602..67f4200 100644 --- a/pyrb/repositories/brokerages/ebest/portfolio.py +++ b/pyrb/repositories/brokerages/ebest/portfolio.py @@ -1,7 +1,10 @@ +import datetime from typing import Any +from zoneinfo import ZoneInfo -from pydantic import NonNegativeFloat +from pydantic import AwareDatetime, NonNegativeFloat +from pyrb.models.portfolio import PortfolioReturn from pyrb.models.position import Asset, Position from pyrb.repositories.brokerages.base.portfolio import Portfolio from pyrb.repositories.brokerages.ebest.client import EbestAPIClient @@ -56,6 +59,35 @@ def get_position_amount(self, symbol: str) -> NonNegativeFloat: position = self.get_position(symbol) return position.total_amount if position else 0 + def fetch_returns( + self, start_date: AwareDatetime, end_date: AwareDatetime + ) -> list[PortfolioReturn]: + path = "stock/accno" + content_type = "application/json; charset=UTF-8" + + headers = {"content-type": content_type, "tr_cd": "FOCCQ33600", "tr_cont": "N"} + + body = { + "FOCCQ33600InBlock1": { + "QrySrtDt": start_date.strftime("%Y%m%d"), + "QryEndDt": end_date.strftime("%Y%m%d"), + "TermTp": "1", + } + } + + response = self._api_client.send_request("POST", path, headers=headers, json=body) + print(response.json()) + return [ + PortfolioReturn( + dt=datetime.datetime.strptime(each["BaseDt"], "%Y%m%d").replace( + tzinfo=ZoneInfo("Asia/Seoul") + ), + rtn=float(each["TermErnrat"]) / 100, + pnl=each["EvalPnlAmt"], + ) + for each in response.json()["FOCCQ33600OutBlock3"] + ] + def refresh(self) -> None: self._serialized_portfolio = self._fetch_portfolio() diff --git a/tests/conftest.py b/tests/conftest.py index 9618ec6..4299a9e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,12 @@ import tempfile from collections.abc import Generator +from datetime import datetime from pathlib import Path import pytest from pyrb.models.order import Order +from pyrb.models.portfolio import PortfolioReturn from pyrb.models.position import Asset, Position from pyrb.models.price import CurrentPrice from pyrb.repositories.account import AccountRepository, LocalConfigAccountRepository @@ -62,6 +64,9 @@ def get_position_amount(self, symbol: str) -> float: position = self.get_position(symbol) return position.total_amount if position else 0 + def fetch_returns(self, start_dt: datetime, end_dt: datetime) -> list[PortfolioReturn]: + return [PortfolioReturn(dt=start_dt, rtn=0.0, pnl=0.0)] + def refresh(self) -> None: ... diff --git a/tests/controllers/test_api.py b/tests/controllers/test_api.py index ec2f698..78f1ee3 100644 --- a/tests/controllers/test_api.py +++ b/tests/controllers/test_api.py @@ -259,6 +259,20 @@ def test_get_portfolio(fake_rebalance_context: RebalanceContext) -> None: app.dependency_overrides.clear() +def test_fetch_portfolio_returns(fake_rebalance_context: RebalanceContext) -> None: + # Given + create_account() + app.dependency_overrides[context_dep] = lambda: fake_rebalance_context + + # When + response = client.get("/portfolio/returns?start_dt=2024-01-01T00:00:00.009Z") + + # Then + assert response.status_code == 200 + data = response.json() + assert "returns" in data + + def test_get_portfolio_without_account() -> None: # When response = client.get("/portfolio")