diff --git a/pyrb/controllers/api/main.py b/pyrb/controllers/api/main.py index 88a2099..8188e48 100644 --- a/pyrb/controllers/api/main.py +++ b/pyrb/controllers/api/main.py @@ -1,9 +1,13 @@ +from uuid import UUID + from fastapi import FastAPI, HTTPException from pydantic import BaseModel +from starlette.status import HTTP_201_CREATED from pyrb.controllers.api.deps import AccountServiceDep +from pyrb.enums import BrokerageType from pyrb.exceptions import InitializationError -from pyrb.models.account import Account +from pyrb.models.account import Account, AccountFactory app = FastAPI() @@ -12,6 +16,16 @@ class AccountResponse(BaseModel): account: Account +class AccountCreateRequest(BaseModel): + brokerage: BrokerageType + app_key: str + secret_key: str + + +class AccountCreateResponse(BaseModel): + account_id: UUID + + @app.get("/accounts/default", response_model=AccountResponse) async def get_default_account(account_service: AccountServiceDep) -> AccountResponse: try: @@ -20,3 +34,15 @@ async def get_default_account(account_service: AccountServiceDep) -> AccountResp raise HTTPException(status_code=404, detail="No accounts registered") from e return AccountResponse(account=account) + + +@app.post("/accounts", response_model=AccountCreateResponse, status_code=HTTP_201_CREATED) +async def create_account( + account_service: AccountServiceDep, body: AccountCreateRequest +) -> AccountCreateResponse: + account = AccountFactory.create( + brokerage=body.brokerage, app_key=body.app_key, app_secret=body.secret_key + ) + account_service.set(account) + + return AccountCreateResponse(account_id=account.id) diff --git a/pyrb/models/account.py b/pyrb/models/account.py index e69234c..ccce933 100644 --- a/pyrb/models/account.py +++ b/pyrb/models/account.py @@ -1,3 +1,4 @@ +import uuid from abc import ABC from typing import Annotated, Any @@ -9,6 +10,7 @@ class Account(BaseModel, ABC): brokerage: Annotated[BrokerageType, Field(...)] + id: uuid.UUID = Field(default_factory=uuid.uuid4) def to_toml(self) -> str: model_dict = self.model_dump()