from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone, date from typing import Dict, List, Tuple, Optional import psycopg from psycopg.rows import dict_row from EnergyPrice import EnergyPriceBase from utils.time import WARSAW_TZ @dataclass class RDNProviderPG(EnergyPriceBase): """ Odczyt stawek RDN z tabeli pricing.energy_prices w Postgresie. - rate(ts) -> PLN/kWh (netto) dla znacznika czasu. - Korzysta z kolumny ts_range (tstzrange, '[)'), więc obsługuje dowolne podziały doby. - Caching per-dzień (w czasie lokalnym Europe/Warsaw) dla wydajności. """ dsn: str provider: str kind: str side: str # 'buy' | 'sell' table: str = "pricing.energy_prices" _conn: Optional[psycopg.Connection] = field(default=None, init=False, repr=False) _day_cache: Dict[date, List[Tuple[datetime, datetime, float]]] = field(default_factory=dict, init=False, repr=False) # --- Połączenie ---------------------------------------------------------- def _connect(self) -> psycopg.Connection: if self._conn is None or self._conn.closed: self._conn = psycopg.connect(self.dsn, row_factory=dict_row) return self._conn def close(self) -> None: if self._conn and not self._conn.closed: self._conn.close() self._conn = None self._day_cache.clear() # --- Pobranie i cache jednego dnia -------------------------------------- def _day_bounds_local(self, d: date) -> Tuple[datetime, datetime]: start_local = datetime(d.year, d.month, d.day, 0, 0, tzinfo=self.tz) end_local = start_local + timedelta(days=1) return start_local, end_local def _load_day(self, d: date) -> List[Tuple[datetime, datetime, float]]: """ Zwraca listę interwałów [(start_utc, end_utc, price_net), ...] dla dnia d (lokalnie). Wynik sortowany po starcie; przycina dobie lokalnej po stronie SQL. """ if d in self._day_cache: return self._day_cache[d] start_local, end_local = self._day_bounds_local(d) start_utc = start_local.astimezone(timezone.utc) end_utc = end_local.astimezone(timezone.utc) sql = f""" SELECT ts_start, ts_end, price_pln_net FROM {self.table} WHERE provider = %s AND kind = %s AND side = %s AND ts_range && tstzrange(%s, %s, '[)') ORDER BY ts_start; """ with self._connect().cursor() as cur: cur.execute(sql, (self.provider, self.kind, self.side, start_utc, end_utc)) rows = cur.fetchall() # Clip do granic doby lokalnej i zapisz do cache out: List[Tuple[datetime, datetime, float]] = [] for r in rows: s = max(r["ts_start"], start_utc) e = min(r["ts_end"], end_utc) if e > s: out.append((s, e, float(r["price_pln_net"]))) # Porządek i scalanie nakładających się fragmentów o tej samej cenie out.sort(key=lambda t: t[0]) merged: List[Tuple[datetime, datetime, float]] = [] for s, e, p in out: if not merged: merged.append((s, e, p)) continue ps, pe, pp = merged[-1] # jeśli overlap i ta sama cena -> łączymy if s <= pe and p == pp: merged[-1] = (ps, max(pe, e), pp) else: # jeśli overlap, ale inna cena — rozcinamy granicę (zachowujemy kolejność) if s < pe and p != pp: s = pe if s >= e: continue merged.append((s, e, p)) self._day_cache[d] = merged return merged # --- Public API: jedna stawka w danym ts -------------------------------- def rate(self, ts: datetime) -> float: """ Zwraca PLN/kWh (netto) dla timestampa ts. Rzuca KeyError, jeśli w dobie brakuje pokrycia interwałami lub ts wpada w lukę. """ ts_local = self.to_local_dt(ts) d = ts_local.date() intervals = self._load_day(d) if not intervals: raise KeyError(f"Brak danych cenowych dla {d.isoformat()} ({self.provider}/{self.kind}/{self.side})") ts_utc = ts_local.astimezone(timezone.utc) # znajdź interwał ts_start <= ts < ts_end # (lista jest posortowana) lo, hi = 0, len(intervals) - 1 while lo <= hi: mid = (lo + hi) // 2 s, e, p = intervals[mid] if ts_utc < s: hi = mid - 1 elif ts_utc >= e: lo = mid + 1 else: return p raise KeyError(f"Brak ceny dla {ts_local.isoformat()} (luka w dobie {d.isoformat()})") # --- Dodatkowe: eksport do HH:MM->price (np. do debug/raportów) ---------- def day_schedule_local(self, d: date) -> List[Tuple[str, str, float]]: """ Zwraca listę [(from_HHMM, to_HHMM, price)] w CZASIE LOKALNYM, przyciętą do pełnej doby lokalnej. Ostatni 'to' = '00:00'. """ start_local, end_local = self._day_bounds_local(d) intervals = self._load_day(d) if not intervals: return [] # przekładamy na lokalny czas parts: List[Tuple[datetime, datetime, float]] = [] for s, e, p in intervals: sl = s.astimezone(self.tz) el = e.astimezone(self.tz) # przycięcie na wszelki wypadek sl = max(sl, start_local) el = min(el, end_local) if el > sl: parts.append((sl, el, p)) # łączenie sąsiadów z tą samą ceną parts.sort(key=lambda x: x[0]) merged: List[Tuple[datetime, datetime, float]] = [] for s, e, p in parts: if merged and merged[-1][1] == s and merged[-1][2] == p: merged[-1] = (merged[-1][0], e, p) else: merged.append((s, e, p)) # upewnij się, że start=00:00, koniec=00:00 if merged and merged[0][0] != start_local: if merged[0][0] > start_local: # luka na początku doby raise KeyError(f"Luka na początku doby {d}") merged[0] = (start_local, merged[0][1], merged[0][2]) if merged and merged[-1][1] != end_local: if merged[-1][1] < end_local: # luka na końcu doby raise KeyError(f"Luka na końcu doby {d}") merged[-1] = (merged[-1][0], end_local, merged[-1][2]) def hhmm(x: datetime) -> str: return x.strftime("%H:%M") out: List[Tuple[str, str, float]] = [] for i, (s, e, p) in enumerate(merged): out.append((hhmm(s), "00:00" if i == len(merged) - 1 else hhmm(e), p)) return out