Bwi/memory management (#2)

Improve memory management
This commit is contained in:
rublon-bwi 2023-09-21 16:52:20 +02:00 committed by GitHub
parent 6a4f2431fc
commit 51b14c57d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 919 additions and 510 deletions

View File

@ -1,8 +1,8 @@
#pragma once
#include <rublon/core_handler_interface.hpp>
#include <rublon/pam.hpp>
#include <rublon/pam_action.hpp>
#include <rublon/core_handler_interface.hpp>
#include <rublon/utils.hpp>
namespace rublon {
@ -14,7 +14,7 @@ class AuthenticationStep {
std::string _tid;
public:
AuthenticationStep() {}
AuthenticationStep() = default;
AuthenticationStep(std::string systemToken, std::string tid) : _systemToken{std::move(systemToken)}, _tid{std::move(tid)} {}
template < typename Handler_t >
@ -30,33 +30,28 @@ class AuthenticationStep {
}
protected:
void addSystemToken(Document & body, RapidJSONPMRAlloc & alloc) const {
void addSystemToken(Document & body) const {
auto & alloc = body.GetAllocator();
body.AddMember("systemToken", Value{this->_systemToken.c_str(), alloc}, alloc);
}
void addTid(Document & body, RapidJSONPMRAlloc & alloc) const {
void addTid(Document & body) const {
auto & alloc = body.GetAllocator();
body.AddMember("tid", Value{this->_tid.c_str(), alloc}, alloc);
}
template < typename HandlerReturn_t >
Error coreErrorHandler(const HandlerReturn_t & /*coreResponse*/) const {
// switch(coreResponse.error().errorClass) {
// case CoreHandlerError::ErrorClass::BadSigature:
// log(LogLevel::Error, "ErrorClass::BadSigature");
// return PamAction::decline;
// case CoreHandlerError::ErrorClass::CoreException: /// TODO exception handling
// log(LogLevel::Error, "ErrorClass::CoreException");
// return PamAction::decline; /// TODO accept?
// case CoreHandlerError::ErrorClass::ConnectionError:
// log(LogLevel::Error, "ErrorClass::ConnectionError");
// return PamAction::decline; /// TODO decline?
// case CoreHandlerError::ErrorClass::BrokenData:
// log(LogLevel::Error, "ErrorClass::BrokenData");
// return PamAction::decline;
// }
// return PamAction::decline;
return Error{Critical{}};
Error coreErrorHandler(const Error & coreResponse) const {
auto category = coreResponse.category();
switch(category) {
case Error::k_SockerError:
log(LogLevel::Error, "Socker error, forcing override");
return Error{PamBaypass{}};
case Error::k_CoreHandlerError:
return Error{PamBaypass{}};
default:
Critical{};
}
return Critical{};
}
};

View File

@ -26,31 +26,28 @@ class Configuration {
bool enablePasswdEmail;
bool logging;
bool autopushPrompt;
};
Parameters parameters;
bool bypass;
} parameters;
};
class ConfigurationFactory {
public:
ConfigurationFactory(){};
ConfigurationFactory() = default;
std::optional< Configuration > systemConfig() {
std::array< char, 8 * 1024 > configBuffer;
std::pmr::monotonic_buffer_resource mr{configBuffer.data(), configBuffer.size()};
memory::MonotonicStackResource< 8 * 1024 > stackResource;
using Params = Configuration::Parameters;
Params configValues;
std::ifstream file(std::filesystem::path{"/etc/rublon.config"});
std::pmr::string line{&mr};
std::pmr::string line{&stackResource};
line.reserve(100);
std::pmr::map< std::pmr::string, std::pmr::string > parameters{&mr};
std::pmr::map< std::pmr::string, std::pmr::string > parameters{&stackResource};
while(std::getline(file, line)) {
std::pmr::string key{&mr};
std::pmr::string value{&mr};
std::pmr::string key{&stackResource};
std::pmr::string value{&stackResource};
if(!line.length())
continue;
@ -62,7 +59,7 @@ class ConfigurationFactory {
key = line.substr(0, posEqual);
value = line.substr(posEqual + 1);
parameters.emplace(std::move(key), std::move(value));
parameters[std::move(key)] = std::move(value);
}
auto saveStr = [&](auto member) { //

View File

@ -18,13 +18,10 @@ template < typename HttpHandler = CURL >
class CoreHandler : public CoreHandlerInterface< CoreHandler< HttpHandler > > {
std::string secretKey;
std::string url;
bool bypass;
std::pmr::string xRublonSignature(std::pmr::memory_resource & mr, std::string_view body) const {
return {signData(body, secretKey.c_str()).data(), &mr};
}
void signRequest(std::pmr::monotonic_buffer_resource & mr, Request & request) const {
request.headers["X-Rublon-Signature"] = xRublonSignature(mr, request.body);
void signRequest(Request & request) const {
request.headers["X-Rublon-Signature"] = std::pmr::string{signData(request.body, secretKey).data(), request.headers.get_allocator()};
}
bool responseSigned(const Response & response) const {
@ -41,19 +38,20 @@ class CoreHandler : public CoreHandlerInterface< CoreHandler< HttpHandler > > {
public:
CoreHandler(const Configuration & config)
: secretKey{config.parameters.secretKey}, url{config.parameters.apiServer}, http{[]() { return Response{}; }} {}
: secretKey{config.parameters.secretKey}, url{config.parameters.apiServer}, bypass{config.parameters.bypass}, http{} {}
tl::expected< Response, Error > validateSignature(const Response & response) const {
tl::expected< std::reference_wrapper< const Response >, Error > validateSignature(const Response & response) const {
if(not responseSigned(response)) {
log(LogLevel::Error, "CoreHandlerError::BadSigature");
return tl::unexpected{CoreHandlerError{CoreHandlerError::BadSigature}};
}
return response;
}
tl::expected< Document, Error > validateResponse(const Response & response) const {
RapidJSONPMRStackAlloc< 4 * 1024 > alloc{};
Document resp{&alloc};
RapidJSONPMRAlloc alloc{memory::default_resource()};
Document resp{};
resp.Parse(response.body.c_str());
if(resp.HasParseError() or not resp.HasMember("result")) {
@ -64,37 +62,71 @@ class CoreHandler : public CoreHandlerInterface< CoreHandler< HttpHandler > > {
if(resp["result"].HasMember("exception")) {
const auto & exception = resp["result"]["exception"].GetString();
log(LogLevel::Error, "rublon Core exception %s", exception);
return tl::unexpected{CoreHandlerError{CoreHandlerError::CoreException, exception}};
return handleCoreException(exception);
}
return resp;
}
tl::unexpected< Error > handleCoreException(std::string_view exceptionString) const {
if(exceptionString == "UserBypassedException" or exceptionString == "UserNotFoundException")
return tl::unexpected{Error{PamBaypass{}}};
else
return tl::unexpected{
Error{CoreHandlerError{CoreHandlerError::CoreException, std::string{exceptionString.data(), exceptionString.size()}}}};
}
tl::expected< Document, Error > request(std::string_view path, const Document & body) const {
const auto validateSignature = [this](const auto & arg) { return this->validateSignature(arg); };
const auto validateResponse = [this](const auto & arg) { return this->validateResponse(arg); };
tl::unexpected< Error > handleHttpError() const {
if(bypass) {
log(LogLevel::Warning, "User login bypass");
return tl::unexpected{Error{PamBaypass{}}};
} else {
log(LogLevel::Warning, "User login deny due to HTTP error");
return tl::unexpected{Error{PamDeny{}}};
}
}
RapidJSONPMRStackAlloc< 1 * 1024 > alloc{};
tl::expected< Document, Error > handleError(const Error & error) const {
if(error.is< HttpError >() and error.hasClass(HttpError::Error)) {
return handleHttpError();
}
return tl::unexpected{Error{error}};
}
template < typename T >
static void stringifyTo(const Document & body, T & to) {
memory::Monotonic_2k_HeapResource tmpResource;
RapidJSONPMRAlloc alloc{&tmpResource};
StringBuffer jsonStr{&alloc};
Writer writer{jsonStr, &alloc};
body.Accept(writer);
to = jsonStr.GetString();
}
std::byte _buffer[2 * 1024];
std::pmr::monotonic_buffer_resource mr{_buffer, sizeof(_buffer)};
tl::expected< Document, Error > request(std::string_view path, const Document & body) const {
memory::StrictMonotonic_2k_HeapResource memoryResource;
Request request{mr};
request.headers["Content-Type"] = "application/json";
request.headers["Accept"] = "application/json";
const auto validateSignature = [this](const auto & arg) { return this->validateSignature(arg); };
const auto validateResponse = [this](const auto & arg) { return this->validateResponse(arg); };
const auto handleError = [this](const auto & error) { return this->handleError(error); };
const auto pmrs = [&](auto txt) { return std::pmr::string{txt, &memoryResource}; };
request.body = jsonStr.GetString();
Request request{&memoryResource};
signRequest(mr, request);
std::pmr::string uri{url + path.data(), &mr};
stringifyTo(body, request.body);
return http.request(uri, request)
request.headers["Content-Type"] = pmrs("application/json");
request.headers["Accept"] = pmrs("application/json");
signRequest(request);
std::pmr::string uri{url + path.data(), &memoryResource};
return http
.request(uri, request) //
.and_then(validateSignature)
.and_then(validateResponse)
.or_else([](const Error & e) -> tl::expected< Document, Error > { return tl::unexpected{e}; });
.or_else(handleError);
}
};

View File

@ -13,8 +13,8 @@
#include <memory_resource>
#include <rublon/utils.hpp>
#include <rublon/error.hpp>
#include <rublon/utils.hpp>
#include <tl/expected.hpp>
@ -29,32 +29,56 @@ namespace {
} // namespace
struct Request {
std::pmr::memory_resource * _mr;
std::pmr::map< std::pmr::string, std::pmr::string > headers;
std::pmr::string body;
Request(std::pmr::memory_resource & mr) : headers{&mr}, body{&mr} {}
public:
Request(std::pmr::memory_resource * mr) : _mr{mr}, headers{_mr}, body{_mr} {};
Request(const Request & res) = delete;
Request & operator=(const Request & res) = delete;
Request(Request && res) = delete;
Request & operator=(Request &&) = delete;
std::pmr::memory_resource * get_allocator() const noexcept {
return _mr;
}
};
struct Response {
std::map< std::string, std::string > headers;
std::string body;
std::pmr::memory_resource * _mr;
std::pmr::map< std::pmr::string, std::pmr::string > headers;
std::pmr::string body;
public:
Response(std::pmr::memory_resource * mr) : _mr{mr}, headers{_mr}, body{_mr} {};
Response(const Response & res) = delete;
Response & operator=(const Response & res) = delete;
Response(Response && res) noexcept = default;
Response & operator=(Response && res) noexcept = default;
std::pmr::memory_resource * get_allocator() const noexcept {
return _mr;
}
};
class CURL {
std::unique_ptr< ::CURL, void (*)(::CURL *) > curl;
std::function< Response() > _responseFactory;
public:
CURL(std::function< Response() > responseFactory)
: curl{std::unique_ptr< ::CURL, void (*)(::CURL *) >(curl_easy_init(), curl_easy_cleanup)},
_responseFactory{std::move(responseFactory)} {}
tl::expected< Response, Error > request(std::string_view uri, const Request & request) const {
std::array< char, 16 * 1024 > buffer = {};
std::pmr::monotonic_buffer_resource mr{buffer.data(), buffer.size()};
std::pmr::string response_data{&mr};
response_data.reserve(15000);
CURL() : curl{std::unique_ptr< ::CURL, void (*)(::CURL *) >(curl_easy_init(), curl_easy_cleanup)} {}
tl::expected< Response, HttpError > request(std::string_view uri, const Request & request) const {
memory::MonotonicStackResource< 8 * 1024 > stackResource;
std::pmr::string response_data{&stackResource};
response_data.reserve(6000);
/// TODO this can be done on stack using pmr
auto curl_headers = std::unique_ptr< curl_slist, void (*)(curl_slist *) >(nullptr, curl_slist_free_all);
@ -74,20 +98,30 @@ class CURL {
curl_easy_setopt(curl.get(), CURLOPT_WRITEDATA, &response_data);
log(LogLevel::Debug, "%s Request send, uri:%s body:\n%s\n", "CURL", uri.data(), request.body.c_str());
auto res = curl_easy_perform(curl.get());
const auto res = curl_easy_perform(curl.get());
if(res != CURLE_OK) {
log(LogLevel::Error, "%s No response from Rublon server err:{%s}", "CURL", curl_easy_strerror(res));
return tl::unexpected{Error{SocketError{SocketError::Timeout}}};
return tl::unexpected{HttpError{HttpError::Timeout, 0}};
}
long http_code = 0;
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, http_code);
if(http_code >= 400) {
log(LogLevel::Error, "%s response with code %d ", "CURL", http_code);
return tl::unexpected{HttpError{HttpError::Error, http_code}};
}
log(LogLevel::Debug, "Response:\n%s\n", response_data.c_str());
Response response;
long size;
curl_easy_getinfo(curl.get(), CURLINFO_HEADER_SIZE, &size);
response.headers = details::headers({response_data.data(), static_cast< std::size_t >(size)});
response.body = response_data.substr(size);
Response response{memory::default_resource()};
details::headers(response_data, response.headers);
response.body = response_data.substr(size);
return response;
}

View File

@ -5,76 +5,94 @@
#include <variant>
namespace rublon {
class SocketError {
class NoError {
public:
enum ErrorClass { Timeout };
static constexpr int errorClass = 0;
};
SocketError() : errorClass{Timeout} {}
SocketError(ErrorClass e) : errorClass{e} {}
SocketError(ErrorClass e, std::string r) : errorClass{e}, reson{std::move(r)} {}
class HttpError {
public:
enum ErrorClass { Timeout, Error };
constexpr HttpError() : errorClass{Timeout}, httpCode(200) {}
constexpr HttpError(ErrorClass e, long httpCode) : errorClass{e}, httpCode(httpCode) {}
ErrorClass errorClass;
std::string reson;
// error_category interface
public:
const char * name() const noexcept {
return "SockerError";
}
std::string message(int) const {
return "";
}
long httpCode;
};
class CoreHandlerError {
public:
enum ErrorClass { BadSigature, CoreException, ConnectionError, BrokenData };
enum ErrorClass { BadSigature, CoreException, BrokenData };
CoreHandlerError(ErrorClass e) : errorClass{e} {}
CoreHandlerError(ErrorClass e, std::string r) : errorClass{e}, reson{std::move(r)} {}
ErrorClass errorClass;
std::string reson;
const char * name() const noexcept {
return "CoreHandlerError";
}
std::string message(int) const {
return "";
}
};
class MethodError {
public:
enum ErrorClass { BadMethod };
MethodError(ErrorClass e) : errorClass{e} {}
MethodError(ErrorClass e, std::string r) : errorClass{e}, reson{std::move(r)} {}
constexpr MethodError(ErrorClass e) : errorClass{e} {}
ErrorClass errorClass;
std::string reson;
const char * name() const noexcept {
return "MethodError";
}
std::string message(int) const {
return "";
}
};
class Critical {};
class WerificationError {
public:
enum ErrorClass { WrongCode };
constexpr WerificationError(ErrorClass e) : errorClass{e} {}
ErrorClass errorClass;
};
class Critical {
public:
enum ErrorClass { Nok };
Critical(ErrorClass e = Nok) : errorClass{e} {}
ErrorClass errorClass;
};
class PamBaypass {
public:
enum ErrorClass { Nok };
constexpr PamBaypass(ErrorClass e = Nok) : errorClass{e} {}
ErrorClass errorClass;
};
class PamDeny {
public:
enum ErrorClass { Nok };
constexpr PamDeny(ErrorClass e = Nok) : errorClass{e} {}
ErrorClass errorClass;
};
class Error {
std::variant< std::monostate, CoreHandlerError, SocketError, Critical, MethodError > _error;
enum class Category { None, CoreHandlerError, SockerError, Criticat, MethodError };
using Error_t = std::variant< NoError, CoreHandlerError, HttpError, WerificationError, Critical, MethodError, PamBaypass, PamDeny >;
Error_t _error;
public:
enum Category { k_None, k_CoreHandlerError, k_SockerError, k_WerificationError, k_Critical, k_MethodError, k_PamBaypass, k_PamDeny };
Error() = default;
Error(CoreHandlerError error) : _error{error} {}
Error(SocketError error) : _error{error} {}
Error(Critical error) : _error{error} {}
Error(HttpError error) : _error{error} {}
Error(MethodError error) : _error{error} {}
Error(WerificationError error) : _error{error} {}
Error(Critical error) : _error{error} {}
Error(PamBaypass error) : _error{error} {}
Error(PamDeny error) : _error{error} {}
Error(const Error &) = default;
Error(Error &&) = default;
@ -90,9 +108,42 @@ class Error {
return _error.index() == 2;
}
Category category() const noexcept {
constexpr Category category() const noexcept {
return static_cast< Category >(_error.index());
}
constexpr int errorClass() const noexcept {
return std::visit([](const auto & e) { return static_cast< int >(e.errorClass); }, _error);
}
template < typename E >
constexpr bool is() const {
return category() == Error{E{}}.category();
}
template < typename E >
constexpr bool isSameCategoryAs(const E & e) const {
return category() == Error{e}.category();
}
constexpr bool hasClass(int _class) const {
return errorClass() == _class;
}
template < typename E >
constexpr bool hasSameErrorClassAs(E e) const {
assert(isSameCategoryAs(e));
return errorClass() == Error{e}.errorClass();
}
};
constexpr bool operator==(const Error & e, const HttpError & socket) {
return e.sockerError() && e.errorClass() == socket.errorClass;
}
template < typename T >
constexpr bool operator==(const Error & lhs, const T & rhs) {
return lhs.isSameCategoryAs(rhs) && lhs.hasSameErrorClassAs(rhs);
}
} // namespace rublon

View File

@ -19,6 +19,36 @@ class Init : public AuthenticationStep< Init< MethodSelect_t > > {
const char * apiPath = "/api/transaction/init";
tl::expected< MethodSelect_t, Error > createMethod(const Document & coreResponse) const {
const auto & rublonResponse = coreResponse["result"];
std::string tid = rublonResponse["tid"].GetString();
return MethodSelect_t{this->_systemToken, tid, rublonResponse["methods"]};
}
tl::expected< MethodSelect_t, Error > handleInitError(const Error & error) const {
return tl::unexpected{this->coreErrorHandler(error)};
}
template < typename PamInfo_t >
void addPamInfo(Document & body, const PamInfo_t & pam) const {
auto & alloc = body.GetAllocator();
body.AddMember("username", Value{pam.username().get(), alloc}, alloc);
body.AddMember("userEmail", "bwi@rublon.com", alloc); /// TODO proper useremail
}
template < typename PamInfo_t >
void addParams(Document & body, const PamInfo_t & pam) const {
auto & alloc = body.GetAllocator();
Value params{rapidjson::kObjectType};
params.AddMember("userIP", Value{pam.ip().get(), alloc}, alloc);
params.AddMember("appVer", "v.1.6", alloc); /// TODO add version to cmake
params.AddMember("os", "Ubuntu 23.04", alloc); /// TODO add version to cmake
body.AddMember("params", std::move(params), alloc);
}
public:
const char * name = "Initialization";
@ -27,31 +57,20 @@ class Init : public AuthenticationStep< Init< MethodSelect_t > > {
/// TODO add core handler interface
template < typename Hander_t, typename PamInfo_t = LinuxPam >
tl::expected< MethodSelect_t, Error > handle(const CoreHandlerInterface< Hander_t > & coreHandler, const PamInfo_t & pam) const {
const auto createMethod = [&](const auto & coreResponse) { return this->createMethod(coreResponse); };
const auto handleInitError = [&](const auto & error) { return this->handleInitError(error); };
RapidJSONPMRStackAlloc< 1024 > alloc{};
Document body{rapidjson::kObjectType, &alloc};
this->addSystemToken(body, alloc);
this->addSystemToken(body);
this->addPamInfo(body, pam);
this->addParams(body, pam);
body.AddMember("username", Value{pam.username().get(), alloc}, alloc);
body.AddMember("userEmail", "bwi@rublon.com", alloc); /// TODO proper username
Value params{rapidjson::kObjectType};
params.AddMember("userIP", Value{pam.ip().get(), alloc}, alloc);
params.AddMember("appVer", "v.1.6", alloc); /// TODO add version to cmake
params.AddMember("os", "Ubuntu 23.04", alloc); /// TODO add version to cmake
body.AddMember("params", std::move(params), alloc);
auto coreResponse = coreHandler.request(apiPath, body);
if(coreResponse.has_value()) {
log(LogLevel::Info, "[TMP] has response, processing", __PRETTY_FUNCTION__);
const auto & rublonResponse = coreResponse.value()["result"];
std::string tid = rublonResponse["tid"].GetString();
return MethodSelect_t{this->_systemToken, tid, rublonResponse["methods"]};
} else {
return tl::unexpected{this->coreErrorHandler(coreResponse)};
}
return coreHandler
.request(apiPath, body) //
.and_then(createMethod)
.or_else(handleInitError);
}
};
} // namespace rublon

View File

@ -1,6 +1,8 @@
#pragma once
#include "rapidjson/document.h"
#include "rapidjson/pointer.h"
#include "rapidjson/writer.h"
#include "rublon/utils.hpp"
#include <cstring>
#include <memory_resource>
@ -47,8 +49,7 @@ struct RapidJSONPMRAlloc {
}
void * newPtr = Malloc(newSize);
if(originalSize)
{
if(originalSize) {
std::memcpy(newPtr, origPtr, originalSize);
freePtr(origPtr, originalSize);
}
@ -81,6 +82,10 @@ struct RapidJSONPMRStackAlloc : public RapidJSONPMRAlloc {
public:
RapidJSONPMRStackAlloc() : RapidJSONPMRAlloc(&mr) {}
RapidJSONPMRStackAlloc(const RapidJSONPMRStackAlloc &) = delete;
RapidJSONPMRStackAlloc(RapidJSONPMRStackAlloc &&) = delete;
RapidJSONPMRStackAlloc & operator=(const RapidJSONPMRStackAlloc &) = delete;
RapidJSONPMRStackAlloc & operator=(RapidJSONPMRStackAlloc &&) = delete;
};
using Document = rapidjson::GenericDocument< rapidjson::UTF8<>, RapidJSONPMRAlloc >;
@ -88,4 +93,6 @@ using Value = rapidjson::GenericValue< rapidjson::UTF8<>, RapidJSONPMRAll
using StringBuffer = rapidjson::GenericStringBuffer< rapidjson::UTF8<>, RapidJSONPMRAlloc >;
using Writer = rapidjson::Writer< StringBuffer, rapidjson::UTF8<>, rapidjson::UTF8<>, RapidJSONPMRAlloc >;
using JSONPointer = rapidjson::GenericPointer< Value, RapidJSONPMRAlloc >;
} // namespace rublon

View File

@ -10,10 +10,10 @@
namespace rublon::method {
class OTP:public PasscodeBasedAuth{
class OTP : public PasscodeBasedAuth {
public:
OTP(std::string systemToken, std::string tid) : PasscodeBasedAuth(std::move(systemToken), std::move(tid), "OTP", "Mobile TOTP from Rublon Authenticator:") {}
OTP(std::string systemToken, std::string tid)
: PasscodeBasedAuth(std::move(systemToken), std::move(tid), "OTP", "Mobile TOTP from Rublon Authenticator:") {}
};
} // namespace rublon::method

View File

@ -10,10 +10,9 @@
namespace rublon::method {
class SMS:public PasscodeBasedAuth{
class SMS : public PasscodeBasedAuth {
public:
SMS(std::string systemToken, std::string tid) : PasscodeBasedAuth(std::move(systemToken), std::move(tid), "SMS", "SMS passcode:") {}
};
} // namespace rublon::method

View File

@ -1,5 +1,6 @@
#pragma once
#include <set>
#include <tl/expected.hpp>
#include <variant>
@ -63,7 +64,7 @@ class PostMethod : public rublon::AuthenticationStep< PostMethod > {
const char * uri = "/api/transaction/methodSSH";
std::string _method;
tl::expected< MethodProxy, Error > doCreateMethod(const Document & coreResponse) const {
tl::expected< MethodProxy, Error > createMethod(const Document & coreResponse) const {
const auto & rublonResponse = coreResponse["result"];
std::string tid = rublonResponse["tid"].GetString();
@ -77,6 +78,14 @@ class PostMethod : public rublon::AuthenticationStep< PostMethod > {
return tl::unexpected{MethodError{MethodError::BadMethod}};
}
void addParams(Document & body) const {
auto & alloc = body.GetAllocator();
body.AddMember("method", Value{_method.c_str(), alloc}, alloc);
body.AddMember("GDPRAccepted", "true", alloc);
body.AddMember("tosAccepted", "true", alloc);
}
public:
const char * name = "Confirm Method";
@ -85,17 +94,14 @@ class PostMethod : public rublon::AuthenticationStep< PostMethod > {
template < typename Hander_t, typename PamInfo_t = LinuxPam >
tl::expected< MethodProxy, Error > handle(const CoreHandlerInterface< Hander_t > & coreHandler) const {
auto createMethod = [&](const auto & coreResponse) { return doCreateMethod(coreResponse); };
auto createMethod = [&](const auto & coreResponse) { return this->createMethod(coreResponse); };
RapidJSONPMRStackAlloc< 1024 > alloc{};
Document body{rapidjson::kObjectType, &alloc};
this->addSystemToken(body, alloc);
this->addTid(body, alloc);
body.AddMember("method", Value{_method.c_str(), alloc}, alloc);
body.AddMember("GDPRAccepted", "true", alloc);
body.AddMember("tosAccepted", "true", alloc);
this->addSystemToken(body);
this->addTid(body);
this->addParams(body);
return coreHandler
.request(uri, body) //
@ -122,30 +128,23 @@ class MethodSelect {
template < typename Pam_t >
tl::expected< PostMethod, Error > create(Pam_t & pam) const {
std::pmr::map< int, std::string > methods_id;
memory::StrictMonotonic_1k_HeapResource memoryResource;
std::pmr::map< int, std::string > methods_id{&memoryResource};
pam.print("%s", "");
int i{};
for(const auto & method : _methods) {
rublon::log(LogLevel::Debug, "method %s found at pos %d", method.c_str(), i);
if(method == "email") {
pam.print("%d: Email Link", i + 1);
methods_id[++i] = "email";
}
if(method == "qrcode") {
pam.print("%d: QR Code", i + 1);
methods_id[++i] = "qrcode";
}
if(method == "totp") {
pam.print("%d: Mobile TOTP", i + 1);
methods_id[++i] = "totp";
continue;
}
if(method == "push") {
pam.print("%d: Mobile Push", i + 1);
methods_id[++i] = "push";
}
if(method == "sms") {
pam.print("%d: SMS code", i + 1);
methods_id[++i] = "sms";
continue;
}
}

View File

@ -13,10 +13,49 @@ class PasscodeBasedAuth : public AuthenticationStep< PasscodeBasedAuth > {
const char * uri = "/api/transaction/confirmCode";
const char * userMessage{nullptr};
constexpr static bool isdigit(char ch) {
return std::isdigit(static_cast< unsigned char >(ch));
}
static bool hasDigitsOnly(const std::string & userinput) {
return std::all_of(userinput.cbegin(), userinput.cend(), isdigit);
}
static bool isProperLength(const std::string & userInput) {
return userInput.size() == 6;
}
template < typename PamInfo_t = LinuxPam >
std::string readPasscode(const PamInfo_t & pam) const {
/// TODO handle bad user input / wrong code etc
return pam.scan([](const char * userInput) { return std::string{userInput}; }, userMessage).value();
tl::expected< std::reference_wrapper< Document >, Error > readPasscode(Document & body, const PamInfo_t & pam) const {
auto & alloc = body.GetAllocator();
auto code = pam.scan([](const char * userInput) { return std::string{userInput}; }, userMessage);
if(code.has_value()) {
const auto & vericode = code.value();
if(isProperLength(vericode) and hasDigitsOnly(vericode)) {
body.AddMember("vericode", Value{vericode.c_str(), alloc}, alloc);
return body;
}
}
return tl::unexpected{Error{WerificationError{WerificationError::WrongCode}}};
}
template < typename PamInfo_t = LinuxPam >
tl::expected< std::reference_wrapper< Document >, Error > askForPasscodeAgain(Document & body, const PamInfo_t & pam) const {
pam.print("passcode has wrong number of digits or illegal characters, please correct");
return readPasscode(body, pam);
}
tl::expected< AuthenticationStatus, Error > checkAuthenticationStatus(const Document & coreResponse) const {
RapidJSONPMRStackAlloc< 1024 > alloc;
auto error = JSONPointer{"/result/error", &alloc}.Get(coreResponse);
if(error) {
return tl::unexpected{Error{WerificationError{WerificationError::WrongCode}}};
}
return AuthenticationStatus{AuthenticationStatus::Action::Confirmed};
}
public:
@ -30,20 +69,17 @@ class PasscodeBasedAuth : public AuthenticationStep< PasscodeBasedAuth > {
RapidJSONPMRStackAlloc< 1024 > alloc{};
Document body{rapidjson::kObjectType, &alloc};
this->addSystemToken(body, alloc);
this->addTid(body, alloc);
body.AddMember("vericode", Value{readPasscode(pam).c_str(), alloc}, alloc); /// TODO proper username
const auto checkAuthenticationStatus = [this](const auto & coreResponse) { return this->checkAuthenticationStatus(coreResponse); };
const auto requestAuthorization = [&](const auto & body) { return coreHandler.request(uri, body); };
const auto askForPasscodeAgain = [&](const auto & /*error*/) { return this->askForPasscodeAgain(body, pam); };
auto coreResponse = coreHandler.request(uri, body);
this->addSystemToken(body);
this->addTid(body);
if(coreResponse.has_value()) {
// {"status":"OK","result":{"error":"Hmm, that's not the right code. Try again."}}
// const auto & rublonResponse = coreResponse.value()["result"];
return AuthenticationStatus{AuthenticationStatus::Action::Confirmed};
} else {
return tl::unexpected{this->coreErrorHandler(coreResponse)};
}
return readPasscode(body, pam) //
.or_else(askForPasscodeAgain)
.and_then(requestAuthorization)
.and_then(checkAuthenticationStatus);
}
};

View File

@ -49,7 +49,7 @@ class LinuxPam {
void print(const char * fmt, Ti... ti) const noexcept {
pam_prompt(pamh, PAM_TEXT_INFO, nullptr, fmt, std::forward< Ti >(ti)...);
}
template < typename Fun, typename... Ti >
auto scan(Fun && f, const char * fmt, Ti... ti) const noexcept {
char * responseBuffer = nullptr;
@ -61,6 +61,5 @@ class LinuxPam {
}
return std::optional< std::result_of_t< Fun(char *) > >();
}
};
} // namespace rublon

View File

@ -1,5 +1,3 @@
#pragma once
namespace rublon {
} // namespace rublon
namespace rublon {} // namespace rublon

View File

@ -8,16 +8,17 @@
namespace rublon {
inline std::array< char, 64 > signData(std::string_view data, std::string_view secretKey) {
std::array< char, 64 > xRublon;
unsigned char md[EVP_MAX_MD_SIZE] = {0};
unsigned int md_len;
// +1 for \0
inline std::array< char, 64 + 1 > signData(std::string_view data, std::string_view secretKey) {
std::array< char, 64 + 1 > xRublon;
std::array< unsigned char, EVP_MAX_MD_SIZE > md;
unsigned int md_len{};
HMAC(EVP_sha256(), secretKey.data(), secretKey.size(), ( unsigned const char * ) data.data(), data.size(), md, &md_len);
HMAC(EVP_sha256(), secretKey.data(), secretKey.size(), ( unsigned const char * ) data.data(), data.size(), md.data(), &md_len);
int i;
for(i = 0; i < 32; i++)
for(unsigned int i = 0; i < md_len; i++)
sprintf(&xRublon[i * 2], "%02x", ( unsigned int ) md[i]);
return xRublon;
}

View File

@ -4,6 +4,7 @@
#include <array>
#include <cctype>
#include <cstring>
#include <iostream>
#include <map>
#include <memory>
#include <memory_resource>
@ -17,8 +18,78 @@
#include <security/pam_modules.h>
#include <syslog.h>
#include <unistd.h>
#include <utility>
namespace rublon {
namespace memory {
struct holder {
static inline std::pmr::memory_resource * _mr = std::pmr::get_default_resource();
};
inline void set_default_resource(std::pmr::memory_resource * memory_resource) {
holder{}._mr = memory_resource;
}
inline std::pmr::memory_resource * default_resource() {
return holder{}._mr;
}
template < std::size_t N >
class MonotonicStackResource : public std::pmr::monotonic_buffer_resource {
char _buffer[N];
public:
MonotonicStackResource() : std::pmr::monotonic_buffer_resource{_buffer, N, std::pmr::null_memory_resource()} {}
};
template < std::size_t N >
class UnsynchronizedStackResource : public std::pmr::unsynchronized_pool_resource {
MonotonicStackResource< N > _upstream;
public:
UnsynchronizedStackResource() : std::pmr::unsynchronized_pool_resource{&_upstream} {}
};
class MonotonicHeapResourceBase {
public:
std::pmr::memory_resource * _upstream{};
std::size_t _size{};
void * _buffer{nullptr};
MonotonicHeapResourceBase(std::size_t size) : _upstream{default_resource()}, _size{size}, _buffer{_upstream->allocate(size)} {}
~MonotonicHeapResourceBase() {
if(_buffer)
_upstream->deallocate(_buffer, _size);
}
};
template < std::size_t N >
class MonotonicHeapResource : MonotonicHeapResourceBase, public std::pmr::monotonic_buffer_resource {
public:
MonotonicHeapResource()
: MonotonicHeapResourceBase{N}, std::pmr::monotonic_buffer_resource{this->_buffer, this->_size, default_resource()} {}
};
template < std::size_t N >
class StrictMonotonicHeapResource : MonotonicHeapResourceBase, public std::pmr::monotonic_buffer_resource {
public:
StrictMonotonicHeapResource()
: MonotonicHeapResourceBase{N},
std::pmr::monotonic_buffer_resource{this->_buffer, this->_size, std::pmr::null_memory_resource()} {}
};
using StrictMonotonic_1k_HeapResource = StrictMonotonicHeapResource< 1 * 1024 >;
using StrictMonotonic_2k_HeapResource = StrictMonotonicHeapResource< 2 * 1024 >;
using StrictMonotonic_4k_HeapResource = StrictMonotonicHeapResource< 4 * 1024 >;
using StrictMonotonic_8k_HeapResource = StrictMonotonicHeapResource< 8 * 1024 >;
using Monotonic_1k_HeapResource = MonotonicHeapResource< 1 * 1024 >;
using Monotonic_2k_HeapResource = MonotonicHeapResource< 2 * 1024 >;
using Monotonic_4k_HeapResource = MonotonicHeapResource< 4 * 1024 >;
using Monotonic_8k_HeapResource = MonotonicHeapResource< 8 * 1024 >;
} // namespace memory
enum class LogLevel { Debug, Info, Warning, Error };
inline auto dateStr() {
@ -29,22 +100,23 @@ inline auto dateStr() {
return date;
}
constexpr const char* LogLevelNames[] {"Debug", "Info", "Warning", "Error"};
constexpr const char * LogLevelNames[]{"Debug", "Info", "Warning", "Error"};
constexpr LogLevel g_level = LogLevel::Debug;
constexpr bool syncLogFile = true;
namespace details{
namespace details {
static void doLog(LogLevel level, const char * line) noexcept {
constexpr auto file_name = "/var/log/rublon-ssh.log";
auto fp = std::unique_ptr< FILE, int (*)(FILE *) >(fopen(file_name, "a"), fclose);
if(fp) {
fprintf(fp.get(), "%s [%s] %s\n", dateStr().data(), LogLevelNames[(int)level], line);
sync();
fprintf(fp.get(), "%s [%s] %s\n", dateStr().data(), LogLevelNames[( int ) level], line);
if(syncLogFile)
sync();
}
}
}
} // namespace details
inline void log(LogLevel level, const char * line) noexcept {
if(level < g_level)
@ -92,9 +164,9 @@ class NonOwningPtr {
};
namespace details {
inline bool to_bool(std::string_view value) {
auto * buf = ( char * ) alloca(value.size());
/// TODO change to global allocator
auto * buf = ( char * ) alloca(value.size() + 1);
buf[value.size()] = '\0';
auto asciitolower = [](char in) { return in - ((in <= 'Z' && in >= 'A') ? ('Z' - 'z') : 0); };
@ -118,46 +190,25 @@ namespace details {
return ltrim(rtrim(s));
}
inline std::map< std::string, std::string > headers(std::string_view data) {
std::map< std::string, std::string > headers{};
template < typename Headers >
inline void headers(std::string_view data, Headers & headers) {
memory::StrictMonotonic_4k_HeapResource stackResource;
std::pmr::string tmp{&stackResource};
std::string tmp{};
std::istringstream resp{};
resp.rdbuf()->pubsetbuf(const_cast< char * >(data.data()), data.size());
while(std::getline(resp, tmp)) {
while(std::getline(resp, tmp) && !(trim(tmp).empty())) {
auto line = std::string_view(tmp);
auto index = tmp.find(':', 0);
if(index != std::string::npos) {
headers.insert({std::string{trim(line.substr(0, index))}, std::string{trim(line.substr(index + 1))}});
headers.insert({//
typename Headers::key_type{trim(line.substr(0, index)), headers.get_allocator()},
typename Headers::mapped_type{trim(line.substr(index + 1)), headers.get_allocator()}});
}
}
return headers;
}
namespace pmr {
inline std::pmr::map< std::pmr::string, std::pmr::string > headers(std::pmr::memory_resource * mr, std::string_view data) {
char _buf[1024];
std::pmr::monotonic_buffer_resource tmr{_buf, sizeof(_buf)};
std::pmr::map< std::pmr::string, std::pmr::string > headers{mr};
std::pmr::string tmp{&tmr};
std::istringstream resp{};
resp.rdbuf()->pubsetbuf(const_cast< char * >(data.data()), data.size());
while(std::getline(resp, tmp) && !(trim(tmp).empty())) {
auto line = std::string_view(tmp);
auto index = tmp.find(':', 0);
if(index != std::string::npos) {
headers.insert({std::pmr::string{trim(line.substr(0, index)), mr}, std::pmr::string{trim(line.substr(index + 1)), mr}});
}
}
return headers;
}
} // namespace pmr
} // namespace details
} // namespace rublon

View File

@ -4,9 +4,8 @@
#include <security/pam_misc.h>
#include <security/pam_modules.h>
#include <rapidjson/rapidjson.h>
#include <rublon/init.hpp>
#include <rublon/json.hpp>
#include <rublon/pam.hpp>
#include <rublon/rublon.hpp>
#include <rublon/utils.hpp>
@ -15,13 +14,6 @@
using namespace std;
namespace {
template < typename T >
using Expected = tl::expected< T, rublon::Error >;
} // namespace
DLL_PUBLIC int pam_sm_setcred([[maybe_unused]] pam_handle_t * pamh,
[[maybe_unused]] int flags,
[[maybe_unused]] int argc,
@ -42,6 +34,11 @@ pam_sm_authenticate(pam_handle_t * pamh, [[maybe_unused]] int flags, [[maybe_unu
auto rublonConfig = ConfigurationFactory{}.systemConfig();
std::byte sharedMemory[32 * 1024] = {};
std::pmr::monotonic_buffer_resource mr{sharedMemory, std::size(sharedMemory)};
std::pmr::unsynchronized_pool_resource rublonPoolResource{&mr};
std::pmr::set_default_resource(&rublonPoolResource);
CoreHandler CH{rublonConfig.value()};
LinuxPam pam{pamh};
@ -54,10 +51,17 @@ pam_sm_authenticate(pam_handle_t * pamh, [[maybe_unused]] int flags, [[maybe_unu
.and_then(selectMethod)
.and_then(confirmMethod)
.and_then(verifi);
if(authStatus.has_value()) {
rublon::log(rublon::LogLevel::Info, "Auth OK");
return PAM_SUCCESS;
} else {
const auto & error = authStatus.error();
rublon::log(
LogLevel::Error, "auth failed due to %d class and %d category", error.errorClass(), static_cast< int >(error.category()));
}
return PAM_SUCCESS;
rublon::log(LogLevel::Warning, "User login failed");
return PAM_PERM_DENIED;
}

View File

@ -25,7 +25,6 @@ add_executable(rublon-tests
./init_test.cpp
./method_select_tests.cpp
./passcode_auth_tests.cpp
./rublon_tests.cpp
./sign_tests.cpp
./utils_tests.cpp

View File

@ -4,13 +4,14 @@
#include <numeric>
#include <set>
#include "gtest_matchers.hpp"
#include "http_mock.hpp"
#include <rublon/core_handler.hpp>
using namespace rublon;
class CoreHandlerTestable : public CoreHandler< HttpHandlerMock > {
public:
CoreHandlerTestable() : CoreHandler< HttpHandlerMock >{conf} {}
CoreHandlerTestable(rublon::Configuration _conf = conf) : CoreHandler< HttpHandlerMock >{_conf} {}
auto & _http() {
return http;
}
@ -30,43 +31,47 @@ class CoreHandlerTests : public testing::Test {
using namespace testing;
MATCHER(HasValue, "") {
return arg.has_value();
}
bool operator==(const Error & lhs, const Error & rhs) {
return lhs.category() == rhs.category();
}
MATCHER_P(Unexpected, error, "") {
return arg.error() == error;
}
auto ReturnError(SocketError::ErrorClass value) {
return Return(tl::unexpected{Error{SocketError{value}}});
}
TEST_F(CoreHandlerTests, coreShouldSetConnectionErrorOnBrokenConnection) {
EXPECT_CALL(http, request(_, _)).WillOnce(ReturnError(SocketError::Timeout));
EXPECT_CALL(http, request(_, _)).WillOnce(Return(RawHttpResponse{}.withTimeoutError()));
EXPECT_THAT(sut.request("", doc), //
AllOf(Not(HasValue()), Unexpected(CoreHandlerError{CoreHandlerError{CoreHandlerError::ConnectionError}})));
AllOf(Not(HasValue()), Unexpected(HttpError{})));
}
TEST_F(CoreHandlerTests, coreShouldCheckSignatureAndReturnBadSignatureBeforeAnythingElse) {
EXPECT_CALL(http, request(_, _)).WillOnce(Return(( Response ) http.brokenSignature().brokenBody()));
EXPECT_CALL(http, request(_, _)).WillOnce(Return(RawHttpResponse{}.initMessage().withBrokenBody().withBrokenSignature()));
EXPECT_THAT(sut.request("", doc), //
AllOf(Not(HasValue()), Unexpected(CoreHandlerError{CoreHandlerError::BadSigature})));
}
TEST_F(CoreHandlerTests, coreShouldSetBrokenDataWhenResultIsNotAvailable) {
EXPECT_CALL(http, request(_, _)).WillOnce(Return(( Response ) http.brokenBody()));
EXPECT_CALL(http, request(_, _)).WillOnce(Return(RawHttpResponse{}.initMessage().withBrokenBody()));
EXPECT_THAT(sut.request("", doc), //
AllOf(Not(HasValue()), Unexpected(CoreHandlerError{CoreHandlerError::BrokenData})));
}
TEST_F(CoreHandlerTests, coreSignatureIsBeingChecked) {
EXPECT_CALL(http, request(Eq(conf.parameters.apiServer + "/path"), _)).WillOnce(Return(( Response ) http.statusPending()));
auto val = sut.request("/path", doc);
EXPECT_TRUE(val.value().IsObject());
EXPECT_CALL(http, request(Eq(conf.parameters.apiServer + "/path"), _)).WillOnce(Return(RawHttpResponse{}.initMessage()));
EXPECT_THAT(sut.request("/path", doc), //
AllOf(HasValue(), IsObject(), HasKey("/result/tid")));
}
TEST_F(CoreHandlerTests, onHttpProblemsAccessShouldBeDeniedByDefault) {
EXPECT_CALL(http, request(_, _)).WillOnce(Return(RawHttpResponse{}.initMessage().withServiceUnavailableError()));
EXPECT_THAT(sut.request("/path", doc), //
AllOf(Not(HasValue()), Unexpected(PamDeny{})));
}
class CoreHandlerWithBypassTests : public testing::Test {
public:
CoreHandlerWithBypassTests() : sut{confBypass}, http{sut._http()} {
}
CoreHandlerTestable sut;
HttpHandlerMock & http;
};
TEST_F(CoreHandlerWithBypassTests, onHttpProblemsAccessShouldBeBypassedWhenEnabled) {
EXPECT_CALL(http, request(_, _)).WillOnce(Return(RawHttpResponse{}.initMessage().withServiceUnavailableError()));
EXPECT_THAT(sut.request("/path", Document{}), //
AllOf(Not(HasValue()), Unexpected(PamBaypass{})));
}

View File

@ -1,9 +1,14 @@
#pragma once
#include <tl/expected.hpp>
#include <rublon/error.hpp>
#include <rublon/json.hpp>
#include <algorithm>
#include <string>
#include <set>
#include <array>
#include <set>
#include <string>
namespace io {
template < class X >
@ -18,7 +23,7 @@ constexpr T forward_or_transform(T t) {
}
template < class T, class = is_string< T > >
constexpr const char * forward_or_transform(const T &t) {
constexpr const char * forward_or_transform(const T & t) {
return t.c_str();
}
@ -28,72 +33,15 @@ int sprintf(char * _s, const std::string & format, Ti... t) {
}
} // namespace io
namespace {
std::string gen_random(const int len) {
static const char alphanum[] =
"0123456789"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ";
std::string tmp_s;
tmp_s.resize(len);
std::for_each(tmp_s.begin(), tmp_s.end(), [](auto & chr) { chr = alphanum[rand() % (sizeof(alphanum) - 1)]; });
return tmp_s;
}
static constexpr const char * result_ok_template =
R"json({"status":"OK","result":{"tid":"%s","status":"%s","companyName":"%s","applicationName":"%s","methods":[%s]}})json";
static constexpr const char * result_broken_template = R"json({"some":"random","json":"notrublon"})json";
static constexpr const char * result_broken_template = //
R"json({"some":"random","json":"notrublon"})json";
static constexpr const char * wrongPasscode = //
R"json({"status":"OK","result":{"error":"Hmm, that's not the right code. Try again."}})json";
} // namespace
class CoreResponseGenerator {
public:
std::string generateBody() {
std::array< char, 2048 > _buf;
io::sprintf(_buf.data(),
generateBrokenData ? result_broken_template : result_ok_template,
_tid,
status,
companyName,
applicationName,
print_methods());
return _buf.data();
}
CoreResponseGenerator & withMethods(std::initializer_list< std::string > newMethods) {
_methods.clear();
std::copy(newMethods.begin(), newMethods.end(), std::inserter(_methods, _methods.begin()));
return *this;
}
CoreResponseGenerator & withTid(std::string tid) {
_tid = tid;
return *this;
}
std::string _tid{generateTid()};
std::string status;
std::string companyName{"rublon"};
std::string applicationName{"test_app"};
std::set< std::string > _methods{"email", "totp", "qrcode", "push"};
bool skipSignatureGeneration{false};
bool generateBrokenData{false};
private:
std::string print_methods() {
std::string ret;
for(const auto & m : _methods)
ret += "\"" + m + "\",";
ret.pop_back();
return ret;
}
static std::string generateTid() {
return gen_random(32);
}
};

View File

@ -0,0 +1,39 @@
#pragma once
#include <gmock/gmock.h>
#include <rublon/json.hpp>
#include <rublon/error.hpp>
using namespace testing;
inline bool operator==(const rublon::Error & lhs, const rublon::Error & rhs) {
return lhs.category() == rhs.category();
}
MATCHER(HasValue, "") {
return arg.has_value();
}
MATCHER(IsObject, "") {
return arg->IsObject();
}
MATCHER(IsPAMBaypass,""){
return arg.error().category() == rublon::Error::k_PamBaypass;
}
MATCHER(IsPAMDeny,""){
return arg.error().category() == rublon::Error::k_PamDeny;
}
MATCHER_P(HasKey, key, "") {
return rublon::JSONPointer{key}.Get(arg.value()) != nullptr;
}
MATCHER_P(Unexpected, error, "") {
rublon::Error e = arg.error();
return e == error;
}

View File

@ -5,8 +5,8 @@
#include <set>
#include <string>
#include <rublon/curl.hpp>
#include <rublon/configuration.hpp>
#include <rublon/curl.hpp>
#include "core_response_generator.hpp"
#include "rublon/sign.hpp"
@ -19,43 +19,219 @@ rublon::Configuration conf{rublon::Configuration::Parameters{//
1,
true,
true,
false,
false}};
rublon::Configuration confBypass{rublon::Configuration::Parameters{//
"320BAB778C4D4262B54CD243CDEFFAFD",
"39e8d771d83a2ed3cc728811911c25",
"https://staging-core.rublon.net",
1,
true,
true,
false,
true}};
inline std::string randomTID() {
static const char alphanum[] =
"0123456789"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ";
std::string tmp_s;
const auto len = 32;
tmp_s.resize(len);
std::for_each(tmp_s.begin(), tmp_s.end(), [](auto & chr) { chr = alphanum[rand() % (sizeof(alphanum) - 1)]; });
return tmp_s;
}
std::string join_methods(std::set< std::string > _methods) {
std::string ret;
for(const auto & m : _methods)
ret += "\"" + m + "\",";
ret.pop_back();
return ret;
}
} // namespace
class HttpHandlerMock : public CoreResponseGenerator {
auto signResponse(rublon::Response & res) {
struct ResponseMockOptions {
bool generateBrokenSigneture{false};
bool generateBrokenBody{false};
bool generateBrokenCode{false};
std::string coreException{};
std::set< std::string > methods{};
rublon::Error error;
};
class InitCoreResponse {
static constexpr const char * result_ok_template =
R"json({"status":"OK","result":{"tid":"%s","status":"%s","companyName":"%s","applicationName":"%s","methods":[%s]}})json";
static constexpr const char * result_broken_template = //
R"json({"some":"random","json":"notrublon"})json";
public:
static std::string generate(const ResponseMockOptions & options) {
if(options.generateBrokenBody) {
return result_broken_template;
}
std::array< char, 2048 > _buf;
auto tid = randomTID();
auto status = "pending";
std::string companyName{"rublon"};
std::string applicationName{"test_app"};
std::set< std::string > methods{"email", "totp", "qrcode", "push"};
if(options.methods.size())
methods = options.methods;
io::sprintf(_buf.data(), result_ok_template, tid, status, companyName, applicationName, join_methods(methods));
return _buf.data();
}
};
class MethodSelectCoreResponse {
static constexpr const char * result_ok_template =
R"json({"status":"OK","result":{"tid":"%s","status":"%s","companyName":"%s","applicationName":"%s","methods":[%s]}})json";
static constexpr const char * result_broken_template = //
R"json({"some":"random","json":"notrublon"})json";
public:
static std::string generate(const ResponseMockOptions & options) {
if(options.generateBrokenBody) {
return result_broken_template;
}
std::array< char, 2048 > _buf;
auto tid = randomTID();
auto status = "pending";
std::string companyName{"rublon"};
std::string applicationName{"test_app"};
std::set< std::string > methods{"email", "totp", "qrcode", "push"};
io::sprintf(_buf.data(), result_ok_template, tid, status, companyName, applicationName, join_methods(methods));
return _buf.data();
}
};
class CodeVerificationResponse {
static constexpr const char * wrongPasscode =
R"json({"status":"OK","result":{"error":"Hmm, that's not the right code. Try again."}})json";
static constexpr const char * goodPasscode = R"json({"status":"OK","result":true})json";
public:
static std::string generate(const ResponseMockOptions & options) {
return options.generateBrokenCode ? wrongPasscode : goodPasscode;
}
};
template < typename Generator >
class ResponseBase {
public:
ResponseBase & initMessage() {
_coreGenerator = InitCoreResponse{};
return *this;
}
ResponseBase & selectMethodMessage() {
_coreGenerator = MethodSelectCoreResponse{};
return *this;
}
ResponseBase & codeConfirmationMessage() {
_coreGenerator = CodeVerificationResponse{};
return *this;
}
ResponseBase & withBrokenBody() {
options.generateBrokenBody = true;
return *this;
}
ResponseBase & withCoreException() {
options.coreException = "some exception";
return *this;
}
ResponseBase & withBrokenSignature() {
options.generateBrokenSigneture = true;
return *this;
}
ResponseBase & withWrongCodeResponse() {
options.generateBrokenCode = true;
return *this;
}
ResponseBase & withTimeoutError() {
options.error = rublon::Error{rublon::HttpError{}};
return *this;
}
ResponseBase & withServiceUnavailableError() {
options.error = rublon::Error{rublon::HttpError{rublon::HttpError::Error, 405}};
return *this;
}
template < typename... T >
ResponseBase & withMethods(T... methods) {
(options.methods.insert(methods), ...);
return *this;
}
ResponseBase & withConnectionError() {
// options.error = rublon::Error{rublon::CoreHandlerError{rublon::CoreHandlerError::ConnectionError}};
return *this;
}
operator auto() {
return options.error.category() == rublon::Error::k_None ? static_cast< Generator * >(this)->generate() : tl::unexpected{error()};
}
rublon::Error error() {
return rublon::Error{options.error};
}
std::variant< InitCoreResponse, MethodSelectCoreResponse, CodeVerificationResponse > _coreGenerator;
ResponseMockOptions options;
};
class RawCoreResponse : public ResponseBase< RawCoreResponse > {
public:
tl::expected< rublon::Document, rublon::Error > generate() {
auto jsonString = std::visit([&](const auto generator) { return generator.generate(options); }, _coreGenerator);
rublon::Document doc;
doc.Parse(jsonString.c_str());
return doc;
}
};
class RawHttpResponse : public ResponseBase< RawHttpResponse > {
static auto signResponse(rublon::Response & res, ResponseMockOptions opts) {
const auto & sign =
skipSignatureGeneration ? std::array< char, 64 >{} : rublon::signData(res.body, conf.parameters.secretKey.c_str());
opts.generateBrokenSigneture ? std::array< char, 65 >{} : rublon::signData(res.body, conf.parameters.secretKey.c_str());
res.headers["x-rublon-signature"] = sign.data();
}
public:
template < typename... Args >
HttpHandlerMock(const Args &...) {}
MOCK_METHOD((tl::expected< rublon::Response, rublon::Error >), request, ( std::string_view, const rublon::Request & ), (const));
HttpHandlerMock & statusPending() {
status = "pending";
return *this;
}
HttpHandlerMock & brokenBody() {
generateBrokenData = true;
return *this;
}
HttpHandlerMock & brokenSignature() {
skipSignatureGeneration = true;
return *this;
}
operator rublon::Response() {
rublon::Response res;
res.body = generateBody();
signResponse(res);
return res;
tl::expected< rublon::Response, rublon::Error > generate() {
rublon::Response response{std::pmr::get_default_resource()};
response.body = std::visit([&](const auto generator) { return generator.generate(options); }, _coreGenerator);
signResponse(response, options);
return response;
}
};
class HttpHandlerMock {
public:
template < typename... Args >
HttpHandlerMock(const Args &...) {}
MOCK_METHOD(( tl::expected< rublon::Response, rublon::Error > ), request, ( std::string_view, const rublon::Request & ), (const));
};

View File

@ -3,65 +3,21 @@
#include <rublon/init.hpp>
#include "core_response_generator.hpp"
#include "core_handler_mock.hpp"
#include "gtest_matchers.hpp"
#include "http_mock.hpp"
#include "pam_info_mock.hpp"
using namespace rublon;
using namespace testing;
namespace {
Configuration conf;
}
class CoreHandlerMock : public CoreHandlerInterface< CoreHandlerMock > {
public:
CoreHandlerMock() {}
MOCK_METHOD(( tl::expected< Document, Error > ), request, ( std::string_view, const Document & ), (const));
CoreHandlerMock & statusPending() {
gen.status = "pending";
return *this;
}
CoreHandlerMock & brokenBody() {
gen.generateBrokenData = true;
return *this;
}
CoreHandlerMock & methods(std::initializer_list< std::string > methods) {
gen.withMethods(methods);
return *this;
}
CoreHandlerMock & tid(std::string tid) {
gen.withTid(tid);
return *this;
}
operator tl::expected< Document, Error >() {
auto body = gen.generateBody();
rublon::Document doc;
doc.Parse(body.c_str());
return doc;
}
auto create() {
return static_cast< tl::expected< Document, Error > >(*this);
}
CoreResponseGenerator gen;
};
class MethodFactoryMock {
public:
using MethodFactoryCreate_t = tl::expected< MethodProxy, Error >;
template < typename... Args >
MethodFactoryMock(Args &&...) {}
MOCK_METHOD(MethodFactoryCreate_t, create, (const mocks::PamInfoMock &, std::string tid), (const));
};
@ -76,48 +32,42 @@ class RublonHttpInitTest : public testing::Test {
expectDefaultPamInfo();
}
CoreHandlerMock coreHandler;
mocks::CoreHandlerMock coreHandler;
mocks::PamInfoMock pam;
Init< MethodFactoryMock > sut;
// MethodFactoryMock &methodFactoryMock;
};
using CoreReturn = tl::expected< Document, CoreHandlerError >;
TEST_F(RublonHttpInitTest, initializationSendsRequestOnGoodPath) {
EXPECT_CALL(coreHandler, request("/api/transaction/init", _))
.WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::BadSigature}}));
EXPECT_CALL(coreHandler, request("/api/transaction/init", _)).WillOnce(Return(RawCoreResponse{}.initMessage()));
sut.handle(coreHandler, pam);
}
/// TODO fix
//MATCHER_P(HoldsPamAction, action, "") {
// return not arg.has_value() && arg.error() == action;
//}
TEST_F(RublonHttpInitTest, rublon_Accept_pamLoginWhenThereIsNoConnection) {
EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(RawCoreResponse{}.withTimeoutError()));
EXPECT_THAT(sut.handle(coreHandler, pam), //
AllOf(Not(HasValue()), IsPAMBaypass()));
}
//TEST_F(RublonHttpInitTest, rublon_Accept_pamLoginWhenThereIsNoConnection) {
// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::ConnectionError}}));
// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline));
//}
// TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerHasBadSignature) {
// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::BadSigature}}));
// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline));
// }
//TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerHasBadSignature) {
// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::BadSigature}}));
// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline));
//}
// TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerReturnsBrokenData) {
// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::BrokenData}}));
// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline));
// }
//TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerReturnsBrokenData) {
// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::BrokenData}}));
// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline));
//}
// TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerReturnsCoreException) {
// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::CoreException}}));
// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline));
// }
//TEST_F(RublonHttpInitTest, rublon_Decline_pamLoginWhenServerReturnsCoreException) {
// EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(tl::unexpected{CoreHandlerError{CoreHandlerError::CoreException}}));
// EXPECT_THAT(sut.handle(coreHandler, pam), HoldsPamAction(PamAction::decline));
//}
//TEST_F(RublonHttpInitTest, AllNeededInformationNeedsToBePassedToMethodFactory) {
// EXPECT_CALL(coreHandler, request(_, _))
// .WillOnce(Return(coreHandler.statusPending().methods({"sms", "otp"}).tid("transaction ID").create()));
// // EXPECT_CALL(methodFactoryMock, create_mocked(_,"transaction ID") );
// sut.handle(coreHandler, pam);
//}
TEST_F(RublonHttpInitTest, AllNeededInformationNeedsToBePassedToMethodFactory) {
EXPECT_CALL(coreHandler, request(_, _))
.WillOnce(Return( RawCoreResponse{}.initMessage().withMethods("sms", "otp")));
// EXPECT_CALL(methodFactoryMock, create_mocked(_,"transaction ID") );
sut.handle(coreHandler, pam);
}

View File

@ -1,6 +1,6 @@
#pragma once
#include <gmock/gmock.h>>
#include <gmock/gmock.h>
#include <rublon/utils.hpp>
@ -9,5 +9,19 @@ class PamInfoMock {
public:
MOCK_METHOD(rublon::NonOwningPtr< const char >, ip, (), (const));
MOCK_METHOD(rublon::NonOwningPtr< const char >, username, (), (const));
MOCK_METHOD(std::string, scan_mock, (const char * fmt), (const));
MOCK_METHOD(void, print_mock, (const char * fmt), (const));
template < typename Fun, typename... Ti >
auto scan(Fun && f, const char * fmt, Ti...) const noexcept -> std::optional< std::result_of_t< Fun(char *) > > {
const auto & responseBuffer = scan_mock(fmt);
return responseBuffer.empty() ? std::nullopt : std::optional{f(responseBuffer.c_str())};
}
template < typename... Ti >
void print(const char * fmt, Ti...) const noexcept {
print_mock(fmt);
}
};
} // namespace mocks

View File

@ -3,6 +3,8 @@
#include <rublon/method/passcode_based_auth.hpp>
#include "core_handler_mock.hpp"
#include "gtest_matchers.hpp"
#include "http_mock.hpp"
#include "pam_info_mock.hpp"
using namespace testing;
@ -21,8 +23,38 @@ class PasscodeBasedAuthTest : public Test {
mocks::PamInfoMock pam;
};
TEST_F(PasscodeBasedAuthTest, wrongPasscodeShouldFail){
EXPECT_THAT(coreHandler, request(_,_) );
TEST_F(PasscodeBasedAuthTest, wrongPasscodeShouldFail) {
EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(RawCoreResponse{}.codeConfirmationMessage().withWrongCodeResponse()));
EXPECT_CALL(pam, scan_mock(_)).WillOnce(Return("123456"));
sut.handle(coreHandler, pam);
EXPECT_THAT(sut.handle(coreHandler, pam), Unexpected(WerificationError{WerificationError::WrongCode}));
}
TEST_F(PasscodeBasedAuthTest, whenGivenBadPasscodeWeNeedToAskAgain) {
EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(RawCoreResponse{}.codeConfirmationMessage()));
EXPECT_CALL(pam, scan_mock(_))//
.WillOnce(Return("1_3456"))
.WillOnce(Return("123456"));
EXPECT_CALL(pam, print_mock(_) );
EXPECT_THAT(sut.handle(coreHandler, pam), HasValue() );
}
TEST_F(PasscodeBasedAuthTest, whenGiveenBadPasscodeMultipleTimesWeAbort) {
EXPECT_CALL(pam, scan_mock(_))//
.WillOnce(Return("1_3456"))
.WillOnce(Return("12345_"));
EXPECT_CALL(pam, print_mock(_) );
EXPECT_THAT(sut.handle(coreHandler, pam), Unexpected(WerificationError{WerificationError::WrongCode}));
}
TEST_F(PasscodeBasedAuthTest, goodPasscodeShouldPass){
EXPECT_CALL(coreHandler, request(_, _)).WillOnce(Return(RawCoreResponse{}.codeConfirmationMessage()));
EXPECT_CALL(pam, scan_mock(_)).WillOnce(Return("123456"));
EXPECT_THAT(sut.handle(coreHandler, pam), HasValue() );
}

View File

@ -1,7 +0,0 @@
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <rublon/rublon.hpp>
using namespace testing;

View File

@ -2,9 +2,17 @@
#include <rublon/sign.hpp>
TEST(sign, AProperSignIsReturnedFromSignData){
using namespace std::string_literals;
using namespace std::string_view_literals;
TEST(sign, AProperSignIsReturnedFromSignData) {
const auto signKey = "39e8d771d83a2ed3cc728811911c25";
const auto sign = "a6385672ee92eb37fe661d0cf5eb37a8a496663d7c27943e226e56a451d4b3d6";
EXPECT_STREQ(rublon::signData("{data}", signKey).data(), sign);
const auto sign = "a6385672ee92eb37fe661d0cf5eb37a8a496663d7c27943e226e56a451d4b3d6"s;
EXPECT_EQ((std::string{rublon::signData("{data}", signKey).data(), 64}), sign);
}
TEST(sign, AZeroTerminatedStringIsReturned) {
const auto signKey = "39e8d771d83a2ed3cc728811911c25";
const auto sign = "a6385672ee92eb37fe661d0cf5eb37a8a496663d7c27943e226e56a451d4b3d6"s;
EXPECT_TRUE(*rublon::signData("{data}", signKey).rbegin() == '\0');
}

View File

@ -28,63 +28,86 @@ x-ratelimit-remaining: 299
{"status":"OK","result":{"tid":"2039132542F6465691BF8C41D7CC38C5","status":"pending","companyName":"rublon","applicationName":"Bartoszek_SSH","methods":["email","totp","qrcode","push"]}})http";
static inline std::string_view ltrim(std::string_view s) {
while(std::isspace(*s.begin()))
s.remove_prefix(1);
return s;
}
class LoggingAllocator : public std::pmr::memory_resource {
public:
int count{};
int bytes{};
static inline std::string_view rtrim(std::string_view s) {
while(std::isspace(*s.rbegin()))
s.remove_suffix(1);
return s;
}
static inline std::string_view trim(std::string_view s) {
return ltrim(rtrim(s));
}
std::pmr::map< std::pmr::string, std::pmr::string > headers(std::pmr::memory_resource * mr, std::string_view data) {
std::pmr::map< std::pmr::string, std::pmr::string > headers{mr};
std::pmr::string tmp{mr};
tmp.reserve(256);
std::istringstream resp{};
resp.rdbuf()->pubsetbuf(const_cast< char * >(data.data()), data.size());
while(std::getline(resp, tmp) && !(trim(tmp).empty())) {
auto line = std::string_view(tmp);
auto index = tmp.find(':', 0);
if(index != std::string::npos) {
headers.insert({std::pmr::string{trim(line.substr(0, index)), mr}, std::pmr::string{trim(line.substr(index + 1)), mr}});
}
void clear() {
count = 0;
bytes = 0;
}
// memory_resource interface
private:
void * do_allocate(std::size_t __bytes, std::size_t __alignment) override {
bytes += __bytes;
count++;
return std::pmr::get_default_resource()->allocate(__bytes, __alignment);
}
return headers;
}
inline std::map< std::string, std::string > headers(std::string_view data) {
std::map< std::string, std::string > headers{};
std::string tmp{};
std::istringstream resp{};
resp.rdbuf()->pubsetbuf(const_cast< char * >(data.data()), data.size());
while(std::getline(resp, tmp) && !(trim(tmp).empty())) {
auto line = std::string_view(tmp);
auto index = tmp.find(':', 0);
if(index != std::string::npos) {
headers.insert({std::string{trim(line.substr(0, index))}, std::string{trim(line.substr(index + 1))}});
}
void do_deallocate(void * __p, std::size_t __bytes, std::size_t __alignment) override {
bytes -= __bytes;
count--;
std::pmr::get_default_resource()->deallocate(__p, __bytes, __alignment);
}
return headers;
}
bool do_is_equal(const memory_resource & __other) const noexcept override {
return std::pmr::get_default_resource()->is_equal(__other);
}
};
TEST(Utils, responseParser) {
auto sut = headers(response);
memory::MonotonicStackResource< 2048 > resource;
std::map<std::string, std::string> sut{};
rublon::details::headers(response, sut);
EXPECT_THAT(sut,
AllOf(Contains(Pair("date", "Thu, 22 Jun 2023 13:24:58 GMT")),
Contains(Pair("x-rublon-signature", "1a01558bedaa2dd92ff659fb8ee3c65a89163d63e312fcb9b6f60463cce864d7")),
Contains(Pair("x-ratelimit-remaining", "299"))));
}
class Memory : public testing::Test {
public:
// Test interface
protected:
void SetUp() override {
memory::set_default_resource(&alloc);
}
void TearDown() override {
memory::set_default_resource(std::pmr::get_default_resource());
}
LoggingAllocator alloc;
};
class X {};
TEST_F(Memory, setSetsProperResource) {
LoggingAllocator s;
memory::set_default_resource(&s);
EXPECT_EQ(&s, memory::default_resource());
}
TEST_F(Memory, stackResourceDoesNotUseHeapAtAll) {
memory::MonotonicStackResource< 1024 > mr;
std::pmr::string{"to large to fit in small buffer optimization", &mr};
EXPECT_EQ(alloc.count, 0);
}
TEST_F(Memory, heapResourceWillAllocate1kWhenAsk) {
memory::set_default_resource(&alloc);
memory::StrictMonotonic_1k_HeapResource mr;
std::pmr::string{"to large to fit in small buffer optimization", &mr};
EXPECT_EQ(alloc.count, 1);
EXPECT_EQ(alloc.bytes, 1024);
}
TEST_F(Memory, heapResourceWillAllocate8kWhenAsk) {
memory::StrictMonotonic_8k_HeapResource mr;
std::pmr::string{"to large to fit in small buffer optimization", &mr};
EXPECT_EQ(alloc.count, 1);
EXPECT_EQ(alloc.bytes, 8 * 1024);
}