diff --git a/db/db_impl_write.cc b/db/db_impl_write.cc index 8f625f839..9e83df861 100644 --- a/db/db_impl_write.cc +++ b/db/db_impl_write.cc @@ -1395,7 +1395,10 @@ Status DB::Put(const WriteOptions& opt, ColumnFamilyHandle* column_family, // 8 bytes are taken by header, 4 bytes for count, 1 byte for type, // and we allocate 11 extra bytes for key length, as well as value length. WriteBatch batch(key.size() + value.size() + 24); - batch.Put(column_family, key, value); + Status s = batch.Put(column_family, key, value); + if (!s.ok()) { + return s; + } return Write(opt, &batch); } @@ -1424,7 +1427,10 @@ Status DB::DeleteRange(const WriteOptions& opt, Status DB::Merge(const WriteOptions& opt, ColumnFamilyHandle* column_family, const Slice& key, const Slice& value) { WriteBatch batch; - batch.Merge(column_family, key, value); + Status s = batch.Merge(column_family, key, value); + if (!s.ok()) { + return s; + } return Write(opt, &batch); } } // namespace rocksdb diff --git a/db/db_test.cc b/db/db_test.cc index 61ddd7d58..5f2f62f99 100644 --- a/db/db_test.cc +++ b/db/db_test.cc @@ -480,6 +480,36 @@ TEST_F(DBTest, SingleDeletePutFlush) { kSkipUniversalCompaction | kSkipMergePut)); } +// Disable because not all platform can run it. +// It requires more than 9GB memory to run it, With single allocation +// of more than 3GB. +TEST_F(DBTest, DISABLED_SanitizeVeryVeryLargeValue) { + const size_t kValueSize = 4 * size_t{1024 * 1024 * 1024}; // 4GB value + std::string raw(kValueSize, 'v'); + Options options = CurrentOptions(); + options.env = env_; + options.merge_operator = MergeOperators::CreatePutOperator(); + options.write_buffer_size = 100000; // Small write buffer + options.paranoid_checks = true; + DestroyAndReopen(options); + + ASSERT_OK(Put("boo", "v1")); + ASSERT_TRUE(Put("foo", raw).IsInvalidArgument()); + ASSERT_TRUE(Merge("foo", raw).IsInvalidArgument()); + + WriteBatch wb; + ASSERT_TRUE(wb.Put("foo", raw).IsInvalidArgument()); + ASSERT_TRUE(wb.Merge("foo", raw).IsInvalidArgument()); + + Slice value_slice = raw; + Slice key_slice = "foo"; + SliceParts sp_key(&key_slice, 1); + SliceParts sp_value(&value_slice, 1); + + ASSERT_TRUE(wb.Put(sp_key, sp_value).IsInvalidArgument()); + ASSERT_TRUE(wb.Merge(sp_key, sp_value).IsInvalidArgument()); +} + // Disable because not all platform can run it. // It requires more than 9GB memory to run it, With single allocation // of more than 3GB. diff --git a/db/write_batch.cc b/db/write_batch.cc index 7963c326b..e5072b1d2 100644 --- a/db/write_batch.cc +++ b/db/write_batch.cc @@ -574,6 +574,13 @@ size_t WriteBatchInternal::GetFirstOffset(WriteBatch* b) { Status WriteBatchInternal::Put(WriteBatch* b, uint32_t column_family_id, const Slice& key, const Slice& value) { + if (key.size() > size_t{port::kMaxUint32}) { + return Status::InvalidArgument("key is too large"); + } + if (value.size() > size_t{port::kMaxUint32}) { + return Status::InvalidArgument("value is too large"); + } + LocalSavePoint save(b); WriteBatchInternal::SetCount(b, WriteBatchInternal::Count(b) + 1); if (column_family_id == 0) { @@ -596,8 +603,33 @@ Status WriteBatch::Put(ColumnFamilyHandle* column_family, const Slice& key, value); } +Status WriteBatchInternal::CheckSlicePartsLength(const SliceParts& key, + const SliceParts& value) { + size_t total_key_bytes = 0; + for (int i = 0; i < key.num_parts; ++i) { + total_key_bytes += key.parts[i].size(); + } + if (total_key_bytes >= size_t{port::kMaxUint32}) { + return Status::InvalidArgument("key is too large"); + } + + size_t total_value_bytes = 0; + for (int i = 0; i < value.num_parts; ++i) { + total_value_bytes += value.parts[i].size(); + } + if (total_value_bytes >= size_t{port::kMaxUint32}) { + return Status::InvalidArgument("value is too large"); + } + return Status::OK(); +} + Status WriteBatchInternal::Put(WriteBatch* b, uint32_t column_family_id, const SliceParts& key, const SliceParts& value) { + Status s = CheckSlicePartsLength(key, value); + if (!s.ok()) { + return s; + } + LocalSavePoint save(b); WriteBatchInternal::SetCount(b, WriteBatchInternal::Count(b) + 1); if (column_family_id == 0) { @@ -814,6 +846,13 @@ Status WriteBatch::DeleteRange(ColumnFamilyHandle* column_family, Status WriteBatchInternal::Merge(WriteBatch* b, uint32_t column_family_id, const Slice& key, const Slice& value) { + if (key.size() > size_t{port::kMaxUint32}) { + return Status::InvalidArgument("key is too large"); + } + if (value.size() > size_t{port::kMaxUint32}) { + return Status::InvalidArgument("value is too large"); + } + LocalSavePoint save(b); WriteBatchInternal::SetCount(b, WriteBatchInternal::Count(b) + 1); if (column_family_id == 0) { @@ -839,6 +878,11 @@ Status WriteBatch::Merge(ColumnFamilyHandle* column_family, const Slice& key, Status WriteBatchInternal::Merge(WriteBatch* b, uint32_t column_family_id, const SliceParts& key, const SliceParts& value) { + Status s = CheckSlicePartsLength(key, value); + if (!s.ok()) { + return s; + } + LocalSavePoint save(b); WriteBatchInternal::SetCount(b, WriteBatchInternal::Count(b) + 1); if (column_family_id == 0) { diff --git a/db/write_batch_internal.h b/db/write_batch_internal.h index aeaf0e1ea..9a200f3cb 100644 --- a/db/write_batch_internal.h +++ b/db/write_batch_internal.h @@ -138,6 +138,9 @@ class WriteBatchInternal { static Status SetContents(WriteBatch* batch, const Slice& contents); + static Status CheckSlicePartsLength(const SliceParts& key, + const SliceParts& value); + // Inserts batches[i] into memtable, for i in 0..num_batches-1 inclusive. // // If ignore_missing_column_families == true. WriteBatch diff --git a/port/port_posix.h b/port/port_posix.h index fe0d42644..2d2a7a79c 100644 --- a/port/port_posix.h +++ b/port/port_posix.h @@ -85,6 +85,7 @@ namespace rocksdb { namespace port { // For use at db/file_indexer.h kLevelMaxIndex +const uint32_t kMaxUint32 = std::numeric_limits::max(); const int kMaxInt32 = std::numeric_limits::max(); const uint64_t kMaxUint64 = std::numeric_limits::max(); const int64_t kMaxInt64 = std::numeric_limits::max(); diff --git a/port/win/port_win.h b/port/win/port_win.h index b16a70521..6da9f955d 100644 --- a/port/win/port_win.h +++ b/port/win/port_win.h @@ -92,6 +92,7 @@ namespace port { // therefore, use the same limits // For use at db/file_indexer.h kLevelMaxIndex +const uint32_t kMaxUint32 = UINT32_MAX; const int kMaxInt32 = INT32_MAX; const int64_t kMaxInt64 = INT64_MAX; const uint64_t kMaxUint64 = UINT64_MAX; @@ -107,6 +108,7 @@ const size_t kMaxSizet = UINT_MAX; #define ROCKSDB_NOEXCEPT noexcept // For use at db/file_indexer.h kLevelMaxIndex +const uint32_t kMaxUint32 = std::numeric_limits::max(); const int kMaxInt32 = std::numeric_limits::max(); const uint64_t kMaxUint64 = std::numeric_limits::max(); const int64_t kMaxInt64 = std::numeric_limits::max();