Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checks for invalid fields & index bounds #9

Merged
merged 4 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions examples/ulog_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ int main(int argc, char** argv)
for (const auto& sub : subscription_names) {
std::cout << sub << "\n";
}
std::cout << "\n";

// Get a particular subscription
if (subscription_names.find("vehicle_status") != subscription_names.end()) {
Expand All @@ -62,30 +63,34 @@ int main(int argc, char** argv)
std::cout << "Field names: "
<< "\n";
for (const std::string& field : subscription->fieldNames()) {
std::cout << field << "\n";
std::cout << " " << field << "\n";
}

// Get particular field
auto nav_state_field = subscription->field("nav_state");
try {
auto nav_state_field = subscription->field("nav_state");

// Iterate over all samples
std::cout << "nav_state values: \n";
for (const auto& sample : *subscription) {
// always correctly extracts the type as defined in the message definition,
// gets cast to the value you put in int.
// This also works for arrays and strings.
auto nav_state = sample[nav_state_field].as<int>();
std::cout << nav_state << ", ";
}
std::cout << "\n";
// Iterate over all samples
std::cout << "nav_state values: \n ";
for (const auto& sample : *subscription) {
// always correctly extracts the type as defined in the message definition,
// gets cast to the value you put in int.
// This also works for arrays and strings.
auto nav_state = sample[nav_state_field].as<int>();
std::cout << nav_state << ", ";
}
std::cout << "\n";

// get a specific sample
auto sample_12 = subscription->at(12);
// get a specific sample
auto sample_12 = subscription->at(12);

// access values by name
auto timestamp = sample_12["timestamp"].as<uint64_t>();
// access values by name
auto timestamp = sample_12["timestamp"].as<uint64_t>();

std::cout << timestamp << "\n";
std::cout << "timestamp at sample 12: " << timestamp << "\n";
} catch (const ulog_cpp::AccessException& exception) {
std::cout << "AccessException: " << exception.what() << "\n";
}
} else {
std::cout << "No vehicle_status subscription found\n";
}
Expand All @@ -95,16 +100,20 @@ int main(int argc, char** argv)
const auto& message_format = data_container->messageFormats().at("esc_status");
std::cout << "Message format: " << message_format->name() << "\n";
for (const auto& field_name : message_format->fieldNames()) {
std::cout << field_name << "\n";
std::cout << " " << field_name << "\n";
}
} else {
std::cout << "No esc_status message format found\n";
}

if (subscription_names.find("esc_status") != subscription_names.end()) {
auto esc_status = data_container->subscription("esc_status");
for (const auto& sample : *esc_status) {
std::cout << "esc_power: " << sample["esc"][7]["esc_power"].as<int>() << "\n";
try {
auto esc_status = data_container->subscription("esc_status");
for (const auto& sample : *esc_status) {
std::cout << "esc_power: " << sample["esc"][7]["esc_power"].as<int>() << "\n";
}
} catch (const ulog_cpp::AccessException& exception) {
std::cout << "AccessException: " << exception.what() << "\n";
}
} else {
std::cout << "No esc_status subscription found\n";
Expand Down
10 changes: 8 additions & 2 deletions test/read_api_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ TEST_CASE("Write complicated, nested data format, then read it")

const uint64_t t00 = 0xdeadbeefdeadbeef;
const int32_t t01 = -123456;
const char t02[] = "Hello World!";
const char t02[] = "Hello World!----";
const double t03 = 3.14159265358979323846;
const uint32_t t04 = 0xdeadbeef;
const char t05 = 'a';
const char t06[] = "Hello World! 2";
const char t06[] = "Hello World! 2----";
const int32_t t07 = 123456;
const uint8_t t08 = 0x12;
const uint8_t t09 = 0x34;
Expand Down Expand Up @@ -259,6 +259,12 @@ TEST_CASE("Write complicated, nested data format, then read it")
sample[f_child_1][f_c1_c1_2][2][f_c1_c1_2_byte_b].asNativeTypeVariant()));
CHECK(std::holds_alternative<std::vector<uint64_t>>(
sample[f_child_1][f_c1_unsigned_long].asNativeTypeVariant()));

// Check exceptions
CHECK_THROWS_AS(sample["non_existent"], ulog_cpp::AccessException);
CHECK_THROWS_AS(sample[f_child_1][f_c1_unsigned_long][100], ulog_cpp::AccessException);
CHECK_THROWS_AS(data_container->subscription("non_existent_subscription"),
ulog_cpp::AccessException);
}

TEST_SUITE_END();
13 changes: 6 additions & 7 deletions ulog_cpp/data_container.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,13 @@ class DataContainer : public DataHandlerInterface {
return names;
}

std::shared_ptr<Subscription> subscription(const std::string& name, int multi_id) const
std::shared_ptr<Subscription> subscription(const std::string& name, int multi_id = 0) const
{
return _subscriptions_by_name_and_multi_id.at({name, multi_id});
}

std::shared_ptr<Subscription> subscription(const std::string& name) const
{
return _subscriptions_by_name_and_multi_id.at({name, 0});
const auto it = _subscriptions_by_name_and_multi_id.find({name, multi_id});
if (it == _subscriptions_by_name_and_multi_id.end()) {
throw AccessException("Subscription not found: " + name);
}
return it->second;
}

protected:
Expand Down
8 changes: 8 additions & 0 deletions ulog_cpp/exception.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,12 @@ class UsageException : public ExceptionBase {
explicit UsageException(std::string reason) : ExceptionBase(std::move(reason)) {}
};

/**
* Some field/subscription does not exist or index is out of range
*/
class AccessException : public ExceptionBase {
public:
explicit AccessException(std::string reason) : ExceptionBase(std::move(reason)) {}
};

} // namespace ulog_cpp
40 changes: 20 additions & 20 deletions ulog_cpp/messages.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ void Field::resolveDefinition(int offset)
std::shared_ptr<MessageFormat> Field::nestedFormat() const
{
if (_type.type != Field::BasicType::NESTED) {
throw ParsingException("Not a nested type");
throw AccessException("Not a nested type");
}
return _type.nested_message;
}

std::shared_ptr<Field> Field::nestedField(const std::string& name) const
{
if (_type.type != Field::BasicType::NESTED) {
throw ParsingException("Not a nested type");
throw AccessException("Not a nested type");
}
return _type.nested_message->field(name);
}
Expand Down Expand Up @@ -190,7 +190,7 @@ std::string Field::encode() const
Value::NativeTypeVariant Value::asNativeTypeVariant() const
{
if (_array_index >= 0 && _field_ref.arrayLength() < 0) {
throw ParsingException("Can not access array element of non-array field");
throw AccessException("Can not access array element of non-array field");
}

if (_field_ref.arrayLength() == -1 || _array_index >= 0) {
Expand Down Expand Up @@ -235,7 +235,7 @@ Value::NativeTypeVariant Value::asNativeTypeVariant() const
return deserialize<char>(_backing_ref_begin, _backing_ref_end, _field_ref.offsetInMessage(),
array_offset);
case Field::BasicType::NESTED:
throw ParsingException("Can't get nested field as basic type. Field " + _field_ref.name());
throw AccessException("Can't get nested field as basic type. Field " + _field_ref.name());
}
} else {
// decode as an array
Expand Down Expand Up @@ -276,14 +276,14 @@ Value::NativeTypeVariant Value::asNativeTypeVariant() const
case Field::BasicType::CHAR: {
auto string_start_iterator = _backing_ref_begin + _field_ref.offsetInMessage();
if (_backing_ref_end - string_start_iterator < _field_ref.arrayLength()) {
throw ParsingException("Decoding fault, memory too short");
throw AccessException("Decoding fault, memory too short");
}
int string_length = strnlen(string_start_iterator.base(), _field_ref.arrayLength());
return std::string(string_start_iterator, string_start_iterator + string_length);
}

case Field::BasicType::NESTED:
throw ParsingException("Can't get nested field as basic type. Field " + _field_ref.name());
throw AccessException("Can't get nested field as basic type. Field " + _field_ref.name());
}
}
return deserialize<uint8_t>(_backing_ref_begin, _backing_ref_end, _field_ref.offsetInMessage(),
Expand All @@ -293,10 +293,10 @@ Value::NativeTypeVariant Value::asNativeTypeVariant() const
Value Value::operator[](const Field& field) const
{
if (_field_ref.type().type != Field::BasicType::NESTED) {
throw ParsingException("Cannot access field of non-nested type");
throw AccessException("Cannot access field of non-nested type");
}
if (!_field_ref.definitionResolved()) {
throw ParsingException("Cannot access field of unresolved type");
throw AccessException("Cannot access field of unresolved type");
}
int submessage_offset = _field_ref.offsetInMessage() +
((_array_index >= 0) ? _field_ref.type().size * _array_index : 0);
Expand All @@ -312,10 +312,10 @@ Value Value::operator[](const std::shared_ptr<Field>& field) const
Value Value::operator[](const std::string& field_name) const
{
if (_field_ref.type().type != Field::BasicType::NESTED) {
throw ParsingException("Cannot access field of non-nested type");
throw AccessException("Cannot access field of non-nested type");
}
if (!_field_ref.definitionResolved()) {
throw ParsingException("Cannot access field of unresolved type");
throw AccessException("Cannot access field of unresolved type");
}
const auto& field = _field_ref.type().nested_message->field(field_name);
return operator[](*field);
Expand All @@ -324,10 +324,10 @@ Value Value::operator[](const std::string& field_name) const
Value Value::operator[](size_t index) const
{
if (_field_ref.arrayLength() < 0) {
throw ParsingException("Cannot access field of non-array type");
throw AccessException("Cannot access field of non-array type");
}
if (index >= static_cast<size_t>(_field_ref.arrayLength())) {
throw ParsingException("Index out of bounds");
throw AccessException("Index out of bounds");
}
return Value(_field_ref, _backing_ref_begin, _backing_ref_end, index);
}
Expand Down Expand Up @@ -410,8 +410,8 @@ MessageFormat::MessageFormat(const uint8_t* msg)
MessageFormat::MessageFormat(std::string name, const std::vector<Field>& fields)
: _name(std::move(name))
{
for (const auto& field : fields) {
auto f = std::make_shared<Field>(field);
for (const auto& current_field : fields) {
auto f = std::make_shared<Field>(current_field);
_fields.insert({f->name(), f});
_fields_ordered.push_back(f);
}
Expand All @@ -430,20 +430,20 @@ void MessageFormat::resolveDefinition(
const std::map<std::string, std::shared_ptr<MessageFormat>>& existing_formats) const
{
int offset = 0;
for (const auto& field : _fields_ordered) {
if (!field->definitionResolved()) {
field->resolveDefinition(existing_formats, offset);
for (const auto& current_field : _fields_ordered) {
if (!current_field->definitionResolved()) {
current_field->resolveDefinition(existing_formats, offset);
}
offset += field->sizeBytes();
offset += current_field->sizeBytes();
}
}

void MessageFormat::serialize(const DataWriteCB& writer) const
{
std::string format_str = _name + ':';

for (const auto& field : _fields_ordered) {
format_str += field->encode() + ';';
for (const auto& current_field : _fields_ordered) {
format_str += current_field->encode() + ';';
}

ulog_message_format_s format;
Expand Down
15 changes: 11 additions & 4 deletions ulog_cpp/messages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ class Value {
res = arg;
} else {
// one is string, the other is not
throw ParsingException("Assign strings and non-string types");
throw AccessException("Assign strings and non-string types");
}
} else if constexpr (is_vector<NativeType>::value) {
// this is natively a vector
Expand All @@ -369,7 +369,7 @@ class Value {
if (arg.size() > 0) {
res = staticCastEnsureUnsignedChar<ReturnType>(arg[0]);
} else {
throw ParsingException("Cannot convert empty vector to non-vector type");
throw AccessException("Cannot convert empty vector to non-vector type");
}
}
} else {
Expand Down Expand Up @@ -431,7 +431,7 @@ class Value {
int total_offset = offset + array_offset * sizeof(T);
if (backing_start > backing_end ||
backing_end - backing_start - total_offset < static_cast<int64_t>(sizeof(v))) {
throw ParsingException("Unexpected data type size");
throw AccessException("Unexpected data type size");
}
std::copy(backing_start + total_offset, backing_start + total_offset + sizeof(v),
reinterpret_cast<uint8_t*>(&v));
Expand Down Expand Up @@ -605,7 +605,14 @@ class MessageFormat {
* @param name the name of the field
* @return the requested field
*/
std::shared_ptr<Field> field(const std::string& name) const { return _fields.at(name); }
std::shared_ptr<Field> field(const std::string& name) const
{
const auto it = _fields.find(name);
if (it == _fields.end()) {
throw AccessException("Field not found: " + name);
}
return it->second;
}

private:
std::string _name;
Expand Down
8 changes: 7 additions & 1 deletion ulog_cpp/subscription.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,13 @@ class Subscription {
_message_format);
}

TypedDataView at(std::size_t n) const { return begin()[n]; }
TypedDataView at(std::size_t n) const
{
if (n >= size()) {
throw AccessException("Index out of range: " + std::to_string(n));
}
return begin()[n];
}

TypedDataView operator[](std::size_t n) { return at(n); }

Expand Down
4 changes: 2 additions & 2 deletions ulog_cpp/writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void Writer::messageInfo(const MessageInfo& message_info)
void Writer::messageFormat(const MessageFormat& message_format)
{
if (_header_complete) {
throw ParsingException("Header completed, cannot write formats");
throw UsageException("Header completed, cannot write formats");
}
message_format.serialize(_data_write_cb);
}
Expand All @@ -48,7 +48,7 @@ void Writer::parameterDefault(const ParameterDefault& parameter_default)
void Writer::addLoggedMessage(const AddLoggedMessage& add_logged_message)
{
if (!_header_complete) {
throw ParsingException("Header not yet completed, cannot write AddLoggedMessage");
throw UsageException("Header not yet completed, cannot write AddLoggedMessage");
}
add_logged_message.serialize(_data_write_cb);
}
Expand Down
Loading