rublon-ssh/PAM/ssh/include/rublon/method/method_select.hpp
Bartosz Wieczorek 700845e17a refactor
2023-08-22 13:34:40 +02:00

174 lines
5.4 KiB
C++

#pragma once
#include <tl/expected.hpp>
#include <variant>
#include <rublon/core_handler.hpp>
#include <rublon/pam.hpp>
#include <rublon/pam_action.hpp>
#include <rublon/method/OTP.hpp>
#include <rublon/method/SMS.hpp>
template < class F >
struct return_type;
template < class R, class... A >
struct return_type< R (*)(A...) > {
typedef R type;
};
template < typename T >
using return_type_t = typename return_type< T >::type;
namespace std {
inline ::rublon::Value::ConstValueIterator begin(const rublon::Value & __ils) noexcept {
return __ils.Begin();
}
inline ::rublon::Value::ConstValueIterator end(const rublon::Value & __ils) noexcept {
return __ils.End();
}
[[nodiscard]] inline std::size_t size(const rublon::Value & __cont) {
return __cont.Size();
}
} // namespace std
namespace rublon {
class MethodProxy {
public:
template < typename Method_t >
MethodProxy(Method_t method) : _impl{std::move(method)} {}
template < typename Handler_t, typename PamInfo_t = LinuxPam >
tl::expected< AuthenticationStatus, Error > fire(const CoreHandlerInterface< Handler_t > & coreHandler, const PamInfo_t & pam) const {
return std::visit(
[&](const auto & method) {
rublon::log(LogLevel::Info, "Using '%s' method", method.name);
return method.fire(coreHandler, pam);
},
_impl);
}
private:
std::variant< method::OTP, method::SMS > _impl;
};
class PostMethod : public rublon::AuthenticationStep< PostMethod > {
using base_t = rublon::AuthenticationStep< PostMethod >;
const char * uri = "/api/transaction/methodSSH";
std::string _method;
tl::expected< MethodProxy, Error > doCreateMethod(const Document & coreResponse) const {
const auto & rublonResponse = coreResponse["result"];
std::string tid = rublonResponse["tid"].GetString();
if(_method == "totp") {
return MethodProxy{method::OTP{this->_systemToken, std::move(tid)}};
} else if(_method == "sms") {
return MethodProxy{method::SMS{this->_systemToken, std::move(tid)}};
}
else
return tl::unexpected{MethodError{MethodError::BadMethod}};
}
public:
const char * name = "Confirm Method";
PostMethod(std::string systemToken, std::string tid, std::string method)
: base_t(std::move(systemToken), std::move(tid)), _method{method} {}
template < typename Hander_t, typename PamInfo_t = LinuxPam >
tl::expected< MethodProxy, Error > handle(const CoreHandlerInterface< Hander_t > & coreHandler) const {
auto createMethod = [&](const auto & coreResponse) { return doCreateMethod(coreResponse); };
RapidJSONPMRStackAlloc< 1024 > alloc{};
Document body{rapidjson::kObjectType, &alloc};
this->addSystemToken(body, alloc);
this->addTid(body, alloc);
body.AddMember("method", Value{_method.c_str(), alloc}, alloc);
body.AddMember("GDPRAccepted", "true", alloc);
body.AddMember("tosAccepted", "true", alloc);
return coreHandler
.request(uri, body) //
.and_then(createMethod);
}
};
class MethodSelect {
std::string _systemToken;
std::string _tid;
std::vector< std::string > _methods;
public:
template < typename Array_t >
MethodSelect(std::string systemToken, std::string tid, const Array_t & methodsAvailableForUser)
: _systemToken{std::move(systemToken)}, _tid{std::move(tid)} {
_methods.reserve(std::size(methodsAvailableForUser));
std::transform(
std::begin(methodsAvailableForUser), std::end(methodsAvailableForUser), std::back_inserter(_methods), [](const auto & method) {
return method.GetString();
});
}
template < typename Pam_t >
tl::expected< PostMethod, Error > create(Pam_t & pam) const {
std::pmr::map< int, std::string > methods_id;
pam.print("%s", "");
int i{};
for(const auto & method : _methods) {
rublon::log(LogLevel::Debug, "method %s found at pos %d", method.c_str(), i);
if(method == "email") {
pam.print("%d: Email Link", i + 1);
methods_id[++i] = "email";
}
if(method == "qrcode") {
pam.print("%d: QR Code", i + 1);
methods_id[++i] = "qrcode";
}
if(method == "totp") {
pam.print("%d: Mobile TOTP", i + 1);
methods_id[++i] = "totp";
}
if(method == "push") {
pam.print("%d: Mobile Push", i + 1);
methods_id[++i] = "push";
}
if(method == "sms") {
pam.print("%d: SMS code", i + 1);
methods_id[++i] = "sms";
}
}
auto methodid = pam.scan(
[](char * userinput) {
rublon::log(LogLevel::Debug, "User input: %s", userinput);
return std::stoi(userinput);
},
"\nSelect method [1-%d]: ",
_methods.size());
pam.print(
"you selected: %s", methods_id.count(methodid.value_or(0)) ? methods_id.at(methodid.value_or(0)).c_str() : "unknown option");
/// TODO check if valid method
if(auto it = methods_id.find(methodid.value()); it != methods_id.end()) {
return PostMethod{_systemToken, _tid, it->second};
}
return tl::unexpected{Critical{}};
}
};
} // namespace rublon