Skip to content

Commit

Permalink
feat: draf forecast logic
Browse files Browse the repository at this point in the history
  • Loading branch information
WLM1ke committed Dec 8, 2024
1 parent 06b3986 commit d10663e
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 23 deletions.
5 changes: 4 additions & 1 deletion poptimizer/domain/domain.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
from datetime import date, datetime
from enum import StrEnum, auto, unique
from typing import Annotated, NewType, Protocol
from typing import Annotated, Final, NewType, Protocol

from pydantic import BaseModel, ConfigDict, PlainSerializer

Expand Down Expand Up @@ -31,6 +31,9 @@ class Revision(BaseModel):
]

Ticker = NewType("Ticker", str)
CashTicker: Final = Ticker("CASH")
PortfolioTicker: Final = Ticker("PORTFOLIO")

AccName = NewType("AccName", str)


Expand Down
42 changes: 38 additions & 4 deletions poptimizer/domain/portfolio/forecasts.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,43 @@
import itertools
from typing import Annotated

from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt, PlainSerializer, field_validator

from poptimizer.domain import domain


class Position(BaseModel):
ticker: domain.Ticker
mean: float = 0
std: NonNegativeFloat = 0
beta: float = 0
grad: float = 0


class Forecast(domain.Entity):
def init_day(
self,
day: domain.Day,
) -> None:
models: Annotated[
set[domain.UID],
PlainSerializer(
list,
return_type=list,
),
] = Field(default_factory=set)
forecasts: NonNegativeInt = 0
portfolio_ver: domain.Version = domain.Version(0)
portfolio: Position = Field(default_factory=lambda: Position(ticker=domain.PortfolioTicker))
cash: Position = Field(default_factory=lambda: Position(ticker=domain.CashTicker))
positions: list[Position] = Field(default_factory=list)

@field_validator("positions")
def _sorted_by_tickers(cls, positions: list[Position]) -> list[Position]:
positions_pairs = itertools.pairwise(positions)

if not all(cur.ticker < next_.ticker for cur, next_ in positions_pairs):
raise ValueError("tickers are not sorted")

return positions

def init_day(self, day: domain.Day) -> None:
self.models.clear()
self.forecasts = 0
self.day = day
19 changes: 3 additions & 16 deletions poptimizer/domain/portfolio/portfolio.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Final, Self
from typing import Self

from pydantic import BaseModel, Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, model_validator

from poptimizer import errors
from poptimizer.domain import domain

_CashTicker: Final = domain.Ticker("CASH")


class Account(BaseModel):
cash: NonNegativeInt = 0
Expand Down Expand Up @@ -47,17 +45,6 @@ def _account_has_known_tickers(self) -> Self:

return self

@property
def value(self) -> float:
value = 0
for account in self.accounts.values():
value += account.cash

for ticker, shares in account.positions.items():
value += shares * self.securities[ticker].price

return value

def create_acount(self, name: domain.AccName) -> None:
if name in self.accounts:
raise errors.DomainError(f"account {name} already exists")
Expand Down Expand Up @@ -90,10 +77,10 @@ def update_position(self, name: domain.AccName, ticker: domain.Ticker, amount: N
if (account := self.accounts.get(name)) is None:
raise errors.DomainError(f"account {name} doesn't exist")

if ticker != _CashTicker and ticker not in self.securities:
if ticker != domain.CashTicker and ticker not in self.securities:
raise errors.DomainError(f"ticker {ticker} doesn't exist")

if ticker == _CashTicker:
if ticker == domain.CashTicker:
account.cash = amount

return
Expand Down
23 changes: 21 additions & 2 deletions poptimizer/use_cases/portfolio/forecasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,28 @@

class ForecastHandler:
async def __call__(
self, ctx: handler.Ctx, msg: handler.ModelDeleted | handler.ModelEvaluated
self,
ctx: handler.Ctx,
msg: handler.ModelDeleted | handler.ModelEvaluated,
) -> handler.ForecastsAnalyzed:
forecast = await ctx.get_for_update(forecasts.Forecast)
forecast.init_day(msg.day)
match msg:
case handler.ModelDeleted():
forecast.models -= {msg.uid}
case handler.ModelEvaluated():
if forecast.day != msg.day:
forecast.init_day(msg.day)

forecast.models.add(msg.uid)

# if len(forecast.models) ** 0.5 - forecast.forecasts**0.5 >= 1:
# await self._update_forecast(ctx, forecast)

return handler.ForecastsAnalyzed(day=msg.day)

# async def _update_forecast(self, ctx: handler.Ctx, forecast: forecasts.Forecast) -> None:
# async with asyncio.TaskGroup() as tg:
# model_tasks = [tg.create_task(ctx.get(evolve.Model, uid)) for uid in forecast.models]
# port = await tg.create_task(ctx.get(portfolio.Portfolio))

# tickers = tuple(sorted(port.securities))

0 comments on commit d10663e

Please sign in to comment.