ranczo-energy-price-scrapers/EnergyPriceProvider/DynamicPricesProvider.py
2025-08-28 12:58:36 +02:00

113 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional
from zoneinfo import ZoneInfo
import psycopg
from EnergyPrice import EnergyPriceBase
WAW = ZoneInfo("Europe/Warsaw")
Interval = Tuple[datetime, datetime, float] # (ts_start, ts_end, price_pln_net)
@dataclass
class DynamicPricesProvider(EnergyPriceBase):
"""
Bazowa klasa dla cen dynamicznych trzymanych w pricing.energy_prices (Timescale).
- rate(ts): zwraca cenę (PLN/kWh netto) dla podanej chwili.
Przy pierwszym wywołaniu dla danej doby ładuje CAŁĄ dobę do cache.
- has_rate(ts): domyślnie True (możesz nadpisać w podklasie, jeśli chcesz sprawdzać realną obecność danych).
- Subklasy definiują PROVIDER i KIND (np. 'TGE' / 'fixing_I').
"""
dsn: Optional[str] = None # np. "postgresql://user:pass@localhost:5432/postgres"
conn: Optional[psycopg.Connection] = None # albo podaj gotowe połączenie psycopg3
table: str = "pricing.energy_prices"
tz: ZoneInfo = WAW
max_cached_days: int = 14 # ile różnych dób trzymać w cache
# identyfikatory nadpisujesz w podklasie
SIDE: str = "" # 'buy' albo 'sell'
PROVIDER: str = ""
KIND: str = ""
# prosty cache: klucz = początek doby (local), wartość = lista interwałów (start, end, price)
_cache: Dict[datetime, List[Interval]] = field(default_factory=dict, init=False, repr=False)
_cache_order: List[datetime] = field(default_factory=list, init=False, repr=False)
# ---------- public API ----------
def provider(self) -> str:
if not self.PROVIDER:
raise NotImplementedError("Subclass must define PROVIDER")
return self.PROVIDER
def kind(self) -> str:
if not self.KIND:
raise NotImplementedError("Subclass must define KIND")
return self.KIND
def side(self) -> str:
if not self.SIDE:
raise NotImplementedError("Subclass must define SIDE")
return self.SIDE
def rate(self, ts: datetime) -> float:
"""Zwraca cenę netto PLN/kWh dla chwili ts. Ładuje cały dzień do cache przy pierwszym wywołaniu."""
dt = self._to_local(ts)
day_key = self._day_key(dt)
self._ensure_day_cached(day_key)
for start, end, price in self._cache.get(day_key, []):
if start <= dt < end:
return price
raise KeyError(f"No price for {dt.isoformat()} (provider={self.provider()}, kind={self.kind()})")
def preload_day(self, day: datetime | None = None):
"""Opcjonalnie: prefetch doby (początek dnia lokalnie)."""
day_key = self._day_key(self._to_local(day or datetime.now(tz=self.tz)))
self._ensure_day_cached(day_key)
def clear_cache(self):
self._cache.clear()
self._cache_order.clear()
# ---------- internals ----------
def _to_local(self, ts: datetime) -> datetime:
return ts.replace(tzinfo=self.tz) if ts.tzinfo is None else ts.astimezone(self.tz)
def _day_key(self, ts_local: datetime) -> datetime:
return ts_local.replace(hour=0, minute=0, second=0, microsecond=0)
def _ensure_conn(self) -> psycopg.Connection:
if self.conn is not None:
return self.conn
if not self.dsn:
raise RuntimeError("Provide dsn= or conn= to DynamicPricesProvider")
self.conn = psycopg.connect(self.dsn)
return self.conn
def _ensure_day_cached(self, day_start_local: datetime):
if day_start_local in self._cache:
return
self._cache[day_start_local] = self._fetch_day(day_start_local)
self._cache_order.append(day_start_local)
# prosty limit cache po liczbie dni
while len(self._cache_order) > self.max_cached_days:
oldest = self._cache_order.pop(0)
self._cache.pop(oldest, None)
def _fetch_day(self, day_start_local: datetime) -> List[Interval]:
"""Pobiera z DB wszystkie rekordy nachodzące na [day_start, day_start+1d)."""
day_end_local = day_start_local + timedelta(days=1)
sql = f"""
SELECT ts_start, ts_end, price_pln_net
FROM {self.table}
WHERE provider = %s
AND kind = %s
AND side = %s
AND tstzrange(ts_start, ts_end, '[)') && tstzrange(%s::timestamptz, %s::timestamptz, '[)')
ORDER BY ts_start
"""
with self._ensure_conn().cursor() as cur:
cur.execute(sql, (self.provider(), self.kind(), self.side(), day_start_local, day_end_local))
rows = cur.fetchall()
# rows: List[Tuple[datetime, datetime, Decimal]]
return [(r[0], r[1], float(r[2])) for r in rows]