113 lines
4.9 KiB
Python
113 lines
4.9 KiB
Python
from __future__ import annotations
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timedelta
|
||
from typing import Dict, List, Tuple, Optional
|
||
from zoneinfo import ZoneInfo
|
||
import psycopg
|
||
from EnergyPrice import EnergyPriceBase
|
||
|
||
WAW = ZoneInfo("Europe/Warsaw")
|
||
Interval = Tuple[datetime, datetime, float] # (ts_start, ts_end, price_pln_net)
|
||
|
||
@dataclass
|
||
class DynamicPricesProvider(EnergyPriceBase):
|
||
"""
|
||
Bazowa klasa dla cen dynamicznych trzymanych w pricing.energy_prices (Timescale).
|
||
- rate(ts): zwraca cenę (PLN/kWh netto) dla podanej chwili.
|
||
Przy pierwszym wywołaniu dla danej doby – ładuje CAŁĄ dobę do cache.
|
||
- has_rate(ts): domyślnie True (możesz nadpisać w podklasie, jeśli chcesz sprawdzać realną obecność danych).
|
||
- Subklasy definiują PROVIDER i KIND (np. 'TGE' / 'fixing_I').
|
||
"""
|
||
dsn: Optional[str] = None # np. "postgresql://user:pass@localhost:5432/postgres"
|
||
conn: Optional[psycopg.Connection] = None # albo podaj gotowe połączenie psycopg3
|
||
table: str = "pricing.energy_prices"
|
||
tz: ZoneInfo = WAW
|
||
max_cached_days: int = 14 # ile różnych dób trzymać w cache
|
||
|
||
# identyfikatory – nadpisujesz w podklasie
|
||
SIDE: str = "" # 'buy' albo 'sell'
|
||
PROVIDER: str = ""
|
||
KIND: str = ""
|
||
|
||
# prosty cache: klucz = początek doby (local), wartość = lista interwałów (start, end, price)
|
||
_cache: Dict[datetime, List[Interval]] = field(default_factory=dict, init=False, repr=False)
|
||
_cache_order: List[datetime] = field(default_factory=list, init=False, repr=False)
|
||
|
||
# ---------- public API ----------
|
||
def provider(self) -> str:
|
||
if not self.PROVIDER:
|
||
raise NotImplementedError("Subclass must define PROVIDER")
|
||
return self.PROVIDER
|
||
|
||
def kind(self) -> str:
|
||
if not self.KIND:
|
||
raise NotImplementedError("Subclass must define KIND")
|
||
return self.KIND
|
||
|
||
def side(self) -> str:
|
||
if not self.SIDE:
|
||
raise NotImplementedError("Subclass must define SIDE")
|
||
return self.SIDE
|
||
|
||
def rate(self, ts: datetime) -> float:
|
||
"""Zwraca cenę netto PLN/kWh dla chwili ts. Ładuje cały dzień do cache przy pierwszym wywołaniu."""
|
||
dt = self._to_local(ts)
|
||
day_key = self._day_key(dt)
|
||
self._ensure_day_cached(day_key)
|
||
for start, end, price in self._cache.get(day_key, []):
|
||
if start <= dt < end:
|
||
return price
|
||
raise KeyError(f"No price for {dt.isoformat()} (provider={self.provider()}, kind={self.kind()})")
|
||
|
||
def preload_day(self, day: datetime | None = None):
|
||
"""Opcjonalnie: prefetch doby (początek dnia lokalnie)."""
|
||
day_key = self._day_key(self._to_local(day or datetime.now(tz=self.tz)))
|
||
self._ensure_day_cached(day_key)
|
||
|
||
def clear_cache(self):
|
||
self._cache.clear()
|
||
self._cache_order.clear()
|
||
|
||
# ---------- internals ----------
|
||
def _to_local(self, ts: datetime) -> datetime:
|
||
return ts.replace(tzinfo=self.tz) if ts.tzinfo is None else ts.astimezone(self.tz)
|
||
|
||
def _day_key(self, ts_local: datetime) -> datetime:
|
||
return ts_local.replace(hour=0, minute=0, second=0, microsecond=0)
|
||
|
||
def _ensure_conn(self) -> psycopg.Connection:
|
||
if self.conn is not None:
|
||
return self.conn
|
||
if not self.dsn:
|
||
raise RuntimeError("Provide dsn= or conn= to DynamicPricesProvider")
|
||
self.conn = psycopg.connect(self.dsn)
|
||
return self.conn
|
||
|
||
def _ensure_day_cached(self, day_start_local: datetime):
|
||
if day_start_local in self._cache:
|
||
return
|
||
self._cache[day_start_local] = self._fetch_day(day_start_local)
|
||
self._cache_order.append(day_start_local)
|
||
# prosty limit cache po liczbie dni
|
||
while len(self._cache_order) > self.max_cached_days:
|
||
oldest = self._cache_order.pop(0)
|
||
self._cache.pop(oldest, None)
|
||
|
||
def _fetch_day(self, day_start_local: datetime) -> List[Interval]:
|
||
"""Pobiera z DB wszystkie rekordy nachodzące na [day_start, day_start+1d)."""
|
||
day_end_local = day_start_local + timedelta(days=1)
|
||
sql = f"""
|
||
SELECT ts_start, ts_end, price_pln_net
|
||
FROM {self.table}
|
||
WHERE provider = %s
|
||
AND kind = %s
|
||
AND side = %s
|
||
AND tstzrange(ts_start, ts_end, '[)') && tstzrange(%s::timestamptz, %s::timestamptz, '[)')
|
||
ORDER BY ts_start
|
||
"""
|
||
|
||
with self._ensure_conn().cursor() as cur:
|
||
cur.execute(sql, (self.provider(), self.kind(), self.side(), day_start_local, day_end_local))
|
||
rows = cur.fetchall()
|
||
# rows: List[Tuple[datetime, datetime, Decimal]]
|
||
return [(r[0], r[1], float(r[2])) for r in rows] |