From 51b14c57d244e4159601452fa39eb6b5a4a02827 Mon Sep 17 00:00:00 2001 From: rublon-bwi <134260122+rublon-bwi@users.noreply.github.com> Date: Thu, 21 Sep 2023 16:52:20 +0200 Subject: [PATCH] Bwi/memory management (#2) Improve memory management --- .../rublon/authentication_step_interface.hpp | 41 ++- PAM/ssh/include/rublon/configuration.hpp | 21 +- PAM/ssh/include/rublon/core_handler.hpp | 82 ++++-- PAM/ssh/include/rublon/curl.hpp | 78 ++++-- PAM/ssh/include/rublon/error.hpp | 129 +++++++--- PAM/ssh/include/rublon/init.hpp | 61 +++-- PAM/ssh/include/rublon/json.hpp | 11 +- PAM/ssh/include/rublon/method/OTP.hpp | 6 +- PAM/ssh/include/rublon/method/SMS.hpp | 3 +- .../include/rublon/method/method_select.hpp | 41 ++- .../rublon/method/passcode_based_auth.hpp | 66 +++-- PAM/ssh/include/rublon/pam.hpp | 3 +- PAM/ssh/include/rublon/rublon.hpp | 4 +- PAM/ssh/include/rublon/sign.hpp | 15 +- PAM/ssh/include/rublon/utils.hpp | 129 +++++++--- PAM/ssh/lib/pam.cpp | 26 +- PAM/ssh/tests/CMakeLists.txt | 1 - PAM/ssh/tests/core_handler_tests.cpp | 55 ++-- PAM/ssh/tests/core_response_generator.hpp | 78 +----- PAM/ssh/tests/gtest_matchers.hpp | 39 +++ PAM/ssh/tests/http_mock.hpp | 238 +++++++++++++++--- PAM/ssh/tests/init_test.cpp | 108 +++----- PAM/ssh/tests/pam_info_mock.hpp | 16 +- PAM/ssh/tests/passcode_auth_tests.cpp | 38 ++- PAM/ssh/tests/rublon_tests.cpp | 7 - PAM/ssh/tests/sign_tests.cpp | 14 +- PAM/ssh/tests/utils_tests.cpp | 119 +++++---- 27 files changed, 919 insertions(+), 510 deletions(-) create mode 100644 PAM/ssh/tests/gtest_matchers.hpp delete mode 100644 PAM/ssh/tests/rublon_tests.cpp diff --git a/PAM/ssh/include/rublon/authentication_step_interface.hpp b/PAM/ssh/include/rublon/authentication_step_interface.hpp index d3e056c..6be7da0 100644 --- a/PAM/ssh/include/rublon/authentication_step_interface.hpp +++ b/PAM/ssh/include/rublon/authentication_step_interface.hpp @@ -1,8 +1,8 @@ #pragma once +#include #include #include -#include #include namespace rublon { @@ -14,7 +14,7 @@ class AuthenticationStep { std::string _tid; public: - AuthenticationStep() {} + AuthenticationStep() = default; AuthenticationStep(std::string systemToken, std::string tid) : _systemToken{std::move(systemToken)}, _tid{std::move(tid)} {} template < typename Handler_t > @@ -30,33 +30,28 @@ class AuthenticationStep { } protected: - void addSystemToken(Document & body, RapidJSONPMRAlloc & alloc) const { + void addSystemToken(Document & body) const { + auto & alloc = body.GetAllocator(); body.AddMember("systemToken", Value{this->_systemToken.c_str(), alloc}, alloc); } - void addTid(Document & body, RapidJSONPMRAlloc & alloc) const { + void addTid(Document & body) const { + auto & alloc = body.GetAllocator(); body.AddMember("tid", Value{this->_tid.c_str(), alloc}, alloc); } - template < typename HandlerReturn_t > - Error coreErrorHandler(const HandlerReturn_t & /*coreResponse*/) const { -// switch(coreResponse.error().errorClass) { -// case CoreHandlerError::ErrorClass::BadSigature: -// log(LogLevel::Error, "ErrorClass::BadSigature"); -// return PamAction::decline; -// case CoreHandlerError::ErrorClass::CoreException: /// TODO exception handling -// log(LogLevel::Error, "ErrorClass::CoreException"); -// return PamAction::decline; /// TODO accept? -// case CoreHandlerError::ErrorClass::ConnectionError: -// log(LogLevel::Error, "ErrorClass::ConnectionError"); -// return PamAction::decline; /// TODO decline? -// case CoreHandlerError::ErrorClass::BrokenData: -// log(LogLevel::Error, "ErrorClass::BrokenData"); -// return PamAction::decline; -// } -// return PamAction::decline; - - return Error{Critical{}}; + Error coreErrorHandler(const Error & coreResponse) const { + auto category = coreResponse.category(); + switch(category) { + case Error::k_SockerError: + log(LogLevel::Error, "Socker error, forcing override"); + return Error{PamBaypass{}}; + case Error::k_CoreHandlerError: + return Error{PamBaypass{}}; + default: + Critical{}; + } + return Critical{}; } }; diff --git a/PAM/ssh/include/rublon/configuration.hpp b/PAM/ssh/include/rublon/configuration.hpp index 8cb50c2..6ac5fef 100644 --- a/PAM/ssh/include/rublon/configuration.hpp +++ b/PAM/ssh/include/rublon/configuration.hpp @@ -26,31 +26,28 @@ class Configuration { bool enablePasswdEmail; bool logging; bool autopushPrompt; - }; - - Parameters parameters; + bool bypass; + } parameters; }; class ConfigurationFactory { public: - ConfigurationFactory(){}; + ConfigurationFactory() = default; std::optional< Configuration > systemConfig() { - std::array< char, 8 * 1024 > configBuffer; - std::pmr::monotonic_buffer_resource mr{configBuffer.data(), configBuffer.size()}; - + memory::MonotonicStackResource< 8 * 1024 > stackResource; using Params = Configuration::Parameters; Params configValues; std::ifstream file(std::filesystem::path{"/etc/rublon.config"}); - std::pmr::string line{&mr}; + std::pmr::string line{&stackResource}; line.reserve(100); - std::pmr::map< std::pmr::string, std::pmr::string > parameters{&mr}; + std::pmr::map< std::pmr::string, std::pmr::string > parameters{&stackResource}; while(std::getline(file, line)) { - std::pmr::string key{&mr}; - std::pmr::string value{&mr}; + std::pmr::string key{&stackResource}; + std::pmr::string value{&stackResource}; if(!line.length()) continue; @@ -62,7 +59,7 @@ class ConfigurationFactory { key = line.substr(0, posEqual); value = line.substr(posEqual + 1); - parameters.emplace(std::move(key), std::move(value)); + parameters[std::move(key)] = std::move(value); } auto saveStr = [&](auto member) { // diff --git a/PAM/ssh/include/rublon/core_handler.hpp b/PAM/ssh/include/rublon/core_handler.hpp index e35c0fb..a9d1e88 100644 --- a/PAM/ssh/include/rublon/core_handler.hpp +++ b/PAM/ssh/include/rublon/core_handler.hpp @@ -18,13 +18,10 @@ template < typename HttpHandler = CURL > class CoreHandler : public CoreHandlerInterface< CoreHandler< HttpHandler > > { std::string secretKey; std::string url; + bool bypass; - std::pmr::string xRublonSignature(std::pmr::memory_resource & mr, std::string_view body) const { - return {signData(body, secretKey.c_str()).data(), &mr}; - } - - void signRequest(std::pmr::monotonic_buffer_resource & mr, Request & request) const { - request.headers["X-Rublon-Signature"] = xRublonSignature(mr, request.body); + void signRequest(Request & request) const { + request.headers["X-Rublon-Signature"] = std::pmr::string{signData(request.body, secretKey).data(), request.headers.get_allocator()}; } bool responseSigned(const Response & response) const { @@ -41,19 +38,20 @@ class CoreHandler : public CoreHandlerInterface< CoreHandler< HttpHandler > > { public: CoreHandler(const Configuration & config) - : secretKey{config.parameters.secretKey}, url{config.parameters.apiServer}, http{[]() { return Response{}; }} {} + : secretKey{config.parameters.secretKey}, url{config.parameters.apiServer}, bypass{config.parameters.bypass}, http{} {} - tl::expected< Response, Error > validateSignature(const Response & response) const { + tl::expected< std::reference_wrapper< const Response >, Error > validateSignature(const Response & response) const { if(not responseSigned(response)) { log(LogLevel::Error, "CoreHandlerError::BadSigature"); return tl::unexpected{CoreHandlerError{CoreHandlerError::BadSigature}}; } + return response; } tl::expected< Document, Error > validateResponse(const Response & response) const { - RapidJSONPMRStackAlloc< 4 * 1024 > alloc{}; - Document resp{&alloc}; + RapidJSONPMRAlloc alloc{memory::default_resource()}; + Document resp{}; resp.Parse(response.body.c_str()); if(resp.HasParseError() or not resp.HasMember("result")) { @@ -64,37 +62,71 @@ class CoreHandler : public CoreHandlerInterface< CoreHandler< HttpHandler > > { if(resp["result"].HasMember("exception")) { const auto & exception = resp["result"]["exception"].GetString(); log(LogLevel::Error, "rublon Core exception %s", exception); - return tl::unexpected{CoreHandlerError{CoreHandlerError::CoreException, exception}}; + return handleCoreException(exception); } return resp; } + + tl::unexpected< Error > handleCoreException(std::string_view exceptionString) const { + if(exceptionString == "UserBypassedException" or exceptionString == "UserNotFoundException") + return tl::unexpected{Error{PamBaypass{}}}; + else + return tl::unexpected{ + Error{CoreHandlerError{CoreHandlerError::CoreException, std::string{exceptionString.data(), exceptionString.size()}}}}; + } - tl::expected< Document, Error > request(std::string_view path, const Document & body) const { - const auto validateSignature = [this](const auto & arg) { return this->validateSignature(arg); }; - const auto validateResponse = [this](const auto & arg) { return this->validateResponse(arg); }; + tl::unexpected< Error > handleHttpError() const { + if(bypass) { + log(LogLevel::Warning, "User login bypass"); + return tl::unexpected{Error{PamBaypass{}}}; + } else { + log(LogLevel::Warning, "User login deny due to HTTP error"); + return tl::unexpected{Error{PamDeny{}}}; + } + } - RapidJSONPMRStackAlloc< 1 * 1024 > alloc{}; + tl::expected< Document, Error > handleError(const Error & error) const { + if(error.is< HttpError >() and error.hasClass(HttpError::Error)) { + return handleHttpError(); + } + + return tl::unexpected{Error{error}}; + } + + template < typename T > + static void stringifyTo(const Document & body, T & to) { + memory::Monotonic_2k_HeapResource tmpResource; + RapidJSONPMRAlloc alloc{&tmpResource}; StringBuffer jsonStr{&alloc}; Writer writer{jsonStr, &alloc}; body.Accept(writer); + to = jsonStr.GetString(); + } - std::byte _buffer[2 * 1024]; - std::pmr::monotonic_buffer_resource mr{_buffer, sizeof(_buffer)}; + tl::expected< Document, Error > request(std::string_view path, const Document & body) const { + memory::StrictMonotonic_2k_HeapResource memoryResource; - Request request{mr}; - request.headers["Content-Type"] = "application/json"; - request.headers["Accept"] = "application/json"; + const auto validateSignature = [this](const auto & arg) { return this->validateSignature(arg); }; + const auto validateResponse = [this](const auto & arg) { return this->validateResponse(arg); }; + const auto handleError = [this](const auto & error) { return this->handleError(error); }; + const auto pmrs = [&](auto txt) { return std::pmr::string{txt, &memoryResource}; }; - request.body = jsonStr.GetString(); + Request request{&memoryResource}; - signRequest(mr, request); - std::pmr::string uri{url + path.data(), &mr}; + stringifyTo(body, request.body); - return http.request(uri, request) + request.headers["Content-Type"] = pmrs("application/json"); + request.headers["Accept"] = pmrs("application/json"); + signRequest(request); + + std::pmr::string uri{url + path.data(), &memoryResource}; + + return http + .request(uri, request) // .and_then(validateSignature) .and_then(validateResponse) - .or_else([](const Error & e) -> tl::expected< Document, Error > { return tl::unexpected{e}; }); + .or_else(handleError); } }; diff --git a/PAM/ssh/include/rublon/curl.hpp b/PAM/ssh/include/rublon/curl.hpp index 2e303de..dfec7d0 100644 --- a/PAM/ssh/include/rublon/curl.hpp +++ b/PAM/ssh/include/rublon/curl.hpp @@ -13,8 +13,8 @@ #include -#include #include +#include #include @@ -29,32 +29,56 @@ namespace { } // namespace struct Request { + std::pmr::memory_resource * _mr; + std::pmr::map< std::pmr::string, std::pmr::string > headers; std::pmr::string body; - Request(std::pmr::memory_resource & mr) : headers{&mr}, body{&mr} {} + public: + Request(std::pmr::memory_resource * mr) : _mr{mr}, headers{_mr}, body{_mr} {}; + + Request(const Request & res) = delete; + Request & operator=(const Request & res) = delete; + + Request(Request && res) = delete; + Request & operator=(Request &&) = delete; + + std::pmr::memory_resource * get_allocator() const noexcept { + return _mr; + } }; struct Response { - std::map< std::string, std::string > headers; - std::string body; + std::pmr::memory_resource * _mr; + + std::pmr::map< std::pmr::string, std::pmr::string > headers; + std::pmr::string body; + + public: + Response(std::pmr::memory_resource * mr) : _mr{mr}, headers{_mr}, body{_mr} {}; + + Response(const Response & res) = delete; + Response & operator=(const Response & res) = delete; + + Response(Response && res) noexcept = default; + Response & operator=(Response && res) noexcept = default; + + std::pmr::memory_resource * get_allocator() const noexcept { + return _mr; + } }; class CURL { std::unique_ptr< ::CURL, void (*)(::CURL *) > curl; - std::function< Response() > _responseFactory; public: - CURL(std::function< Response() > responseFactory) - : curl{std::unique_ptr< ::CURL, void (*)(::CURL *) >(curl_easy_init(), curl_easy_cleanup)}, - _responseFactory{std::move(responseFactory)} {} - - tl::expected< Response, Error > request(std::string_view uri, const Request & request) const { - std::array< char, 16 * 1024 > buffer = {}; - std::pmr::monotonic_buffer_resource mr{buffer.data(), buffer.size()}; - - std::pmr::string response_data{&mr}; - response_data.reserve(15000); + CURL() : curl{std::unique_ptr< ::CURL, void (*)(::CURL *) >(curl_easy_init(), curl_easy_cleanup)} {} + + tl::expected< Response, HttpError > request(std::string_view uri, const Request & request) const { + memory::MonotonicStackResource< 8 * 1024 > stackResource; + + std::pmr::string response_data{&stackResource}; + response_data.reserve(6000); /// TODO this can be done on stack using pmr auto curl_headers = std::unique_ptr< curl_slist, void (*)(curl_slist *) >(nullptr, curl_slist_free_all); @@ -74,20 +98,30 @@ class CURL { curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_data); log(LogLevel::Debug, "%s Request send, uri:%s body:\n%s\n", "CURL", uri.data(), request.body.c_str()); - auto res = curl_easy_perform(curl.get()); + const auto res = curl_easy_perform(curl.get()); + if(res != CURLE_OK) { log(LogLevel::Error, "%s No response from Rublon server err:{%s}", "CURL", curl_easy_strerror(res)); - return tl::unexpected{Error{SocketError{SocketError::Timeout}}}; + return tl::unexpected{HttpError{HttpError::Timeout, 0}}; } - + + long http_code = 0; + curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, http_code); + + if(http_code >= 400) { + log(LogLevel::Error, "%s response with code %d ", "CURL", http_code); + return tl::unexpected{HttpError{HttpError::Error, http_code}}; + } + log(LogLevel::Debug, "Response:\n%s\n", response_data.c_str()); - Response response; - + long size; curl_easy_getinfo(curl.get(), CURLINFO_HEADER_SIZE, &size); - response.headers = details::headers({response_data.data(), static_cast< std::size_t >(size)}); - response.body = response_data.substr(size); + Response response{memory::default_resource()}; + + details::headers(response_data, response.headers); + response.body = response_data.substr(size); return response; } diff --git a/PAM/ssh/include/rublon/error.hpp b/PAM/ssh/include/rublon/error.hpp index 4d573fd..599d815 100644 --- a/PAM/ssh/include/rublon/error.hpp +++ b/PAM/ssh/include/rublon/error.hpp @@ -5,76 +5,94 @@ #include namespace rublon { - -class SocketError { +class NoError { public: - enum ErrorClass { Timeout }; + static constexpr int errorClass = 0; +}; - SocketError() : errorClass{Timeout} {} - SocketError(ErrorClass e) : errorClass{e} {} - SocketError(ErrorClass e, std::string r) : errorClass{e}, reson{std::move(r)} {} +class HttpError { + public: + enum ErrorClass { Timeout, Error }; + + constexpr HttpError() : errorClass{Timeout}, httpCode(200) {} + constexpr HttpError(ErrorClass e, long httpCode) : errorClass{e}, httpCode(httpCode) {} ErrorClass errorClass; - std::string reson; - - // error_category interface - public: - const char * name() const noexcept { - return "SockerError"; - } - std::string message(int) const { - return ""; - } + long httpCode; }; class CoreHandlerError { public: - enum ErrorClass { BadSigature, CoreException, ConnectionError, BrokenData }; + enum ErrorClass { BadSigature, CoreException, BrokenData }; CoreHandlerError(ErrorClass e) : errorClass{e} {} CoreHandlerError(ErrorClass e, std::string r) : errorClass{e}, reson{std::move(r)} {} ErrorClass errorClass; std::string reson; - const char * name() const noexcept { - return "CoreHandlerError"; - } - std::string message(int) const { - return ""; - } }; class MethodError { public: enum ErrorClass { BadMethod }; - MethodError(ErrorClass e) : errorClass{e} {} - MethodError(ErrorClass e, std::string r) : errorClass{e}, reson{std::move(r)} {} + constexpr MethodError(ErrorClass e) : errorClass{e} {} ErrorClass errorClass; - std::string reson; - const char * name() const noexcept { - return "MethodError"; - } - std::string message(int) const { - return ""; - } }; -class Critical {}; +class WerificationError { + public: + enum ErrorClass { WrongCode }; + + constexpr WerificationError(ErrorClass e) : errorClass{e} {} + + ErrorClass errorClass; +}; + +class Critical { + public: + enum ErrorClass { Nok }; + + Critical(ErrorClass e = Nok) : errorClass{e} {} + + ErrorClass errorClass; +}; + +class PamBaypass { + public: + enum ErrorClass { Nok }; + + constexpr PamBaypass(ErrorClass e = Nok) : errorClass{e} {} + + ErrorClass errorClass; +}; + +class PamDeny { + public: + enum ErrorClass { Nok }; + + constexpr PamDeny(ErrorClass e = Nok) : errorClass{e} {} + + ErrorClass errorClass; +}; class Error { - std::variant< std::monostate, CoreHandlerError, SocketError, Critical, MethodError > _error; - - enum class Category { None, CoreHandlerError, SockerError, Criticat, MethodError }; + using Error_t = std::variant< NoError, CoreHandlerError, HttpError, WerificationError, Critical, MethodError, PamBaypass, PamDeny >; + Error_t _error; public: + enum Category { k_None, k_CoreHandlerError, k_SockerError, k_WerificationError, k_Critical, k_MethodError, k_PamBaypass, k_PamDeny }; + Error() = default; Error(CoreHandlerError error) : _error{error} {} - Error(SocketError error) : _error{error} {} - Error(Critical error) : _error{error} {} + Error(HttpError error) : _error{error} {} Error(MethodError error) : _error{error} {} + Error(WerificationError error) : _error{error} {} + Error(Critical error) : _error{error} {} + Error(PamBaypass error) : _error{error} {} + Error(PamDeny error) : _error{error} {} Error(const Error &) = default; Error(Error &&) = default; @@ -90,9 +108,42 @@ class Error { return _error.index() == 2; } - Category category() const noexcept { + constexpr Category category() const noexcept { return static_cast< Category >(_error.index()); } + + constexpr int errorClass() const noexcept { + return std::visit([](const auto & e) { return static_cast< int >(e.errorClass); }, _error); + } + + template < typename E > + constexpr bool is() const { + return category() == Error{E{}}.category(); + } + + template < typename E > + constexpr bool isSameCategoryAs(const E & e) const { + return category() == Error{e}.category(); + } + + constexpr bool hasClass(int _class) const { + return errorClass() == _class; + } + + template < typename E > + constexpr bool hasSameErrorClassAs(E e) const { + assert(isSameCategoryAs(e)); + return errorClass() == Error{e}.errorClass(); + } }; +constexpr bool operator==(const Error & e, const HttpError & socket) { + return e.sockerError() && e.errorClass() == socket.errorClass; +} + +template < typename T > +constexpr bool operator==(const Error & lhs, const T & rhs) { + return lhs.isSameCategoryAs(rhs) && lhs.hasSameErrorClassAs(rhs); +} + } // namespace rublon diff --git a/PAM/ssh/include/rublon/init.hpp b/PAM/ssh/include/rublon/init.hpp index 341e37c..aeabf7a 100644 --- a/PAM/ssh/include/rublon/init.hpp +++ b/PAM/ssh/include/rublon/init.hpp @@ -19,6 +19,36 @@ class Init : public AuthenticationStep< Init< MethodSelect_t > > { const char * apiPath = "/api/transaction/init"; + tl::expected< MethodSelect_t, Error > createMethod(const Document & coreResponse) const { + const auto & rublonResponse = coreResponse["result"]; + std::string tid = rublonResponse["tid"].GetString(); + return MethodSelect_t{this->_systemToken, tid, rublonResponse["methods"]}; + } + + tl::expected< MethodSelect_t, Error > handleInitError(const Error & error) const { + return tl::unexpected{this->coreErrorHandler(error)}; + } + + template < typename PamInfo_t > + void addPamInfo(Document & body, const PamInfo_t & pam) const { + auto & alloc = body.GetAllocator(); + body.AddMember("username", Value{pam.username().get(), alloc}, alloc); + body.AddMember("userEmail", "bwi@rublon.com", alloc); /// TODO proper useremail + } + + template < typename PamInfo_t > + void addParams(Document & body, const PamInfo_t & pam) const { + auto & alloc = body.GetAllocator(); + + Value params{rapidjson::kObjectType}; + + params.AddMember("userIP", Value{pam.ip().get(), alloc}, alloc); + params.AddMember("appVer", "v.1.6", alloc); /// TODO add version to cmake + params.AddMember("os", "Ubuntu 23.04", alloc); /// TODO add version to cmake + + body.AddMember("params", std::move(params), alloc); + } + public: const char * name = "Initialization"; @@ -27,31 +57,20 @@ class Init : public AuthenticationStep< Init< MethodSelect_t > > { /// TODO add core handler interface template < typename Hander_t, typename PamInfo_t = LinuxPam > tl::expected< MethodSelect_t, Error > handle(const CoreHandlerInterface< Hander_t > & coreHandler, const PamInfo_t & pam) const { + const auto createMethod = [&](const auto & coreResponse) { return this->createMethod(coreResponse); }; + const auto handleInitError = [&](const auto & error) { return this->handleInitError(error); }; + RapidJSONPMRStackAlloc< 1024 > alloc{}; Document body{rapidjson::kObjectType, &alloc}; - this->addSystemToken(body, alloc); + this->addSystemToken(body); + this->addPamInfo(body, pam); + this->addParams(body, pam); - body.AddMember("username", Value{pam.username().get(), alloc}, alloc); - body.AddMember("userEmail", "bwi@rublon.com", alloc); /// TODO proper username - - Value params{rapidjson::kObjectType}; - params.AddMember("userIP", Value{pam.ip().get(), alloc}, alloc); - params.AddMember("appVer", "v.1.6", alloc); /// TODO add version to cmake - params.AddMember("os", "Ubuntu 23.04", alloc); /// TODO add version to cmake - - body.AddMember("params", std::move(params), alloc); - - auto coreResponse = coreHandler.request(apiPath, body); - - if(coreResponse.has_value()) { - log(LogLevel::Info, "[TMP] has response, processing", __PRETTY_FUNCTION__); - const auto & rublonResponse = coreResponse.value()["result"]; - std::string tid = rublonResponse["tid"].GetString(); - return MethodSelect_t{this->_systemToken, tid, rublonResponse["methods"]}; - } else { - return tl::unexpected{this->coreErrorHandler(coreResponse)}; - } + return coreHandler + .request(apiPath, body) // + .and_then(createMethod) + .or_else(handleInitError); } }; } // namespace rublon diff --git a/PAM/ssh/include/rublon/json.hpp b/PAM/ssh/include/rublon/json.hpp index 8509d4c..a7b1dcf 100644 --- a/PAM/ssh/include/rublon/json.hpp +++ b/PAM/ssh/include/rublon/json.hpp @@ -1,6 +1,8 @@ #pragma once #include "rapidjson/document.h" +#include "rapidjson/pointer.h" #include "rapidjson/writer.h" +#include "rublon/utils.hpp" #include #include @@ -47,8 +49,7 @@ struct RapidJSONPMRAlloc { } void * newPtr = Malloc(newSize); - if(originalSize) - { + if(originalSize) { std::memcpy(newPtr, origPtr, originalSize); freePtr(origPtr, originalSize); } @@ -81,6 +82,10 @@ struct RapidJSONPMRStackAlloc : public RapidJSONPMRAlloc { public: RapidJSONPMRStackAlloc() : RapidJSONPMRAlloc(&mr) {} + RapidJSONPMRStackAlloc(const RapidJSONPMRStackAlloc &) = delete; + RapidJSONPMRStackAlloc(RapidJSONPMRStackAlloc &&) = delete; + RapidJSONPMRStackAlloc & operator=(const RapidJSONPMRStackAlloc &) = delete; + RapidJSONPMRStackAlloc & operator=(RapidJSONPMRStackAlloc &&) = delete; }; using Document = rapidjson::GenericDocument< rapidjson::UTF8<>, RapidJSONPMRAlloc >; @@ -88,4 +93,6 @@ using Value = rapidjson::GenericValue< rapidjson::UTF8<>, RapidJSONPMRAll using StringBuffer = rapidjson::GenericStringBuffer< rapidjson::UTF8<>, RapidJSONPMRAlloc >; using Writer = rapidjson::Writer< StringBuffer, rapidjson::UTF8<>, rapidjson::UTF8<>, RapidJSONPMRAlloc >; +using JSONPointer = rapidjson::GenericPointer< Value, RapidJSONPMRAlloc >; + } // namespace rublon diff --git a/PAM/ssh/include/rublon/method/OTP.hpp b/PAM/ssh/include/rublon/method/OTP.hpp index f62219b..f916fd1 100644 --- a/PAM/ssh/include/rublon/method/OTP.hpp +++ b/PAM/ssh/include/rublon/method/OTP.hpp @@ -10,10 +10,10 @@ namespace rublon::method { -class OTP:public PasscodeBasedAuth{ +class OTP : public PasscodeBasedAuth { public: - OTP(std::string systemToken, std::string tid) : PasscodeBasedAuth(std::move(systemToken), std::move(tid), "OTP", "Mobile TOTP from Rublon Authenticator:") {} - + OTP(std::string systemToken, std::string tid) + : PasscodeBasedAuth(std::move(systemToken), std::move(tid), "OTP", "Mobile TOTP from Rublon Authenticator:") {} }; } // namespace rublon::method diff --git a/PAM/ssh/include/rublon/method/SMS.hpp b/PAM/ssh/include/rublon/method/SMS.hpp index 9fb5521..9233765 100644 --- a/PAM/ssh/include/rublon/method/SMS.hpp +++ b/PAM/ssh/include/rublon/method/SMS.hpp @@ -10,10 +10,9 @@ namespace rublon::method { -class SMS:public PasscodeBasedAuth{ +class SMS : public PasscodeBasedAuth { public: SMS(std::string systemToken, std::string tid) : PasscodeBasedAuth(std::move(systemToken), std::move(tid), "SMS", "SMS passcode:") {} - }; } // namespace rublon::method diff --git a/PAM/ssh/include/rublon/method/method_select.hpp b/PAM/ssh/include/rublon/method/method_select.hpp index 182931b..276dfaf 100644 --- a/PAM/ssh/include/rublon/method/method_select.hpp +++ b/PAM/ssh/include/rublon/method/method_select.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -63,7 +64,7 @@ class PostMethod : public rublon::AuthenticationStep< PostMethod > { const char * uri = "/api/transaction/methodSSH"; std::string _method; - tl::expected< MethodProxy, Error > doCreateMethod(const Document & coreResponse) const { + tl::expected< MethodProxy, Error > createMethod(const Document & coreResponse) const { const auto & rublonResponse = coreResponse["result"]; std::string tid = rublonResponse["tid"].GetString(); @@ -77,6 +78,14 @@ class PostMethod : public rublon::AuthenticationStep< PostMethod > { return tl::unexpected{MethodError{MethodError::BadMethod}}; } + void addParams(Document & body) const { + auto & alloc = body.GetAllocator(); + + body.AddMember("method", Value{_method.c_str(), alloc}, alloc); + body.AddMember("GDPRAccepted", "true", alloc); + body.AddMember("tosAccepted", "true", alloc); + } + public: const char * name = "Confirm Method"; @@ -85,17 +94,14 @@ class PostMethod : public rublon::AuthenticationStep< PostMethod > { template < typename Hander_t, typename PamInfo_t = LinuxPam > tl::expected< MethodProxy, Error > handle(const CoreHandlerInterface< Hander_t > & coreHandler) const { - auto createMethod = [&](const auto & coreResponse) { return doCreateMethod(coreResponse); }; + auto createMethod = [&](const auto & coreResponse) { return this->createMethod(coreResponse); }; RapidJSONPMRStackAlloc< 1024 > alloc{}; Document body{rapidjson::kObjectType, &alloc}; - this->addSystemToken(body, alloc); - this->addTid(body, alloc); - - body.AddMember("method", Value{_method.c_str(), alloc}, alloc); - body.AddMember("GDPRAccepted", "true", alloc); - body.AddMember("tosAccepted", "true", alloc); + this->addSystemToken(body); + this->addTid(body); + this->addParams(body); return coreHandler .request(uri, body) // @@ -122,30 +128,23 @@ class MethodSelect { template < typename Pam_t > tl::expected< PostMethod, Error > create(Pam_t & pam) const { - std::pmr::map< int, std::string > methods_id; + memory::StrictMonotonic_1k_HeapResource memoryResource; + std::pmr::map< int, std::string > methods_id{&memoryResource}; pam.print("%s", ""); int i{}; + for(const auto & method : _methods) { rublon::log(LogLevel::Debug, "method %s found at pos %d", method.c_str(), i); - if(method == "email") { - pam.print("%d: Email Link", i + 1); - methods_id[++i] = "email"; - } - if(method == "qrcode") { - pam.print("%d: QR Code", i + 1); - methods_id[++i] = "qrcode"; - } if(method == "totp") { pam.print("%d: Mobile TOTP", i + 1); methods_id[++i] = "totp"; + continue; } - if(method == "push") { - pam.print("%d: Mobile Push", i + 1); - methods_id[++i] = "push"; - } + if(method == "sms") { pam.print("%d: SMS code", i + 1); methods_id[++i] = "sms"; + continue; } } diff --git a/PAM/ssh/include/rublon/method/passcode_based_auth.hpp b/PAM/ssh/include/rublon/method/passcode_based_auth.hpp index b8b1546..033f708 100644 --- a/PAM/ssh/include/rublon/method/passcode_based_auth.hpp +++ b/PAM/ssh/include/rublon/method/passcode_based_auth.hpp @@ -13,10 +13,49 @@ class PasscodeBasedAuth : public AuthenticationStep< PasscodeBasedAuth > { const char * uri = "/api/transaction/confirmCode"; const char * userMessage{nullptr}; + constexpr static bool isdigit(char ch) { + return std::isdigit(static_cast< unsigned char >(ch)); + } + + static bool hasDigitsOnly(const std::string & userinput) { + return std::all_of(userinput.cbegin(), userinput.cend(), isdigit); + } + + static bool isProperLength(const std::string & userInput) { + return userInput.size() == 6; + } + template < typename PamInfo_t = LinuxPam > - std::string readPasscode(const PamInfo_t & pam) const { - /// TODO handle bad user input / wrong code etc - return pam.scan([](const char * userInput) { return std::string{userInput}; }, userMessage).value(); + tl::expected< std::reference_wrapper< Document >, Error > readPasscode(Document & body, const PamInfo_t & pam) const { + auto & alloc = body.GetAllocator(); + auto code = pam.scan([](const char * userInput) { return std::string{userInput}; }, userMessage); + + if(code.has_value()) { + const auto & vericode = code.value(); + if(isProperLength(vericode) and hasDigitsOnly(vericode)) { + body.AddMember("vericode", Value{vericode.c_str(), alloc}, alloc); + return body; + } + } + + return tl::unexpected{Error{WerificationError{WerificationError::WrongCode}}}; + } + + template < typename PamInfo_t = LinuxPam > + tl::expected< std::reference_wrapper< Document >, Error > askForPasscodeAgain(Document & body, const PamInfo_t & pam) const { + pam.print("passcode has wrong number of digits or illegal characters, please correct"); + return readPasscode(body, pam); + } + + tl::expected< AuthenticationStatus, Error > checkAuthenticationStatus(const Document & coreResponse) const { + RapidJSONPMRStackAlloc< 1024 > alloc; + auto error = JSONPointer{"/result/error", &alloc}.Get(coreResponse); + + if(error) { + return tl::unexpected{Error{WerificationError{WerificationError::WrongCode}}}; + } + + return AuthenticationStatus{AuthenticationStatus::Action::Confirmed}; } public: @@ -30,20 +69,17 @@ class PasscodeBasedAuth : public AuthenticationStep< PasscodeBasedAuth > { RapidJSONPMRStackAlloc< 1024 > alloc{}; Document body{rapidjson::kObjectType, &alloc}; - this->addSystemToken(body, alloc); - this->addTid(body, alloc); - - body.AddMember("vericode", Value{readPasscode(pam).c_str(), alloc}, alloc); /// TODO proper username + const auto checkAuthenticationStatus = [this](const auto & coreResponse) { return this->checkAuthenticationStatus(coreResponse); }; + const auto requestAuthorization = [&](const auto & body) { return coreHandler.request(uri, body); }; + const auto askForPasscodeAgain = [&](const auto & /*error*/) { return this->askForPasscodeAgain(body, pam); }; - auto coreResponse = coreHandler.request(uri, body); + this->addSystemToken(body); + this->addTid(body); - if(coreResponse.has_value()) { - // {"status":"OK","result":{"error":"Hmm, that's not the right code. Try again."}} - // const auto & rublonResponse = coreResponse.value()["result"]; - return AuthenticationStatus{AuthenticationStatus::Action::Confirmed}; - } else { - return tl::unexpected{this->coreErrorHandler(coreResponse)}; - } + return readPasscode(body, pam) // + .or_else(askForPasscodeAgain) + .and_then(requestAuthorization) + .and_then(checkAuthenticationStatus); } }; diff --git a/PAM/ssh/include/rublon/pam.hpp b/PAM/ssh/include/rublon/pam.hpp index b9dd305..c6f5964 100644 --- a/PAM/ssh/include/rublon/pam.hpp +++ b/PAM/ssh/include/rublon/pam.hpp @@ -49,7 +49,7 @@ class LinuxPam { void print(const char * fmt, Ti... ti) const noexcept { pam_prompt(pamh, PAM_TEXT_INFO, nullptr, fmt, std::forward< Ti >(ti)...); } - + template < typename Fun, typename... Ti > auto scan(Fun && f, const char * fmt, Ti... ti) const noexcept { char * responseBuffer = nullptr; @@ -61,6 +61,5 @@ class LinuxPam { } return std::optional< std::result_of_t< Fun(char *) > >(); } - }; } // namespace rublon diff --git a/PAM/ssh/include/rublon/rublon.hpp b/PAM/ssh/include/rublon/rublon.hpp index 55795d3..53f687f 100644 --- a/PAM/ssh/include/rublon/rublon.hpp +++ b/PAM/ssh/include/rublon/rublon.hpp @@ -1,5 +1,3 @@ #pragma once -namespace rublon { - -} // namespace rublon +namespace rublon {} // namespace rublon diff --git a/PAM/ssh/include/rublon/sign.hpp b/PAM/ssh/include/rublon/sign.hpp index 76e4701..2f711f1 100644 --- a/PAM/ssh/include/rublon/sign.hpp +++ b/PAM/ssh/include/rublon/sign.hpp @@ -8,16 +8,17 @@ namespace rublon { -inline std::array< char, 64 > signData(std::string_view data, std::string_view secretKey) { - std::array< char, 64 > xRublon; - unsigned char md[EVP_MAX_MD_SIZE] = {0}; - unsigned int md_len; +// +1 for \0 +inline std::array< char, 64 + 1 > signData(std::string_view data, std::string_view secretKey) { + std::array< char, 64 + 1 > xRublon; + std::array< unsigned char, EVP_MAX_MD_SIZE > md; + unsigned int md_len{}; - HMAC(EVP_sha256(), secretKey.data(), secretKey.size(), ( unsigned const char * ) data.data(), data.size(), md, &md_len); + HMAC(EVP_sha256(), secretKey.data(), secretKey.size(), ( unsigned const char * ) data.data(), data.size(), md.data(), &md_len); - int i; - for(i = 0; i < 32; i++) + for(unsigned int i = 0; i < md_len; i++) sprintf(&xRublon[i * 2], "%02x", ( unsigned int ) md[i]); + return xRublon; } diff --git a/PAM/ssh/include/rublon/utils.hpp b/PAM/ssh/include/rublon/utils.hpp index 69234a2..addb752 100644 --- a/PAM/ssh/include/rublon/utils.hpp +++ b/PAM/ssh/include/rublon/utils.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -17,8 +18,78 @@ #include #include #include +#include namespace rublon { +namespace memory { + struct holder { + static inline std::pmr::memory_resource * _mr = std::pmr::get_default_resource(); + }; + + inline void set_default_resource(std::pmr::memory_resource * memory_resource) { + holder{}._mr = memory_resource; + } + + inline std::pmr::memory_resource * default_resource() { + return holder{}._mr; + } + + template < std::size_t N > + class MonotonicStackResource : public std::pmr::monotonic_buffer_resource { + char _buffer[N]; + + public: + MonotonicStackResource() : std::pmr::monotonic_buffer_resource{_buffer, N, std::pmr::null_memory_resource()} {} + }; + + template < std::size_t N > + class UnsynchronizedStackResource : public std::pmr::unsynchronized_pool_resource { + MonotonicStackResource< N > _upstream; + + public: + UnsynchronizedStackResource() : std::pmr::unsynchronized_pool_resource{&_upstream} {} + }; + + class MonotonicHeapResourceBase { + public: + std::pmr::memory_resource * _upstream{}; + std::size_t _size{}; + void * _buffer{nullptr}; + + MonotonicHeapResourceBase(std::size_t size) : _upstream{default_resource()}, _size{size}, _buffer{_upstream->allocate(size)} {} + + ~MonotonicHeapResourceBase() { + if(_buffer) + _upstream->deallocate(_buffer, _size); + } + }; + + template < std::size_t N > + class MonotonicHeapResource : MonotonicHeapResourceBase, public std::pmr::monotonic_buffer_resource { + public: + MonotonicHeapResource() + : MonotonicHeapResourceBase{N}, std::pmr::monotonic_buffer_resource{this->_buffer, this->_size, default_resource()} {} + }; + + template < std::size_t N > + class StrictMonotonicHeapResource : MonotonicHeapResourceBase, public std::pmr::monotonic_buffer_resource { + public: + StrictMonotonicHeapResource() + : MonotonicHeapResourceBase{N}, + std::pmr::monotonic_buffer_resource{this->_buffer, this->_size, std::pmr::null_memory_resource()} {} + }; + + using StrictMonotonic_1k_HeapResource = StrictMonotonicHeapResource< 1 * 1024 >; + using StrictMonotonic_2k_HeapResource = StrictMonotonicHeapResource< 2 * 1024 >; + using StrictMonotonic_4k_HeapResource = StrictMonotonicHeapResource< 4 * 1024 >; + using StrictMonotonic_8k_HeapResource = StrictMonotonicHeapResource< 8 * 1024 >; + + using Monotonic_1k_HeapResource = MonotonicHeapResource< 1 * 1024 >; + using Monotonic_2k_HeapResource = MonotonicHeapResource< 2 * 1024 >; + using Monotonic_4k_HeapResource = MonotonicHeapResource< 4 * 1024 >; + using Monotonic_8k_HeapResource = MonotonicHeapResource< 8 * 1024 >; +} // namespace memory + enum class LogLevel { Debug, Info, Warning, Error }; inline auto dateStr() { @@ -29,22 +100,23 @@ inline auto dateStr() { return date; } -constexpr const char* LogLevelNames[] {"Debug", "Info", "Warning", "Error"}; - +constexpr const char * LogLevelNames[]{"Debug", "Info", "Warning", "Error"}; constexpr LogLevel g_level = LogLevel::Debug; +constexpr bool syncLogFile = true; -namespace details{ +namespace details { static void doLog(LogLevel level, const char * line) noexcept { constexpr auto file_name = "/var/log/rublon-ssh.log"; - + auto fp = std::unique_ptr< FILE, int (*)(FILE *) >(fopen(file_name, "a"), fclose); - + if(fp) { - fprintf(fp.get(), "%s [%s] %s\n", dateStr().data(), LogLevelNames[(int)level], line); - sync(); + fprintf(fp.get(), "%s [%s] %s\n", dateStr().data(), LogLevelNames[( int ) level], line); + if(syncLogFile) + sync(); } } -} +} // namespace details inline void log(LogLevel level, const char * line) noexcept { if(level < g_level) @@ -92,9 +164,9 @@ class NonOwningPtr { }; namespace details { - inline bool to_bool(std::string_view value) { - auto * buf = ( char * ) alloca(value.size()); + /// TODO change to global allocator + auto * buf = ( char * ) alloca(value.size() + 1); buf[value.size()] = '\0'; auto asciitolower = [](char in) { return in - ((in <= 'Z' && in >= 'A') ? ('Z' - 'z') : 0); }; @@ -118,46 +190,25 @@ namespace details { return ltrim(rtrim(s)); } - inline std::map< std::string, std::string > headers(std::string_view data) { - std::map< std::string, std::string > headers{}; + template < typename Headers > + inline void headers(std::string_view data, Headers & headers) { + memory::StrictMonotonic_4k_HeapResource stackResource; + std::pmr::string tmp{&stackResource}; - std::string tmp{}; std::istringstream resp{}; resp.rdbuf()->pubsetbuf(const_cast< char * >(data.data()), data.size()); - while(std::getline(resp, tmp)) { + while(std::getline(resp, tmp) && !(trim(tmp).empty())) { auto line = std::string_view(tmp); auto index = tmp.find(':', 0); if(index != std::string::npos) { - headers.insert({std::string{trim(line.substr(0, index))}, std::string{trim(line.substr(index + 1))}}); + headers.insert({// + typename Headers::key_type{trim(line.substr(0, index)), headers.get_allocator()}, + typename Headers::mapped_type{trim(line.substr(index + 1)), headers.get_allocator()}}); } } - - return headers; } - namespace pmr { - inline std::pmr::map< std::pmr::string, std::pmr::string > headers(std::pmr::memory_resource * mr, std::string_view data) { - char _buf[1024]; - std::pmr::monotonic_buffer_resource tmr{_buf, sizeof(_buf)}; - std::pmr::map< std::pmr::string, std::pmr::string > headers{mr}; - - std::pmr::string tmp{&tmr}; - std::istringstream resp{}; - resp.rdbuf()->pubsetbuf(const_cast< char * >(data.data()), data.size()); - - while(std::getline(resp, tmp) && !(trim(tmp).empty())) { - auto line = std::string_view(tmp); - auto index = tmp.find(':', 0); - if(index != std::string::npos) { - headers.insert({std::pmr::string{trim(line.substr(0, index)), mr}, std::pmr::string{trim(line.substr(index + 1)), mr}}); - } - } - - return headers; - } - } // namespace pmr - } // namespace details } // namespace rublon diff --git a/PAM/ssh/lib/pam.cpp b/PAM/ssh/lib/pam.cpp index f7f7222..6a49013 100644 --- a/PAM/ssh/lib/pam.cpp +++ b/PAM/ssh/lib/pam.cpp @@ -4,9 +4,8 @@ #include #include -#include - #include +#include #include #include #include @@ -15,13 +14,6 @@ using namespace std; -namespace { - -template < typename T > -using Expected = tl::expected< T, rublon::Error >; - -} // namespace - DLL_PUBLIC int pam_sm_setcred([[maybe_unused]] pam_handle_t * pamh, [[maybe_unused]] int flags, [[maybe_unused]] int argc, @@ -42,6 +34,11 @@ pam_sm_authenticate(pam_handle_t * pamh, [[maybe_unused]] int flags, [[maybe_unu auto rublonConfig = ConfigurationFactory{}.systemConfig(); + std::byte sharedMemory[32 * 1024] = {}; + std::pmr::monotonic_buffer_resource mr{sharedMemory, std::size(sharedMemory)}; + std::pmr::unsynchronized_pool_resource rublonPoolResource{&mr}; + std::pmr::set_default_resource(&rublonPoolResource); + CoreHandler CH{rublonConfig.value()}; LinuxPam pam{pamh}; @@ -54,10 +51,17 @@ pam_sm_authenticate(pam_handle_t * pamh, [[maybe_unused]] int flags, [[maybe_unu .and_then(selectMethod) .and_then(confirmMethod) .and_then(verifi); - + if(authStatus.has_value()) { rublon::log(rublon::LogLevel::Info, "Auth OK"); + return PAM_SUCCESS; + } else { + const auto & error = authStatus.error(); + + rublon::log( + LogLevel::Error, "auth failed due to %d class and %d category", error.errorClass(), static_cast< int >(error.category())); } - return PAM_SUCCESS; + rublon::log(LogLevel::Warning, "User login failed"); + return PAM_PERM_DENIED; } diff --git a/PAM/ssh/tests/CMakeLists.txt b/PAM/ssh/tests/CMakeLists.txt index f547894..28cbf38 100644 --- a/PAM/ssh/tests/CMakeLists.txt +++ b/PAM/ssh/tests/CMakeLists.txt @@ -25,7 +25,6 @@ add_executable(rublon-tests ./init_test.cpp ./method_select_tests.cpp ./passcode_auth_tests.cpp - ./rublon_tests.cpp ./sign_tests.cpp ./utils_tests.cpp diff --git a/PAM/ssh/tests/core_handler_tests.cpp b/PAM/ssh/tests/core_handler_tests.cpp index 97a462c..358877b 100644 --- a/PAM/ssh/tests/core_handler_tests.cpp +++ b/PAM/ssh/tests/core_handler_tests.cpp @@ -4,13 +4,14 @@ #include #include +#include "gtest_matchers.hpp" #include "http_mock.hpp" #include using namespace rublon; class CoreHandlerTestable : public CoreHandler< HttpHandlerMock > { public: - CoreHandlerTestable() : CoreHandler< HttpHandlerMock >{conf} {} + CoreHandlerTestable(rublon::Configuration _conf = conf) : CoreHandler< HttpHandlerMock >{_conf} {} auto & _http() { return http; } @@ -30,43 +31,47 @@ class CoreHandlerTests : public testing::Test { using namespace testing; -MATCHER(HasValue, "") { - return arg.has_value(); -} - -bool operator==(const Error & lhs, const Error & rhs) { - return lhs.category() == rhs.category(); -} - -MATCHER_P(Unexpected, error, "") { - return arg.error() == error; -} - -auto ReturnError(SocketError::ErrorClass value) { - return Return(tl::unexpected{Error{SocketError{value}}}); -} - TEST_F(CoreHandlerTests, coreShouldSetConnectionErrorOnBrokenConnection) { - EXPECT_CALL(http, request(_, _)).WillOnce(ReturnError(SocketError::Timeout)); + EXPECT_CALL(http, request(_, _)).WillOnce(Return(RawHttpResponse{}.withTimeoutError())); EXPECT_THAT(sut.request("", doc), // - AllOf(Not(HasValue()), Unexpected(CoreHandlerError{CoreHandlerError{CoreHandlerError::ConnectionError}}))); + AllOf(Not(HasValue()), Unexpected(HttpError{}))); } TEST_F(CoreHandlerTests, coreShouldCheckSignatureAndReturnBadSignatureBeforeAnythingElse) { - EXPECT_CALL(http, request(_, _)).WillOnce(Return(( Response ) http.brokenSignature().brokenBody())); + EXPECT_CALL(http, request(_, _)).WillOnce(Return(RawHttpResponse{}.initMessage().withBrokenBody().withBrokenSignature())); EXPECT_THAT(sut.request("", doc), // AllOf(Not(HasValue()), Unexpected(CoreHandlerError{CoreHandlerError::BadSigature}))); } TEST_F(CoreHandlerTests, coreShouldSetBrokenDataWhenResultIsNotAvailable) { - EXPECT_CALL(http, request(_, _)).WillOnce(Return(( Response ) http.brokenBody())); + EXPECT_CALL(http, request(_, _)).WillOnce(Return(RawHttpResponse{}.initMessage().withBrokenBody())); EXPECT_THAT(sut.request("", doc), // AllOf(Not(HasValue()), Unexpected(CoreHandlerError{CoreHandlerError::BrokenData}))); } TEST_F(CoreHandlerTests, coreSignatureIsBeingChecked) { - EXPECT_CALL(http, request(Eq(conf.parameters.apiServer + "/path"), _)).WillOnce(Return(( Response ) http.statusPending())); - auto val = sut.request("/path", doc); - - EXPECT_TRUE(val.value().IsObject()); + EXPECT_CALL(http, request(Eq(conf.parameters.apiServer + "/path"), _)).WillOnce(Return(RawHttpResponse{}.initMessage())); + EXPECT_THAT(sut.request("/path", doc), // + AllOf(HasValue(), IsObject(), HasKey("/result/tid"))); +} + +TEST_F(CoreHandlerTests, onHttpProblemsAccessShouldBeDeniedByDefault) { + EXPECT_CALL(http, request(_, _)).WillOnce(Return(RawHttpResponse{}.initMessage().withServiceUnavailableError())); + EXPECT_THAT(sut.request("/path", doc), // + AllOf(Not(HasValue()), Unexpected(PamDeny{}))); +} + +class CoreHandlerWithBypassTests : public testing::Test { + public: + CoreHandlerWithBypassTests() : sut{confBypass}, http{sut._http()} { + } + + CoreHandlerTestable sut; + HttpHandlerMock & http; +}; + +TEST_F(CoreHandlerWithBypassTests, onHttpProblemsAccessShouldBeBypassedWhenEnabled) { + EXPECT_CALL(http, request(_, _)).WillOnce(Return(RawHttpResponse{}.initMessage().withServiceUnavailableError())); + EXPECT_THAT(sut.request("/path", Document{}), // + AllOf(Not(HasValue()), Unexpected(PamBaypass{}))); } diff --git a/PAM/ssh/tests/core_response_generator.hpp b/PAM/ssh/tests/core_response_generator.hpp index 024b0b8..3049411 100644 --- a/PAM/ssh/tests/core_response_generator.hpp +++ b/PAM/ssh/tests/core_response_generator.hpp @@ -1,9 +1,14 @@ #pragma once +#include + +#include +#include + #include -#include -#include #include +#include +#include namespace io { template < class X > @@ -18,7 +23,7 @@ constexpr T forward_or_transform(T t) { } template < class T, class = is_string< T > > -constexpr const char * forward_or_transform(const T &t) { +constexpr const char * forward_or_transform(const T & t) { return t.c_str(); } @@ -28,72 +33,15 @@ int sprintf(char * _s, const std::string & format, Ti... t) { } } // namespace io - namespace { -std::string gen_random(const int len) { - static const char alphanum[] = - "0123456789" - "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; - std::string tmp_s; - tmp_s.resize(len); - - std::for_each(tmp_s.begin(), tmp_s.end(), [](auto & chr) { chr = alphanum[rand() % (sizeof(alphanum) - 1)]; }); - - return tmp_s; -} static constexpr const char * result_ok_template = R"json({"status":"OK","result":{"tid":"%s","status":"%s","companyName":"%s","applicationName":"%s","methods":[%s]}})json"; -static constexpr const char * result_broken_template = R"json({"some":"random","json":"notrublon"})json"; +static constexpr const char * result_broken_template = // + R"json({"some":"random","json":"notrublon"})json"; + +static constexpr const char * wrongPasscode = // + R"json({"status":"OK","result":{"error":"Hmm, that's not the right code. Try again."}})json"; } // namespace - -class CoreResponseGenerator { - public: - std::string generateBody() { - std::array< char, 2048 > _buf; - io::sprintf(_buf.data(), - generateBrokenData ? result_broken_template : result_ok_template, - _tid, - status, - companyName, - applicationName, - print_methods()); - return _buf.data(); - } - - CoreResponseGenerator & withMethods(std::initializer_list< std::string > newMethods) { - _methods.clear(); - std::copy(newMethods.begin(), newMethods.end(), std::inserter(_methods, _methods.begin())); - return *this; - } - - CoreResponseGenerator & withTid(std::string tid) { - _tid = tid; - return *this; - } - - std::string _tid{generateTid()}; - std::string status; - std::string companyName{"rublon"}; - std::string applicationName{"test_app"}; - std::set< std::string > _methods{"email", "totp", "qrcode", "push"}; - - bool skipSignatureGeneration{false}; - bool generateBrokenData{false}; - private: - std::string print_methods() { - std::string ret; - for(const auto & m : _methods) - ret += "\"" + m + "\","; - ret.pop_back(); - return ret; - } - - static std::string generateTid() { - return gen_random(32); - } - - -}; diff --git a/PAM/ssh/tests/gtest_matchers.hpp b/PAM/ssh/tests/gtest_matchers.hpp new file mode 100644 index 0000000..a8c2c30 --- /dev/null +++ b/PAM/ssh/tests/gtest_matchers.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include +#include + +using namespace testing; + + +inline bool operator==(const rublon::Error & lhs, const rublon::Error & rhs) { + return lhs.category() == rhs.category(); +} + +MATCHER(HasValue, "") { + return arg.has_value(); +} + +MATCHER(IsObject, "") { + return arg->IsObject(); +} + +MATCHER(IsPAMBaypass,""){ + return arg.error().category() == rublon::Error::k_PamBaypass; +} + +MATCHER(IsPAMDeny,""){ + return arg.error().category() == rublon::Error::k_PamDeny; +} + +MATCHER_P(HasKey, key, "") { + return rublon::JSONPointer{key}.Get(arg.value()) != nullptr; +} + +MATCHER_P(Unexpected, error, "") { + rublon::Error e = arg.error(); + return e == error; +} + diff --git a/PAM/ssh/tests/http_mock.hpp b/PAM/ssh/tests/http_mock.hpp index 3335c27..12049b0 100644 --- a/PAM/ssh/tests/http_mock.hpp +++ b/PAM/ssh/tests/http_mock.hpp @@ -5,8 +5,8 @@ #include #include -#include #include +#include #include "core_response_generator.hpp" #include "rublon/sign.hpp" @@ -19,43 +19,219 @@ rublon::Configuration conf{rublon::Configuration::Parameters{// 1, true, true, + false, false}}; +rublon::Configuration confBypass{rublon::Configuration::Parameters{// + "320BAB778C4D4262B54CD243CDEFFAFD", + "39e8d771d83a2ed3cc728811911c25", + "https://staging-core.rublon.net", + 1, + true, + true, + false, + true}}; + +inline std::string randomTID() { + static const char alphanum[] = + "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + std::string tmp_s; + const auto len = 32; + tmp_s.resize(len); + + std::for_each(tmp_s.begin(), tmp_s.end(), [](auto & chr) { chr = alphanum[rand() % (sizeof(alphanum) - 1)]; }); + + return tmp_s; +} + +std::string join_methods(std::set< std::string > _methods) { + std::string ret; + for(const auto & m : _methods) + ret += "\"" + m + "\","; + ret.pop_back(); + return ret; +} + } // namespace -class HttpHandlerMock : public CoreResponseGenerator { - auto signResponse(rublon::Response & res) { +struct ResponseMockOptions { + bool generateBrokenSigneture{false}; + bool generateBrokenBody{false}; + bool generateBrokenCode{false}; + + std::string coreException{}; + std::set< std::string > methods{}; + + rublon::Error error; +}; + +class InitCoreResponse { + static constexpr const char * result_ok_template = + R"json({"status":"OK","result":{"tid":"%s","status":"%s","companyName":"%s","applicationName":"%s","methods":[%s]}})json"; + + static constexpr const char * result_broken_template = // + R"json({"some":"random","json":"notrublon"})json"; + + public: + static std::string generate(const ResponseMockOptions & options) { + if(options.generateBrokenBody) { + return result_broken_template; + } + + std::array< char, 2048 > _buf; + + auto tid = randomTID(); + auto status = "pending"; + std::string companyName{"rublon"}; + std::string applicationName{"test_app"}; + std::set< std::string > methods{"email", "totp", "qrcode", "push"}; + + if(options.methods.size()) + methods = options.methods; + + io::sprintf(_buf.data(), result_ok_template, tid, status, companyName, applicationName, join_methods(methods)); + return _buf.data(); + } +}; + +class MethodSelectCoreResponse { + static constexpr const char * result_ok_template = + R"json({"status":"OK","result":{"tid":"%s","status":"%s","companyName":"%s","applicationName":"%s","methods":[%s]}})json"; + + static constexpr const char * result_broken_template = // + R"json({"some":"random","json":"notrublon"})json"; + + public: + static std::string generate(const ResponseMockOptions & options) { + if(options.generateBrokenBody) { + return result_broken_template; + } + + std::array< char, 2048 > _buf; + + auto tid = randomTID(); + auto status = "pending"; + std::string companyName{"rublon"}; + std::string applicationName{"test_app"}; + std::set< std::string > methods{"email", "totp", "qrcode", "push"}; + + io::sprintf(_buf.data(), result_ok_template, tid, status, companyName, applicationName, join_methods(methods)); + return _buf.data(); + } +}; + +class CodeVerificationResponse { + static constexpr const char * wrongPasscode = + R"json({"status":"OK","result":{"error":"Hmm, that's not the right code. Try again."}})json"; + static constexpr const char * goodPasscode = R"json({"status":"OK","result":true})json"; + + public: + static std::string generate(const ResponseMockOptions & options) { + return options.generateBrokenCode ? wrongPasscode : goodPasscode; + } +}; + +template < typename Generator > +class ResponseBase { + public: + ResponseBase & initMessage() { + _coreGenerator = InitCoreResponse{}; + return *this; + } + + ResponseBase & selectMethodMessage() { + _coreGenerator = MethodSelectCoreResponse{}; + return *this; + } + + ResponseBase & codeConfirmationMessage() { + _coreGenerator = CodeVerificationResponse{}; + return *this; + } + + ResponseBase & withBrokenBody() { + options.generateBrokenBody = true; + return *this; + } + + ResponseBase & withCoreException() { + options.coreException = "some exception"; + return *this; + } + + ResponseBase & withBrokenSignature() { + options.generateBrokenSigneture = true; + return *this; + } + + ResponseBase & withWrongCodeResponse() { + options.generateBrokenCode = true; + return *this; + } + + ResponseBase & withTimeoutError() { + options.error = rublon::Error{rublon::HttpError{}}; + return *this; + } + + ResponseBase & withServiceUnavailableError() { + options.error = rublon::Error{rublon::HttpError{rublon::HttpError::Error, 405}}; + return *this; + } + + template < typename... T > + ResponseBase & withMethods(T... methods) { + (options.methods.insert(methods), ...); + return *this; + } + + ResponseBase & withConnectionError() { + // options.error = rublon::Error{rublon::CoreHandlerError{rublon::CoreHandlerError::ConnectionError}}; + return *this; + } + + operator auto() { + return options.error.category() == rublon::Error::k_None ? static_cast< Generator * >(this)->generate() : tl::unexpected{error()}; + } + + rublon::Error error() { + return rublon::Error{options.error}; + } + + std::variant< InitCoreResponse, MethodSelectCoreResponse, CodeVerificationResponse > _coreGenerator; + ResponseMockOptions options; +}; + +class RawCoreResponse : public ResponseBase< RawCoreResponse > { + public: + tl::expected< rublon::Document, rublon::Error > generate() { + auto jsonString = std::visit([&](const auto generator) { return generator.generate(options); }, _coreGenerator); + rublon::Document doc; + doc.Parse(jsonString.c_str()); + return doc; + } +}; + +class RawHttpResponse : public ResponseBase< RawHttpResponse > { + static auto signResponse(rublon::Response & res, ResponseMockOptions opts) { const auto & sign = - skipSignatureGeneration ? std::array< char, 64 >{} : rublon::signData(res.body, conf.parameters.secretKey.c_str()); + opts.generateBrokenSigneture ? std::array< char, 65 >{} : rublon::signData(res.body, conf.parameters.secretKey.c_str()); res.headers["x-rublon-signature"] = sign.data(); } public: - template < typename... Args > - HttpHandlerMock(const Args &...) {} - - MOCK_METHOD((tl::expected< rublon::Response, rublon::Error >), request, ( std::string_view, const rublon::Request & ), (const)); - - HttpHandlerMock & statusPending() { - status = "pending"; - return *this; - } - - HttpHandlerMock & brokenBody() { - generateBrokenData = true; - return *this; - } - - HttpHandlerMock & brokenSignature() { - skipSignatureGeneration = true; - return *this; - } - - operator rublon::Response() { - rublon::Response res; - res.body = generateBody(); - - signResponse(res); - - return res; + tl::expected< rublon::Response, rublon::Error > generate() { + rublon::Response response{std::pmr::get_default_resource()}; + response.body = std::visit([&](const auto generator) { return generator.generate(options); }, _coreGenerator); + signResponse(response, options); + return response; } }; + +class HttpHandlerMock { + public: + template < typename... Args > + HttpHandlerMock(const Args &...) {} + + MOCK_METHOD(( tl::expected< rublon::Response, rublon::Error > ), request, ( std::string_view, const rublon::Request & ), (const)); +}; diff --git a/PAM/ssh/tests/init_test.cpp b/PAM/ssh/tests/init_test.cpp index 8c136be..8874321 100644 --- a/PAM/ssh/tests/init_test.cpp +++ b/PAM/ssh/tests/init_test.cpp @@ -3,65 +3,21 @@ #include -#include "core_response_generator.hpp" +#include "core_handler_mock.hpp" +#include "gtest_matchers.hpp" +#include "http_mock.hpp" #include "pam_info_mock.hpp" using namespace rublon; using namespace testing; -namespace { -Configuration conf; -} - -class CoreHandlerMock : public CoreHandlerInterface< CoreHandlerMock > { - public: - CoreHandlerMock() {} - - MOCK_METHOD(( tl::expected< Document, Error > ), request, ( std::string_view, const Document & ), (const)); - - CoreHandlerMock & statusPending() { - gen.status = "pending"; - return *this; - } - - CoreHandlerMock & brokenBody() { - gen.generateBrokenData = true; - return *this; - } - - CoreHandlerMock & methods(std::initializer_list< std::string > methods) { - gen.withMethods(methods); - return *this; - } - - CoreHandlerMock & tid(std::string tid) { - gen.withTid(tid); - return *this; - } - - operator tl::expected< Document, Error >() { - auto body = gen.generateBody(); - - rublon::Document doc; - doc.Parse(body.c_str()); - - return doc; - } - - auto create() { - return static_cast< tl::expected< Document, Error > >(*this); - } - - CoreResponseGenerator gen; -}; - class MethodFactoryMock { public: using MethodFactoryCreate_t = tl::expected< MethodProxy, Error >; template < typename... Args > MethodFactoryMock(Args &&...) {} - + MOCK_METHOD(MethodFactoryCreate_t, create, (const mocks::PamInfoMock &, std::string tid), (const)); }; @@ -76,48 +32,42 @@ class RublonHttpInitTest : public testing::Test { expectDefaultPamInfo(); } - CoreHandlerMock coreHandler; + mocks::CoreHandlerMock coreHandler; mocks::PamInfoMock pam; Init< MethodFactoryMock > sut; - // MethodFactoryMock &methodFactoryMock; }; using CoreReturn = tl::expected< Document, CoreHandlerError >; TEST_F(RublonHttpInitTest, initializationSendsRequestOnGoodPath) { - EXPECT_CALL(coreHandler, request("/api/transaction/init", _)) - .WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::BadSigature}})); + EXPECT_CALL(coreHandler, request("/api/transaction/init", _)).WillOnce(Return(RawCoreResponse{}.initMessage())); sut.handle(coreHandler, pam); } -/// TODO fix -//MATCHER_P(HoldsPamAction, action, "") { -// return not arg.has_value() && arg.error() == action; -//} +TEST_F(RublonHttpInitTest, rublon_Accept_pamLoginWhenThereIsNoConnection) { + EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(RawCoreResponse{}.withTimeoutError())); + EXPECT_THAT(sut.handle(coreHandler, pam), // + AllOf(Not(HasValue()), IsPAMBaypass())); +} -//TEST_F(RublonHttpInitTest, rublon_Accept_pamLoginWhenThereIsNoConnection) { -// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::ConnectionError}})); -// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline)); -//} +// TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerHasBadSignature) { +// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::BadSigature}})); +// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline)); +// } -//TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerHasBadSignature) { -// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::BadSigature}})); -// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline)); -//} +// TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerReturnsBrokenData) { +// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::BrokenData}})); +// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline)); +// } -//TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerReturnsBrokenData) { -// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::BrokenData}})); -// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline)); -//} +// TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerReturnsCoreException) { +// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::CoreException}})); +// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline)); +// } -//TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerReturnsCoreException) { -// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::CoreException}})); -// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline)); -//} - -//TEST_F(RublonHttpInitTest, AllNeededInformationNeedsToBePassedToMethodFactory) { -// EXPECT_CALL(coreHandler, request(_, _)) -// .WillOnce(Return(coreHandler.statusPending().methods({"sms", "otp"}).tid("transaction ID").create())); -// // EXPECT_CALL(methodFactoryMock, create_mocked(_,"transaction ID") ); -// sut.handle(coreHandler, pam); -//} + TEST_F(RublonHttpInitTest, AllNeededInformationNeedsToBePassedToMethodFactory) { + EXPECT_CALL(coreHandler, request(_, _)) + .WillOnce(Return( RawCoreResponse{}.initMessage().withMethods("sms", "otp"))); + // EXPECT_CALL(methodFactoryMock, create_mocked(_,"transaction ID") ); + sut.handle(coreHandler, pam); + } diff --git a/PAM/ssh/tests/pam_info_mock.hpp b/PAM/ssh/tests/pam_info_mock.hpp index 9a0e982..38d894d 100644 --- a/PAM/ssh/tests/pam_info_mock.hpp +++ b/PAM/ssh/tests/pam_info_mock.hpp @@ -1,6 +1,6 @@ #pragma once -#include > +#include #include @@ -9,5 +9,19 @@ class PamInfoMock { public: MOCK_METHOD(rublon::NonOwningPtr< const char >, ip, (), (const)); MOCK_METHOD(rublon::NonOwningPtr< const char >, username, (), (const)); + + MOCK_METHOD(std::string, scan_mock, (const char * fmt), (const)); + MOCK_METHOD(void, print_mock, (const char * fmt), (const)); + + template < typename Fun, typename... Ti > + auto scan(Fun && f, const char * fmt, Ti...) const noexcept -> std::optional< std::result_of_t< Fun(char *) > > { + const auto & responseBuffer = scan_mock(fmt); + return responseBuffer.empty() ? std::nullopt : std::optional{f(responseBuffer.c_str())}; + } + + template < typename... Ti > + void print(const char * fmt, Ti...) const noexcept { + print_mock(fmt); + } }; } // namespace mocks diff --git a/PAM/ssh/tests/passcode_auth_tests.cpp b/PAM/ssh/tests/passcode_auth_tests.cpp index 727ab02..c4c5dc2 100644 --- a/PAM/ssh/tests/passcode_auth_tests.cpp +++ b/PAM/ssh/tests/passcode_auth_tests.cpp @@ -3,6 +3,8 @@ #include #include "core_handler_mock.hpp" +#include "gtest_matchers.hpp" +#include "http_mock.hpp" #include "pam_info_mock.hpp" using namespace testing; @@ -21,8 +23,38 @@ class PasscodeBasedAuthTest : public Test { mocks::PamInfoMock pam; }; -TEST_F(PasscodeBasedAuthTest, wrongPasscodeShouldFail){ - EXPECT_THAT(coreHandler, request(_,_) ); +TEST_F(PasscodeBasedAuthTest, wrongPasscodeShouldFail) { + EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(RawCoreResponse{}.codeConfirmationMessage().withWrongCodeResponse())); + EXPECT_CALL(pam, scan_mock(_)).WillOnce(Return("123456")); - sut.handle(coreHandler, pam); + EXPECT_THAT(sut.handle(coreHandler, pam), Unexpected(WerificationError{WerificationError::WrongCode})); +} + +TEST_F(PasscodeBasedAuthTest, whenGivenBadPasscodeWeNeedToAskAgain) { + EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(RawCoreResponse{}.codeConfirmationMessage())); + + EXPECT_CALL(pam, scan_mock(_))// + .WillOnce(Return("1_3456")) + .WillOnce(Return("123456")); + + EXPECT_CALL(pam, print_mock(_) ); + + EXPECT_THAT(sut.handle(coreHandler, pam), HasValue() ); +} + +TEST_F(PasscodeBasedAuthTest, whenGiveenBadPasscodeMultipleTimesWeAbort) { + EXPECT_CALL(pam, scan_mock(_))// + .WillOnce(Return("1_3456")) + .WillOnce(Return("12345_")); + + EXPECT_CALL(pam, print_mock(_) ); + + EXPECT_THAT(sut.handle(coreHandler, pam), Unexpected(WerificationError{WerificationError::WrongCode})); +} + +TEST_F(PasscodeBasedAuthTest, goodPasscodeShouldPass){ + EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(RawCoreResponse{}.codeConfirmationMessage())); + EXPECT_CALL(pam, scan_mock(_)).WillOnce(Return("123456")); + + EXPECT_THAT(sut.handle(coreHandler, pam), HasValue() ); } diff --git a/PAM/ssh/tests/rublon_tests.cpp b/PAM/ssh/tests/rublon_tests.cpp deleted file mode 100644 index 0ed9da6..0000000 --- a/PAM/ssh/tests/rublon_tests.cpp +++ /dev/null @@ -1,7 +0,0 @@ -#include -#include - -#include - -using namespace testing; - diff --git a/PAM/ssh/tests/sign_tests.cpp b/PAM/ssh/tests/sign_tests.cpp index 394ea64..0f42c6a 100644 --- a/PAM/ssh/tests/sign_tests.cpp +++ b/PAM/ssh/tests/sign_tests.cpp @@ -2,9 +2,17 @@ #include -TEST(sign, AProperSignIsReturnedFromSignData){ +using namespace std::string_literals; +using namespace std::string_view_literals; + +TEST(sign, AProperSignIsReturnedFromSignData) { const auto signKey = "39e8d771d83a2ed3cc728811911c25"; - const auto sign = "a6385672ee92eb37fe661d0cf5eb37a8a496663d7c27943e226e56a451d4b3d6"; - EXPECT_STREQ(rublon::signData("{data}", signKey).data(), sign); + const auto sign = "a6385672ee92eb37fe661d0cf5eb37a8a496663d7c27943e226e56a451d4b3d6"s; + EXPECT_EQ((std::string{rublon::signData("{data}", signKey).data(), 64}), sign); } +TEST(sign, AZeroTerminatedStringIsReturned) { + const auto signKey = "39e8d771d83a2ed3cc728811911c25"; + const auto sign = "a6385672ee92eb37fe661d0cf5eb37a8a496663d7c27943e226e56a451d4b3d6"s; + EXPECT_TRUE(*rublon::signData("{data}", signKey).rbegin() == '\0'); +} diff --git a/PAM/ssh/tests/utils_tests.cpp b/PAM/ssh/tests/utils_tests.cpp index 0465a0c..9d04593 100644 --- a/PAM/ssh/tests/utils_tests.cpp +++ b/PAM/ssh/tests/utils_tests.cpp @@ -28,63 +28,86 @@ x-ratelimit-remaining: 299 {"status":"OK","result":{"tid":"2039132542F6465691BF8C41D7CC38C5","status":"pending","companyName":"rublon","applicationName":"Bartoszek_SSH","methods":["email","totp","qrcode","push"]}})http"; -static inline std::string_view ltrim(std::string_view s) { - while(std::isspace(*s.begin())) - s.remove_prefix(1); - return s; -} +class LoggingAllocator : public std::pmr::memory_resource { + public: + int count{}; + int bytes{}; -static inline std::string_view rtrim(std::string_view s) { - while(std::isspace(*s.rbegin())) - s.remove_suffix(1); - return s; -} - -static inline std::string_view trim(std::string_view s) { - return ltrim(rtrim(s)); -} - -std::pmr::map< std::pmr::string, std::pmr::string > headers(std::pmr::memory_resource * mr, std::string_view data) { - std::pmr::map< std::pmr::string, std::pmr::string > headers{mr}; - - std::pmr::string tmp{mr}; - tmp.reserve(256); - std::istringstream resp{}; - resp.rdbuf()->pubsetbuf(const_cast< char * >(data.data()), data.size()); - - while(std::getline(resp, tmp) && !(trim(tmp).empty())) { - auto line = std::string_view(tmp); - auto index = tmp.find(':', 0); - if(index != std::string::npos) { - headers.insert({std::pmr::string{trim(line.substr(0, index)), mr}, std::pmr::string{trim(line.substr(index + 1)), mr}}); - } + void clear() { + count = 0; + bytes = 0; + } + // memory_resource interface + private: + void * do_allocate(std::size_t __bytes, std::size_t __alignment) override { + bytes += __bytes; + count++; + return std::pmr::get_default_resource()->allocate(__bytes, __alignment); } - return headers; -} - -inline std::map< std::string, std::string > headers(std::string_view data) { - std::map< std::string, std::string > headers{}; - - std::string tmp{}; - std::istringstream resp{}; - resp.rdbuf()->pubsetbuf(const_cast< char * >(data.data()), data.size()); - - while(std::getline(resp, tmp) && !(trim(tmp).empty())) { - auto line = std::string_view(tmp); - auto index = tmp.find(':', 0); - if(index != std::string::npos) { - headers.insert({std::string{trim(line.substr(0, index))}, std::string{trim(line.substr(index + 1))}}); - } + void do_deallocate(void * __p, std::size_t __bytes, std::size_t __alignment) override { + bytes -= __bytes; + count--; + std::pmr::get_default_resource()->deallocate(__p, __bytes, __alignment); } - return headers; -} + bool do_is_equal(const memory_resource & __other) const noexcept override { + return std::pmr::get_default_resource()->is_equal(__other); + } +}; TEST(Utils, responseParser) { - auto sut = headers(response); + memory::MonotonicStackResource< 2048 > resource; + std::map sut{}; + rublon::details::headers(response, sut); EXPECT_THAT(sut, AllOf(Contains(Pair("date", "Thu, 22 Jun 2023 13:24:58 GMT")), Contains(Pair("x-rublon-signature", "1a01558bedaa2dd92ff659fb8ee3c65a89163d63e312fcb9b6f60463cce864d7")), Contains(Pair("x-ratelimit-remaining", "299")))); } + +class Memory : public testing::Test { + public: + // Test interface + protected: + void SetUp() override { + memory::set_default_resource(&alloc); + } + void TearDown() override { + memory::set_default_resource(std::pmr::get_default_resource()); + } + + LoggingAllocator alloc; +}; + +class X {}; + +TEST_F(Memory, setSetsProperResource) { + LoggingAllocator s; + memory::set_default_resource(&s); + EXPECT_EQ(&s, memory::default_resource()); +} + +TEST_F(Memory, stackResourceDoesNotUseHeapAtAll) { + memory::MonotonicStackResource< 1024 > mr; + std::pmr::string{"to large to fit in small buffer optimization", &mr}; + + EXPECT_EQ(alloc.count, 0); +} + +TEST_F(Memory, heapResourceWillAllocate1kWhenAsk) { + memory::set_default_resource(&alloc); + memory::StrictMonotonic_1k_HeapResource mr; + std::pmr::string{"to large to fit in small buffer optimization", &mr}; + + EXPECT_EQ(alloc.count, 1); + EXPECT_EQ(alloc.bytes, 1024); +} + +TEST_F(Memory, heapResourceWillAllocate8kWhenAsk) { + memory::StrictMonotonic_8k_HeapResource mr; + std::pmr::string{"to large to fit in small buffer optimization", &mr}; + + EXPECT_EQ(alloc.count, 1); + EXPECT_EQ(alloc.bytes, 8 * 1024); +}