rublon-ssh/PAM/ssh/include/rublon/websockets.hpp
2025-07-18 14:13:06 +02:00

305 lines
11 KiB
C++

#pragma once
#include <chrono>
#include <cstdio>
#include <cstring>
#include <functional>
#include <string>
#include <string_view>
#include <rublon/configuration.hpp>
#include <rublon/error.hpp>
#include <rublon/json.hpp>
#include <rublon/memory.hpp>
#include <rublon/pam_action.hpp>
#include <rublon/static_string.hpp>
#include <rublon/utils.hpp>
#include <libwebsockets.h>
#include <unistd.h>
namespace rublon {
enum TransactionConfirmationStatus { transactionConfirmed, transactionDenied };
struct RublonEventData {
TransactionConfirmationStatus status;
StaticString< 32 > transactionID; /// TODO tid verification?
std::optional< StaticString< 60 > > accessToken;
};
class WebSocket {
std::reference_wrapper< const Configuration > _config;
std::string_view urlv;
bool event_received = false;
bool con_ok = false;
lws_context * context{nullptr};
lws * wsi{nullptr};
lws_context_creation_info info{};
lws_client_connect_info ccinfo{};
RublonEventData * currentEvent{nullptr};
std::pmr::string proxyUrl{};
public:
WebSocket(const Configuration & config) : _config{config}, urlv{_config.get().apiServer}, proxyUrl{_config.get().apiServer.get_allocator()} {
const auto & cfg = _config.get(); // only a alias to not use _config.get() all the time
auto lws_log_emit = [](int level, const char * line) {
LogLevel rlevel{};
if(level == LLL_ERR)
rlevel = LogLevel::Error;
if(level == LLL_WARN)
rlevel = LogLevel::Warning;
if(level == LLL_INFO)
rlevel = LogLevel::Info;
if(level == LLL_NOTICE)
rlevel = LogLevel::Debug;
if(level == LLL_DEBUG)
rlevel = LogLevel::Debug;
log(rlevel, "libwesockets: %s", line);
};
if(_config.get().logging) {
lws_set_log_level(LLL_ERR | LLL_WARN | LLL_NOTICE | LLL_INFO | LLL_DEBUG, lws_log_emit);
} else {
lws_set_log_level(LLL_ERR | LLL_WARN, lws_log_emit);
}
memset(&info, 0, sizeof(info));
memset(&ccinfo, 0, sizeof(ccinfo));
info.port = CONTEXT_PORT_NO_LISTEN;
info.protocols = protocols;
info.options = LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
if(cfg.proxyEnabled && (cfg.proxyType == "http" || cfg.proxyType == "https")) {
assert(cfg.proxyType.has_value());
assert(cfg.proxyHost.has_value());
log(LogLevel::Debug, "WebSocket using proxy");
// "username:password\@server:port"
if(cfg.proxyAuthRequired) {
proxyUrl.reserve(conservative_estimate(cfg.proxyUsername, cfg.proxyPass, cfg.proxyHost) + 10);
proxyUrl += *cfg.proxyUsername;
proxyUrl += ":";
proxyUrl += *cfg.proxyPass;
proxyUrl += "@";
}
proxyUrl += *cfg.proxyHost;
log(LogLevel::Debug, "WebSocket proxy %s", proxyUrl.c_str());
info.http_proxy_address = proxyUrl.c_str();
info.http_proxy_port = config.proxyPort.value_or(8080);
}
context = lws_create_context(&info);
const std::string_view prefix = "https://";
if(urlv.substr(0, prefix.size()) == prefix)
urlv.remove_prefix(prefix.size());
ccinfo.context = context;
ccinfo.address = urlv.data();
ccinfo.host = ccinfo.address;
ccinfo.port = 443;
ccinfo.ssl_connection = LCCSCF_USE_SSL;
ccinfo.path = "/ws/socket.io/?EIO=4&transport=websocket";
ccinfo.protocol = protocols[0].name;
ccinfo.pwsi = &wsi;
ccinfo.userdata = this;
log(LogLevel::Debug, "WebSocket Created connection to %s", urlv.data());
}
WebSocket(WebSocket && rhs) noexcept = default;
WebSocket & operator=(WebSocket && rhs) noexcept = default;
WebSocket(const WebSocket &) noexcept = delete;
WebSocket & operator=(const WebSocket &) noexcept = delete;
~WebSocket() noexcept {
if(context)
lws_context_destroy(context);
}
bool attachToTransactionConfirmationChannel(std::string_view transaction_id) {
StaticString< 128 > subscribe_message{};
unsigned char buf[128 + LWS_PRE] = {};
subscribe_message += R"msg(42["subscribe",{"channel":"transactionConfirmation.)msg";
subscribe_message += transaction_id.data();
subscribe_message += R"("}])";
memcpy(buf + LWS_PRE, subscribe_message.data(), subscribe_message.size());
log(LogLevel::Debug, "WebSocket send message: %s", subscribe_message.c_str());
int bytes_sent = lws_write(wsi, buf + LWS_PRE, subscribe_message.size(), LWS_WRITE_TEXT);
log(LogLevel::Debug, "WebSocket send: %d bytes", bytes_sent);
if(bytes_sent < ( int ) subscribe_message.size()) {
log(LogLevel::Error, "WebSocket failed to send subscribe message");
return false;
}
return true;
}
bool AttachToCore(std::string_view tid) {
log(LogLevel::Debug, "WebSocket attaching to rublon api at %s", ccinfo.address);
lws_client_connect_via_info(&ccinfo);
const int seconds = 10;
auto endtime = std::chrono::steady_clock::now() + std::chrono::seconds{seconds};
// we need to wait for connect, 10sec should be ok
while(!con_ok && std::chrono::steady_clock::now() < endtime) {
lws_service(context, 1000);
}
if(wsi == nullptr || !con_ok) {
log(LogLevel::Error, "WebSocket client connection failed");
return false;
} else {
log(LogLevel::Debug, "WebSocket client connection created");
}
if(not attachToTransactionConfirmationChannel(tid)) {
log(LogLevel::Error, "WebSocket transaction subscribtion was not send properly");
return false;
}
log(LogLevel::Debug, "WebSocket client connection OK");
return true;
}
RublonEventData listen() {
RublonEventData event;
event_received = false;
currentEvent = &event; // for callback use
const int seconds = 60;
auto endtime = std::chrono::steady_clock::now() + std::chrono::seconds{seconds};
log(LogLevel::Debug, "WebSocket waiting for events for %d seconds", seconds);
while(!event_received && std::chrono::steady_clock::now() < endtime) {
lws_service(context, 1000);
}
currentEvent = nullptr;
event_received = false;
return event;
}
static int callback_ws(struct lws * wsi, enum lws_callback_reasons reason, void * user, void * in, size_t len) {
WebSocket * _this = ( WebSocket * ) user;
switch(reason) {
case LWS_CALLBACK_CLIENT_ESTABLISHED:
log(LogLevel::Debug, "WebSocket connection established");
lws_callback_on_writable(wsi); // Request writable event to send handshake
_this->con_ok = true;
break;
case LWS_CALLBACK_CLIENT_WRITEABLE: {
// Perform the Socket.IO 4.x handshake (send `40` message)
const std::string_view handshake = "40";
unsigned char buf[64] = {};
memcpy(&buf[LWS_PRE], handshake.data(), handshake.size());
lws_write(wsi, &buf[LWS_PRE], handshake.size(), LWS_WRITE_TEXT);
log(LogLevel::Debug, "WebSocket Sent Socket.IO handshake");
break;
}
case LWS_CALLBACK_CLIENT_RECEIVE: {
std::string_view input(( const char * ) in, len); // not null terminated
auto readStatus = [](std::string_view input) {
if(input == "transactionConfirmed") {
return transactionConfirmed;
} else if(input == "transactionDenied") {
return transactionDenied;
} else {
return transactionDenied; /// TODO przedawnione
};
};
if(input.substr(0, 2) == "42") {
/// TODO assert _this
/// TODO assert currentEvent
/// TODO refactor to separate class
if(_this->currentEvent == nullptr)
return -1;
log(LogLevel::Debug, "WebSocket got %s", input.data());
size_t startPos = input.find("[\"") + 2;
size_t endPos = input.find("\",", startPos);
auto & event = *_this->currentEvent;
event.status = readStatus(input.substr(startPos, endPos - startPos));
startPos = endPos + 3;
endPos = input.find("\",", startPos);
auto fulltid = input.substr(startPos, endPos - startPos);
event.transactionID = fulltid.substr(fulltid.find(".", 0) + 1, 32);
startPos = endPos + 2;
auto jsonString = input.substr(startPos, input.length() - startPos - 1);
memory::Monotonic_1k_Resource mr;
RapidJSONPMRAlloc alloc{&mr};
Document dataJson{&alloc};
dataJson.Parse(jsonString.data(), jsonString.size());
const auto * data = JSONPointer{"/data", &alloc}.Get(dataJson);
const auto * redirectUrl = JSONPointer{"/redirectUrl", &alloc}.Get(dataJson);
if(data != nullptr) {
Document tokenJson{&alloc};
tokenJson.Parse(data->GetString(), data->GetStringLength());
const auto * token = JSONPointer{"/token", &alloc}.Get(tokenJson);
if(token != nullptr) {
event.accessToken = token->GetString();
} else {
log(LogLevel::Error, "WebSocket response does not contain access token");
}
_this->event_received = true;
} else if(redirectUrl != nullptr) {
log(LogLevel::Info, "WebSocket received deny message");
_this->event_received = true;
} else {
log(LogLevel::Error, "WebSocket event data incorrect");
return -1;
}
} else {
log(LogLevel::Debug, "WebSocket Not an confirmation event");
}
break;
}
case LWS_CALLBACK_CLIENT_CLOSED:
log(LogLevel::Debug, "WebSocket connection closed");
break;
case LWS_CALLBACK_CLIENT_CONNECTION_ERROR:
log(LogLevel::Error, "WebSocket connection error");
break;
default:
break;
}
return 0;
}
private:
static const struct lws_protocols protocols[];
};
const struct lws_protocols WebSocket::protocols[] = { //
{"wss", WebSocket::callback_ws, 1024, 1024, 0, NULL, 0},
{nullptr, nullptr, 0, 0, 0, nullptr, 0}};
} // namespace rublon