ZMQ Hardening (#96)

This commit is contained in:
Lee *!* Clagett
2024-03-16 21:40:48 -04:00
committed by Lee *!* Clagett
parent ffdd8da2a9
commit f66943dce1
30 changed files with 564 additions and 295 deletions

View File

@@ -76,7 +76,7 @@ namespace wire
template<typename R, typename T, std::size_t N>
inline void read_bytes(R& source, std::array<T, N>& dest)
{
std::size_t count = source.start_array();
std::size_t count = source.start_array(0);
const bool json = (count == 0);
if (!json && count != dest.size())
WIRE_DLOG_THROW(wire::error::schema::array, "Expected array of size " << dest.size());

View File

@@ -42,6 +42,10 @@ namespace wire
return "No schema errors";
case schema::array:
return "Schema expected array";
case schema::array_max_element:
return "Schema expected array size to be smaller";
case schema::array_min_size:
return "Schema expected minimum wire size per array element to be larger";
case schema::binary:
return "Schema expected binary value of variable size";
case schema::boolean:

View File

@@ -54,6 +54,8 @@ namespace wire
{
none = 0, //!< Must be zero for `expect<..>`
array, //!< Expected an array value
array_max_element,//!< Exceeded max array count
array_min_size, //!< Below min element wire size
binary, //!< Expected a binary value of variable length
boolean, //!< Expected a boolean value
enumeration, //!< Expected a value from a specific set

View File

@@ -48,13 +48,13 @@ namespace
};
//! \throw std::system_error by converting `code` into a std::error_code
[[noreturn]] void throw_json_error(const epee::span<char> source, const rapidjson::Reader& reader, const wire::error::schema expected)
[[noreturn]] void throw_json_error(const epee::span<const std::uint8_t> source, const rapidjson::Reader& reader, const wire::error::schema expected)
{
const std::size_t offset = std::min(source.size(), reader.GetErrorOffset());
const std::size_t start = offset;//std::max(snippet_size / 2, offset) - (snippet_size / 2);
const std::size_t end = start + std::min(snippet_size, source.size() - start);
const boost::string_ref text{source.data() + start, end - start};
const boost::string_ref text{reinterpret_cast<const char*>(source.data()) + start, end - start};
const rapidjson::ParseErrorCode parse_error = reader.GetParseErrorCode();
switch (parse_error)
{
@@ -178,17 +178,19 @@ namespace wire
void json_reader::read_next_value(rapidjson_sax& handler)
{
rapidjson::InsituStringStream stream{current_.data()};
if (!reader_.Parse<rapidjson::kParseStopWhenDoneFlag>(stream, handler))
throw_json_error(current_, reader_, handler.expected_);
current_.remove_prefix(stream.Tell());
rapidjson::MemoryStream stream{reinterpret_cast<const char*>(remaining_.data()), remaining_.size()};
rapidjson::EncodedInputStream<rapidjson::UTF8<>, rapidjson::MemoryStream> istream{stream};
if (!reader_.Parse<rapidjson::kParseStopWhenDoneFlag>(istream, handler))
throw_json_error(remaining_, reader_, handler.expected_);
remaining_.remove_prefix(istream.Tell());
}
char json_reader::get_next_token()
{
rapidjson::InsituStringStream stream{current_.data()};
rapidjson::SkipWhitespace(stream);
current_.remove_prefix(stream.Tell());
rapidjson::MemoryStream stream{reinterpret_cast<const char*>(remaining_.data()), remaining_.size()};
rapidjson::EncodedInputStream<rapidjson::UTF8<>, rapidjson::MemoryStream> istream{stream};
rapidjson::SkipWhitespace(istream);
remaining_.remove_prefix(istream.Tell());
return stream.Peek();
}
@@ -196,15 +198,15 @@ namespace wire
{
if (get_next_token() != '"')
WIRE_DLOG_THROW_(error::schema::string);
current_.remove_prefix(1);
remaining_.remove_prefix(1);
void const* const end = std::memchr(current_.data(), '"', current_.size());
void const* const end = std::memchr(remaining_.data(), '"', remaining_.size());
if (!end)
WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorStringMissQuotationMark));
char const* const begin = current_.data();
const std::size_t length = current_.remove_prefix(static_cast<const char*>(end) - current_.data() + 1);
return {begin, length - 1};
std::uint8_t const* const begin = remaining_.data();
const std::size_t length = remaining_.remove_prefix(static_cast<const std::uint8_t*>(end) - remaining_.data() + 1);
return {reinterpret_cast<const char*>(begin), length - 1};
}
void json_reader::skip_value()
@@ -214,11 +216,12 @@ namespace wire
}
json_reader::json_reader(std::string&& source)
: reader(),
: reader(nullptr),
source_(std::move(source)),
current_(std::addressof(source_[0]), source_.size()),
reader_()
{}
{
remaining_ = {reinterpret_cast<const std::uint8_t*>(source_.data()), source_.size()};
}
void json_reader::check_complete() const
{
@@ -271,13 +274,13 @@ namespace wire
{
if (get_next_token() != '"')
WIRE_DLOG_THROW_(error::schema::string);
current_.remove_prefix(1);
remaining_.remove_prefix(1);
const std::uintmax_t out = unsigned_integer();
if (get_next_token() != '"')
WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorStringMissQuotationMark));
current_.remove_prefix(1);
remaining_.remove_prefix(1);
return out;
}
@@ -316,11 +319,11 @@ namespace wire
WIRE_DLOG_THROW(error::schema::fixed_binary, "of size" << dest.size() * 2 << " but got " << value.size());
}
std::size_t json_reader::start_array()
std::size_t json_reader::start_array(std::size_t)
{
if (get_next_token() != '[')
WIRE_DLOG_THROW_(error::schema::array);
current_.remove_prefix(1);
remaining_.remove_prefix(1);
increment_depth();
return 0;
}
@@ -332,7 +335,7 @@ namespace wire
WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorArrayMissCommaOrSquareBracket));
if (next == ']')
{
current_.remove_prefix(1);
remaining_.remove_prefix(1);
return true;
}
@@ -340,7 +343,7 @@ namespace wire
{
if (next != ',')
WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorArrayMissCommaOrSquareBracket));
current_.remove_prefix(1);
remaining_.remove_prefix(1);
}
return false;
}
@@ -349,7 +352,7 @@ namespace wire
{
if (get_next_token() != '{')
WIRE_DLOG_THROW_(error::schema::object);
current_.remove_prefix(1);
remaining_.remove_prefix(1);
increment_depth();
return 0;
}
@@ -377,7 +380,7 @@ namespace wire
WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorObjectMissCommaOrCurlyBracket));
if (next == '}')
{
current_.remove_prefix(1);
remaining_.remove_prefix(1);
return false;
}
@@ -386,7 +389,7 @@ namespace wire
{
if (next != ',')
WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorObjectMissCommaOrCurlyBracket));
current_.remove_prefix(1);
remaining_.remove_prefix(1);
}
++state;
@@ -395,7 +398,7 @@ namespace wire
index = process_key(json_key.value.string);
if (get_next_token() != ':')
WIRE_DLOG_THROW_(error::rapidjson_e(rapidjson::kParseErrorObjectMissColon));
current_.remove_prefix(1);
remaining_.remove_prefix(1);
// parse value
if (index != map.size())

View File

@@ -48,7 +48,6 @@ namespace wire
struct rapidjson_sax;
std::string source_;
epee::span<char> current_;
rapidjson::Reader reader_;
void read_next_value(rapidjson_sax& handler);
@@ -90,7 +89,7 @@ namespace wire
//! \throw wire::exception if next token not `[`.
std::size_t start_array() override final;
std::size_t start_array(std::size_t) override final;
//! Skips whitespace to next token. \return True if next token is eof or ']'.
bool is_array_end(std::size_t count) override final;

View File

@@ -43,8 +43,12 @@ namespace error
return "Unable to encode integer in msgpack";
case msgpack::invalid:
return "Invalid msgpack encoding";
case msgpack::max_tree_size:
return "Exceeded tag tracking amount";
case msgpack::not_enough_bytes:
return "Expected more bytes in the msgpack stream";
case msgpack::underflow_tree:
return "Expected more tags";
}
return "Unknown msgpack error";

View File

@@ -40,7 +40,9 @@ namespace error
incomplete,
integer_encoding,
invalid,
not_enough_bytes
max_tree_size,
not_enough_bytes,
underflow_tree
};
//! \return Static string describing error `value`.

View File

@@ -77,7 +77,7 @@ namespace
//! \return Integer `T` encoded as big endian in `source`.
template<typename T>
T read_endian(epee::byte_slice& source)
T read_endian(epee::span<const std::uint8_t>& source)
{
static_assert(std::is_integral<T>::value, "must be integral type");
static constexpr const std::size_t bits = 8 * sizeof(T);
@@ -95,12 +95,12 @@ namespace
//! \return Integer `T` encoded as big endian in `source`.
template<typename T, wire::msgpack::tag U>
T read_endian(epee::byte_slice& source, const wire::msgpack::type<T, U>)
T read_endian(epee::span<const std::uint8_t>& source, const wire::msgpack::type<T, U>)
{ return read_endian<T>(source); }
//! \return Integer `T` whose encoding is specified by tag `next`
template<typename T>
T read_integer(epee::byte_slice& source, const wire::msgpack::tag next)
T read_integer(epee::span<const std::uint8_t>& source, const wire::msgpack::tag next)
{
try
{
@@ -135,20 +135,21 @@ namespace
WIRE_DLOG_THROW_(wire::error::schema::integer);
}
epee::byte_slice read_raw(epee::byte_slice& source, const std::size_t bytes)
epee::span<const std::uint8_t> read_raw(epee::span<const std::uint8_t>& source, const std::size_t bytes)
{
if (source.size() < bytes)
WIRE_DLOG_THROW_(wire::error::msgpack::not_enough_bytes);
return source.take_slice(bytes);
const std::size_t actual = source.remove_prefix(bytes);
return {source.data() - actual, actual};
}
template<typename T>
epee::byte_slice read_raw(epee::byte_slice& source)
epee::span<const std::uint8_t> read_raw(epee::span<const std::uint8_t>& source)
{
return read_raw(source, wire::integer::cast_unsigned<std::size_t>(read_endian<T>(source)));
}
epee::byte_slice read_string(epee::byte_slice& source, const wire::msgpack::tag next)
epee::span<const std::uint8_t> read_string(epee::span<const std::uint8_t>& source, const wire::msgpack::tag next)
{
switch (next)
{
@@ -170,7 +171,7 @@ namespace
}
//! \return Binary blob encoded message
epee::byte_slice read_binary(epee::byte_slice& source, const wire::msgpack::tag next)
epee::span<const std::uint8_t> read_binary(epee::span<const std::uint8_t>& source, const wire::msgpack::tag next)
{
switch (next)
{
@@ -189,21 +190,21 @@ namespace
namespace wire
{
void msgpack_reader::throw_logic_error()
void msgpack_reader::throw_wire_exception()
{
throw std::logic_error{"Bug in msgpack_reader usage"};
WIRE_DLOG_THROW_(error::msgpack::underflow_tree);
}
void msgpack_reader::skip_value()
{
assert(remaining_);
if (limits<std::size_t>::max() == remaining_)
throw std::runtime_error{"msgpack_reader exceeded tree tracking"};
assert(tags_remaining_);
if (limits<std::size_t>::max() == tags_remaining_)
WIRE_DLOG_THROW_(error::msgpack::max_tree_size);
const std::size_t initial = remaining_;
const std::size_t initial = tags_remaining_;
do
{
const std::size_t size = source_.size();
const std::size_t size = remaining_.size();
const msgpack::tag next = peek_tag();
switch (next)
{
@@ -213,59 +214,59 @@ namespace wire
case msgpack::tag::unused:
case msgpack::tag::False:
case msgpack::tag::True:
source_.remove_prefix(1);
remaining_.remove_prefix(1);
break;
case msgpack::tag::binary8:
case msgpack::tag::binary16:
case msgpack::tag::binary32:
source_.remove_prefix(1);
read_binary(source_, next);
remaining_.remove_prefix(1);
read_binary(remaining_, next);
break;
case msgpack::tag::extension8:
source_.remove_prefix(1);
read_raw<std::uint8_t>(source_);
source_.remove_prefix(1);
remaining_.remove_prefix(1);
read_raw<std::uint8_t>(remaining_);
remaining_.remove_prefix(1);
break;
case msgpack::tag::extension16:
source_.remove_prefix(1);
read_raw<std::uint16_t>(source_);
source_.remove_prefix(1);
remaining_.remove_prefix(1);
read_raw<std::uint16_t>(remaining_);
remaining_.remove_prefix(1);
break;
case msgpack::tag::extension32:
source_.remove_prefix(1);
read_raw<std::uint32_t>(source_);
source_.remove_prefix(1);
remaining_.remove_prefix(1);
read_raw<std::uint32_t>(remaining_);
remaining_.remove_prefix(1);
break;
case msgpack::tag::int8:
case msgpack::tag::uint8:
source_.remove_prefix(2);
remaining_.remove_prefix(2);
break;
case msgpack::tag::int16:
case msgpack::tag::uint16:
case msgpack::tag::fixed_extension1:
source_.remove_prefix(3);
remaining_.remove_prefix(3);
break;
case msgpack::tag::int32:
case msgpack::tag::uint32:
case msgpack::tag::float32:
source_.remove_prefix(5);
remaining_.remove_prefix(5);
break;
case msgpack::tag::int64:
case msgpack::tag::uint64:
case msgpack::tag::float64:
source_.remove_prefix(9);
remaining_.remove_prefix(9);
break;
case msgpack::tag::fixed_extension2:
source_.remove_prefix(4);
remaining_.remove_prefix(4);
break;
case msgpack::tag::fixed_extension4:
source_.remove_prefix(6);
remaining_.remove_prefix(6);
break;
case msgpack::tag::fixed_extension8:
source_.remove_prefix(10);
remaining_.remove_prefix(10);
break;
case msgpack::tag::fixed_extension16:
source_.remove_prefix(18);
remaining_.remove_prefix(18);
break;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wswitch"
@@ -273,8 +274,8 @@ namespace wire
case msgpack::tag::string8:
case msgpack::tag::string16:
case msgpack::tag::string32:
source_.remove_prefix(1);
read_string(source_, next);
remaining_.remove_prefix(1);
read_string(remaining_, next);
break;
case msgpack::tag(0x90): case msgpack::tag(0x91): case msgpack::tag(0x92):
case msgpack::tag(0x93): case msgpack::tag(0x94): case msgpack::tag(0x95):
@@ -284,7 +285,7 @@ namespace wire
case msgpack::tag(0x9f):
case msgpack::tag::array16:
case msgpack::tag::array32:
start_array();
start_array(0);
break;
case msgpack::tag(0x80): case msgpack::tag(0x81): case msgpack::tag(0x82):
case msgpack::tag(0x83): case msgpack::tag(0x84): case msgpack::tag(0x85):
@@ -299,27 +300,27 @@ namespace wire
#pragma GCC diagnostic pop
};
if (size == source_.size())
if (size == remaining_.size())
{
if (!msgpack::ftag_unsigned::matches(next) && !msgpack::ftag_signed::matches(next))
WIRE_DLOG_THROW_(error::msgpack::invalid);
source_.remove_prefix(1);
remaining_.remove_prefix(1);
}
update_remaining();
} while (initial <= remaining_);
update_tags_remaining();
} while (initial <= tags_remaining_);
}
msgpack::tag msgpack_reader::peek_tag()
{
if (source_.empty())
if (remaining_.empty())
WIRE_DLOG_THROW_(error::msgpack::not_enough_bytes);
return msgpack::tag(*source_.data());
return msgpack::tag(*remaining_.data());
}
msgpack::tag msgpack_reader::get_tag()
{
const msgpack::tag next = peek_tag();
source_.remove_prefix(1);
remaining_.remove_prefix(1);
return next;
}
@@ -327,12 +328,12 @@ namespace wire
{
if (msgpack::ftag_signed::matches(next))
return *reinterpret_cast<const std::int8_t*>(std::addressof(next)); // special case
return read_integer<std::intmax_t>(source_, next);
return read_integer<std::intmax_t>(remaining_, next);
}
std::uintmax_t msgpack_reader::do_unsigned_integer(const msgpack::tag next)
{
return read_integer<std::uintmax_t>(source_, next);
return read_integer<std::uintmax_t>(remaining_, next);
}
template<typename T, typename U>
@@ -347,7 +348,7 @@ namespace wire
{
if (type.Tag() == next)
{
out = integer::cast_unsigned<std::size_t>(read_endian(source_, type));
out = integer::cast_unsigned<std::size_t>(read_endian(remaining_, type));
return true;
}
return false;
@@ -361,13 +362,13 @@ namespace wire
void msgpack_reader::check_complete() const
{
if (remaining_)
if (tags_remaining_)
WIRE_DLOG_THROW_(error::msgpack::incomplete);
}
bool msgpack_reader::boolean()
{
update_remaining();
update_tags_remaining();
switch (get_tag())
{
case msgpack::tag::True:
@@ -382,14 +383,14 @@ namespace wire
double msgpack_reader::real()
{
update_remaining();
update_tags_remaining();
const auto read_float = [this](auto value)
{
if (source_.size() < sizeof(value))
if (remaining_.size() < sizeof(value))
WIRE_DLOG_THROW_(error::msgpack::not_enough_bytes);
std::memcpy(std::addressof(value), source_.data(), sizeof(value));
source_.remove_prefix(sizeof(value));
std::memcpy(std::addressof(value), remaining_.data(), sizeof(value));
remaining_.remove_prefix(sizeof(value));
return value;
};
@@ -407,34 +408,38 @@ namespace wire
std::string msgpack_reader::string()
{
update_remaining();
const epee::byte_slice bytes = read_string(source_, get_tag());
update_tags_remaining();
const epee::span<const std::uint8_t> bytes = read_string(remaining_, get_tag());
return std::string{reinterpret_cast<const char*>(bytes.data()), bytes.size()};
}
std::vector<std::uint8_t> msgpack_reader::binary()
{
update_remaining();
const epee::byte_slice bytes = read_binary(source_, get_tag());
update_tags_remaining();
const epee::span<const std::uint8_t> bytes = read_binary(remaining_, get_tag());
return std::vector<std::uint8_t>{bytes.begin(), bytes.end()};
}
void msgpack_reader::binary(epee::span<std::uint8_t> dest)
{
update_remaining();
const epee::byte_slice bytes = read_binary(source_, get_tag());
update_tags_remaining();
const epee::span<const std::uint8_t> bytes = read_binary(remaining_, get_tag());
if (dest.size() != bytes.size())
WIRE_DLOG_THROW(error::schema::fixed_binary, "of size " << dest.size() << " but got " << bytes.size());
std::memcpy(dest.data(), bytes.data(), dest.size());
}
std::size_t msgpack_reader::start_array()
std::size_t msgpack_reader::start_array(const std::size_t min_element_size)
{
const std::size_t upcoming =
read_count<msgpack::ftag_array, msgpack::array_types>(error::schema::array);
if (limits<std::size_t>::max() - remaining_ < upcoming)
throw std::runtime_error{"Exceeded max tree tracking for msgpack_reader"};
remaining_ += upcoming;
if (limits<std::size_t>::max() - tags_remaining_ < upcoming)
WIRE_DLOG_THROW_(error::msgpack::max_tree_size);
if (min_element_size && (remaining_.size() / min_element_size) < upcoming)
WIRE_DLOG_THROW(error::schema::array, upcoming << " array elements of at least " << min_element_size << " bytes each exceeds " << remaining_.size() << " remaining bytes");
tags_remaining_ += upcoming;
increment_depth();
return upcoming;
}
@@ -442,7 +447,7 @@ namespace wire
{
if (count)
return false;
update_remaining();
update_tags_remaining();
return true;
}
@@ -451,10 +456,11 @@ namespace wire
const std::size_t upcoming =
read_count<msgpack::ftag_object, msgpack::object_types>(error::schema::object);
if (limits<std::size_t>::max() / 2 < upcoming)
throw std::runtime_error{"Exceeded max object tracking for msgpack_reader"};
if (limits<std::size_t>::max() - remaining_ < upcoming * 2)
throw std::runtime_error{"Exceeded msgpack_reader:: tree tracking"};
remaining_ += upcoming * 2;
WIRE_DLOG_THROW_(error::msgpack::max_tree_size);
if (limits<std::size_t>::max() - tags_remaining_ < upcoming * 2)
WIRE_DLOG_THROW_(error::msgpack::max_tree_size);
tags_remaining_ += upcoming * 2;
increment_depth();
return upcoming;
}
@@ -463,14 +469,14 @@ namespace wire
index = map.size();
for ( ;state; --state)
{
update_remaining(); // for key
update_tags_remaining(); // for key
const msgpack::tag next = get_tag();
const bool single = msgpack::ftag_unsigned::matches(next);
if (single || matches<msgpack::unsigned_types>(next))
{
unsigned key = std::uint8_t(next);
if (!single)
key = read_integer<unsigned>(source_, next);
key = read_integer<unsigned>(remaining_, next);
for (const key_map& elem : map)
{
if (elem.id == key)
@@ -482,7 +488,7 @@ namespace wire
}
else if (msgpack::ftag_string::matches(next) || matches<msgpack::string_types>(next))
{
const epee::byte_slice key = read_string(source_, next);
const epee::span<const std::uint8_t> key = read_string(remaining_, next);
for (const key_map& elem : map)
{
const boost::string_ref elem_{elem.name};
@@ -503,7 +509,7 @@ namespace wire
}
skip_value();
} // until state == 0
update_remaining(); // for end of object
update_tags_remaining(); // for end of object
return false;
}
}

View File

@@ -45,30 +45,30 @@ namespace wire
class msgpack_reader : public reader
{
epee::byte_slice source_;
std::size_t remaining_; //!< Expected number of elements remaining
std::size_t tags_remaining_; //!< Expected number of elements remaining
//! \throw std::logic_error
[[noreturn]] void throw_logic_error();
//! Decrement remaining_ if not zero, \throw std::logic_error when `remaining_ == 0`.
void update_remaining()
//! \throw wire::exception with `error::msgpack::underflow_tree`
[[noreturn]] void throw_wire_exception();
//! Decrement tags_remaining_ if not zero, \throw std::logic_error when `tags_remaining_ == 0`.
void update_tags_remaining()
{
if (remaining_)
--remaining_;
if (tags_remaining_)
--tags_remaining_;
else
throw_logic_error();
throw_wire_exception();
}
//! Skips next value. \throw wire::exception if invalid JSON syntax.
void skip_value();
//! \return Next tag but leave `source_` untouched.
//! \return Next tag but leave `remaining_` untouched.
msgpack::tag peek_tag();
//! \return Next tag and remove first byte from `source_`.
//! \return Next tag and remove first byte from `remaining_`.
msgpack::tag get_tag();
//! \return Integer from `soure_` where positive fixed tag has been checked.
//! \return Integer from `remaining_` where positive fixed tag has been checked.
std::intmax_t do_integer(msgpack::tag);
//! \return Integer from `source_` where fixed tag has been checked.
//! \return Integer from `remaining_` where fixed tag has been checked.
std::uintmax_t do_unsigned_integer(msgpack::tag);
//! \return Number of items determined by `T` fixed tag and `U` tuple of tags.
@@ -77,8 +77,10 @@ namespace wire
public:
explicit msgpack_reader(epee::byte_slice&& source)
: reader(), source_(std::move(source)), remaining_(1)
{}
: reader(nullptr), source_(std::move(source)), tags_remaining_(1)
{
remaining_ = {source_.data(), source_.size()};
}
//! \throw wire::exception if JSON parsing is incomplete.
void check_complete() const override final;
@@ -89,7 +91,7 @@ namespace wire
//! \throw wire::expception if next token not an integer.
std::intmax_t integer() override final
{
update_remaining();
update_tags_remaining();
const msgpack::tag next = get_tag();
if (std::uint8_t(next) <= msgpack::ftag_unsigned::max())
return std::uint8_t(next);
@@ -99,7 +101,7 @@ namespace wire
//! \throw wire::exception if next token not an unsigned integer.
std::uintmax_t unsigned_integer() override final
{
update_remaining();
update_tags_remaining();
const msgpack::tag next = get_tag();
if (std::uint8_t(next) <= msgpack::ftag_unsigned::max())
return std::uint8_t(next);
@@ -120,7 +122,7 @@ namespace wire
//! \throw wire::exception if next token not `[`.
std::size_t start_array() override final;
std::size_t start_array(std::size_t min_element_size) override final;
//! \return true when `count == 0`.
bool is_array_end(const std::size_t count) override final;

View File

@@ -35,6 +35,13 @@ void wire::reader::increment_depth()
WIRE_DLOG_THROW_(error::schema::maximum_depth);
}
void wire::reader::decrement_depth()
{
if (!depth_)
throw std::logic_error{"reader::decrement_depth() already at zero"};
--depth_;
}
[[noreturn]] void wire::integer::throw_exception(std::intmax_t source, std::intmax_t min, std::intmax_t max)
{
static_assert(

View File

@@ -44,19 +44,27 @@ namespace wire
{
//! Interface for converting "wire" (byte) formats to C/C++ objects without a DOM.
class reader
{
{
std::size_t depth_; //!< Tracks number of recursive objects and arrays
protected:
//! \throw wire::exception if max depth is reached
void increment_depth();
void decrement_depth() noexcept { --depth_; }
epee::span<const std::uint8_t> remaining_; //!< Derived class tracks unprocessed bytes here
reader(const epee::span<const std::uint8_t> remaining) noexcept
: depth_(0), remaining_(remaining)
{}
reader(const reader&) = default;
reader(reader&&) = default;
reader& operator=(const reader&) = default;
reader& operator=(reader&&) = default;
//! \throw wire::exception if max depth is reached
void increment_depth();
//! \throw std::logic_error if already `depth() == 0`.
void decrement_depth();
public:
struct key_map
{
@@ -70,16 +78,15 @@ namespace wire
//! \return Assume delimited arrays in generic interface (some optimizations disabled)
static constexpr std::true_type delimited_arrays() noexcept { return {}; }
reader() noexcept
: depth_(0)
{}
virtual ~reader() noexcept
{}
//! \return Number of recursive objects and arrays
std::size_t depth() const noexcept { return depth_; }
//! \return Unprocessed bytes
epee::span<const std::uint8_t> remaining() const noexcept { return remaining_; }
//! \throw wire::exception if parsing is incomplete.
virtual void check_complete() const = 0;
@@ -104,14 +111,20 @@ namespace wire
//! \throw wire::exception if next value cannot be read as binary into `dest`.
virtual void binary(epee::span<std::uint8_t> dest) = 0;
//! \throw wire::exception if next value not array
virtual std::size_t start_array() = 0;
/* \param min_element_size of each array element in any format - if known.
Derived types with explicit element count should verify available
space, and throw a `wire::exception` on issues.
\throw wire::exception if next value not array
\throw wire::exception if not enough bytes for all array elements
(with epee/msgpack which has specified number of elements).
\return Number of values to read before calling `is_array_end()` */
virtual std::size_t start_array(std::size_t min_element_size) = 0;
//! \return True if there is another element to read.
virtual bool is_array_end(std::size_t count) = 0;
//! \throw wire::exception if array end delimiter not present.
void end_array() noexcept { decrement_depth(); }
void end_array() { decrement_depth(); }
//! \throw wire::exception if not object begin. \return State to be given to `key(...)` function.
@@ -134,7 +147,7 @@ namespace wire
*/
virtual bool key(epee::span<const key_map> map, std::size_t& state, std::size_t& index) = 0;
void end_object() noexcept { decrement_depth(); }
void end_object() { decrement_depth(); }
};
template<typename R>
@@ -247,28 +260,84 @@ namespace wire_read
return {};
}
// Trap objects that do not have standard insertion functions
template<typename R, typename... T>
void array_insert(const R&, const T&...) noexcept
{
static_assert(std::is_same<R, void>::value, "type T does not have a valid insertion function");
}
// Insert to sorted containers
template<typename R, typename T, typename V = typename T::value_type>
inline auto array_insert(R& source, T& dest) -> decltype(dest.emplace_hint(dest.end(), std::declval<V>()), bool(true))
{
V val{};
wire_read::bytes(source, val);
dest.emplace_hint(dest.end(), std::move(val));
return true;
}
// Insert into unsorted containers
template<typename R, typename T, typename V = typename T::value_type>
inline auto array_insert(R& source, T& dest) -> decltype(dest.emplace_back(), dest.back(), bool(true))
{
// more efficient to process the object in-place in many cases
dest.emplace_back();
wire_read::bytes(source, dest.back());
return true;
}
// no compile-time checks for the array constraints
template<typename R, typename T>
inline void array(R& source, T& dest)
inline void array_unchecked(R& source, T& dest, const std::size_t min_element_size, const std::size_t max_element_count)
{
using value_type = typename T::value_type;
static_assert(!std::is_same<value_type, char>::value, "read array of chars as binary");
static_assert(!std::is_same<value_type, char>::value, "read array of chars as string");
static_assert(!std::is_same<value_type, std::int8_t>::value, "read array of signed chars as binary");
static_assert(!std::is_same<value_type, std::uint8_t>::value, "read array of unsigned chars as binary");
std::size_t count = source.start_array();
std::size_t count = source.start_array(min_element_size);
// quick check for epee/msgpack formats
if (max_element_count < count)
throw_exception(wire::error::schema::array_max_element, "", nullptr);
// also checked by derived formats when count is known
if (min_element_size && (source.remaining().size() / min_element_size) < count)
throw_exception(wire::error::schema::array_min_size, "", nullptr);
dest.clear();
wire::reserve(dest, count);
bool more = count;
const std::size_t start_bytes = source.remaining().size();
while (more || !source.is_array_end(count))
{
dest.emplace_back();
read_bytes(source, dest.back());
// check for json/cbor formats
if (source.delimited_arrays() && max_element_count <= dest.size())
throw_exception(wire::error::schema::array_max_element, "", nullptr);
wire_read::array_insert(source, dest);
--count;
more &= bool(count);
if (((start_bytes - source.remaining().size()) / dest.size()) < min_element_size)
throw_exception(wire::error::schema::array_min_size, "", nullptr);
}
return source.end_array();
source.end_array();
}
template<typename R, typename T, std::size_t M, std::size_t N = std::numeric_limits<std::size_t>::max()>
inline void array(R& source, T& dest, wire::min_element_size<M> min_element_size, wire::max_element_count<N> max_element_count = {})
{
using value_type = typename T::value_type;
static_assert(
min_element_size.template check<value_type>() || max_element_count.template check<value_type>(),
"array unpacking memory issues"
);
// each set of template args generates unique ASM, merge them down
array_unchecked(source, dest, min_element_size, max_element_count);
}
template<typename T, unsigned I>
@@ -413,7 +482,14 @@ namespace wire
template<typename R, typename T>
inline std::enable_if_t<is_array<T>::value> read_bytes(R& source, T& dest)
{
wire_read::array(source, dest);
static constexpr const std::size_t wire_size =
default_min_element_size<R, typename T::value_type>::value;
static_assert(
wire_size != 0,
"no sane default array constraints for the reader / value_type pair"
);
wire_read::array(source, dest, min_element_size<wire_size>{});
}
template<typename R, typename... T>

View File

@@ -84,6 +84,52 @@ namespace wire
: is_array<T> // all array types in old output engine were optional when empty
{};
//! A constraint for `wire_read::array` where a max of `N` elements can be read.
template<std::size_t N>
struct max_element_count
: std::integral_constant<std::size_t, N>
{
// The threshold is low - min_element_size is a better constraint metric
static constexpr std::size_t max_bytes() noexcept { return 512 * 1024; } // 512 KiB
//! \return True if `N` C++ objects of type `T` are below `max_bytes()` threshold.
template<typename T>
static constexpr bool check() noexcept
{
return N <= (max_bytes() / sizeof(T));
}
};
//! A constraint for `wire_read::array` where each element must use at least `N` bytes on the wire.
template<std::size_t N>
struct min_element_size
: std::integral_constant<std::size_t, N>
{
static constexpr std::size_t max_ratio() noexcept { return 4; }
//! \return True if C++ object of type `T` with minimum wire size `N` is below `max_ratio()`.
template<typename T>
static constexpr bool check() noexcept
{
return N != 0 ? ((sizeof(T) / N) <= max_ratio()) : false;
}
};
/*! Trait used in `wire/read.h` for default `min_element_size` behavior based
on an array of `T` objects and `R` reader type. This trait can be used
instead of the `wire::array(...)` (and associated macros) functionality, as
it sets a global value. The last argument is for `enable_if`. */
template<typename R, typename T, typename = void>
struct default_min_element_size
: std::integral_constant<std::size_t, 0>
{};
//! If `T` is a blob, a safe default for all formats is the size of the blob
template<typename R, typename T>
struct default_min_element_size<R, T, std::enable_if_t<is_blob<T>::value>>
: std::integral_constant<std::size_t, sizeof(T)>
{};
// example usage : `wire::sum(std::size_t(wire::available(fields))...)`
inline constexpr int sum() noexcept
@@ -96,6 +142,9 @@ namespace wire
return head + sum(tail...);
}
template<typename... T>
using min_element_sizeof = min_element_size<sum(sizeof(T)...)>;
//! If container has no `reserve(0)` function, this function is used
template<typename... T>
inline void reserve(const T&...) noexcept

View File

@@ -44,7 +44,25 @@ namespace wire
{
// see constraints directly above `array_` definition
static_assert(std::is_same<R, void>::value, "array_ must have a read constraint for memory purposes");
wire_read::array(source, wrapper.get_read_object());
}
template<typename R, typename T, std::size_t N>
inline void read_bytes(R& source, array_<T, max_element_count<N>>& wrapper)
{
using array_type = array_<T, max_element_count<N>>;
using value_type = typename array_type::value_type;
using constraint = typename array_type::constraint;
static_assert(constraint::template check<value_type>(), "max reserve bytes exceeded for element");
wire_read::array(source, wrapper.get_read_object(), min_element_size<0>{}, constraint{});
}
template<typename R, typename T, std::size_t N>
inline void read_bytes(R& source, array_<T, min_element_size<N>>& wrapper)
{
using array_type = array_<T, min_element_size<N>>;
using value_type = typename array_type::value_type;
using constraint = typename array_type::constraint;
static_assert(constraint::template check<value_type>(), "max compression ratio exceeded for element");
wire_read::array(source, wrapper.get_read_object(), constraint{});
}
template<typename W, typename T, typename C>