/** * Copyright 2012 Facebook * @author Tudor Bosman (tudorb@fb.com) */ #ifndef THRIFT_LIB_CPP_PROTOCOL_NEUTRONIUM_ENCODER_INL_H_ #define THRIFT_LIB_CPP_PROTOCOL_NEUTRONIUM_ENCODER_INL_H_ #ifndef THRIFT_INCLUDE_ENCODER_INL #error This file may only be included from Encoder.h #endif #include "folly/Conv.h" namespace apache { namespace thrift { namespace protocol { namespace neutronium { inline void Encoder::writeFieldBegin(const char* name, TType fieldType, int16_t fieldId) { auto& s = top(); DCHECK(s.state == IN_STRUCT); s.field = s.dataType->fields.at(fieldId); s.tag = fieldId; s.state = IN_FIELD; } inline void Encoder::writeFieldEnd() { auto& s = top(); DCHECK(s.state == DONE_FIELD); s.tag = 0; s.field.clear(); s.state = IN_STRUCT; } inline void Encoder::writeFieldStop() { auto& s = top(); DCHECK(s.state == IN_STRUCT); flush(); } inline bool Encoder::EncoderState::inDataState() const { return (state == IN_FIELD || state == IN_MAP_KEY || state == IN_MAP_VALUE || state == IN_LIST_VALUE || state == IN_SET_VALUE); } inline bool Encoder::EncoderState::inFlushableState() const { return (state == IN_STRUCT || state == IN_MAP_KEY || state == IN_LIST_VALUE || state == IN_SET_VALUE); } inline void Encoder::EncoderState::dataWritten() { switch (state) { case IN_FIELD: state = DONE_FIELD; break; case IN_MAP_KEY: state = IN_MAP_VALUE; field.type = dataType->valueType; break; case IN_MAP_VALUE: state = IN_MAP_KEY; field.type = dataType->mapKeyType; break; case IN_LIST_VALUE: break; case IN_SET_VALUE: break; default: LOG(FATAL) << "Invalid state " << state; } } template inline size_t findIndex(const C& container, T val) { auto pos = container.find(val); if (pos == container.end()) throw std::out_of_range("not found"); return pos - container.begin(); } inline void Encoder::EncoderState::markFieldSet() { // only useful for structs if (state != IN_FIELD || field.isRequired) return; optionalSet[findIndex(dataType->optionalFields, tag)] = true; } template inline void Encoder::EncoderState::checkType() const { throw TLibraryException(folly::to("Invalid type ", field.type)); } template inline void Encoder::EncoderState::checkType(reflection::Type t, Args... tail) const { if (t == reflection::getType(field.type)) return; checkType(tail...); } template inline void Encoder::EncoderState::appendToOutput(const Vec& vec) { size_t bytes = vec.size() * sizeof(typename Vec::value_type::second_type); for (auto& p : vec) { appender.writeBE(p.second); } bytesWritten += bytes; } inline void Encoder::writeMapBegin(TType keyType, TType valType, uint32_t size) { push(reflection::TYPE_MAP, topType(), size); } inline void Encoder::writeMapEnd() { DCHECK_EQ(top().state, IN_MAP_KEY); flush(); pop(); } inline void Encoder::writeListBegin(TType elemType, uint32_t size) { push(reflection::TYPE_LIST, topType(), size); } inline void Encoder::writeListEnd() { DCHECK_EQ(top().state, IN_LIST_VALUE); flush(); pop(); } inline void Encoder::writeSetBegin(TType elemType, uint32_t size) { push(reflection::TYPE_SET, topType(), size); } inline void Encoder::writeSetEnd() { DCHECK_EQ(top().state, IN_SET_VALUE); flush(); pop(); } inline void Encoder::writeStructBegin(const char* name) { push(reflection::TYPE_STRUCT, topType(), 0); } inline void Encoder::writeStructEnd() { DCHECK_EQ(top().state, FLUSHED); // writeFieldStop() called flush() pop(); } inline void Encoder::push(reflection::Type rtype, int64_t type, uint32_t size) { DCHECK(stack_.empty() || top().inDataState()); if (reflection::getType(type) != rtype) { throw TLibraryException(folly::to( "Invalid aggregate type ", reflection::getType(type), " expected ", rtype)); } const DataType* dt = &(schema_->map().at(type)); stack_.emplace_back(new EncoderState(type, dt, size)); } inline void Encoder::pop() { auto& s = top(); DCHECK_EQ(s.state, FLUSHED); int64_t size = s.dataType->fixedSize; DCHECK(size == -1 || size == s.buf->computeChainDataLength()); size_t bytesWritten = s.bytesWritten; DCHECK_EQ(bytesWritten, s.buf->computeChainDataLength()); auto buf = std::move(s.buf); stack_.pop_back(); if (stack_.empty()) { outputBuf_->prependChain(std::move(buf)); bytesWritten_ = bytesWritten; } else { writeData(std::move(buf)); top().bytesWritten += bytesWritten; } } inline Encoder::EncoderState& Encoder::top() { DCHECK(!stack_.empty()); DCHECK(stack_.back()); return *stack_.back(); } inline const Encoder::EncoderState& Encoder::top() const { DCHECK(!stack_.empty()); DCHECK(stack_.back()); return *stack_.back(); } inline int64_t Encoder::topType() const { return (stack_.empty() ? rootType_ : top().field.type); } inline void Encoder::writeBool(bool v) { auto& s = top(); DCHECK(s.inDataState()); s.checkType(reflection::TYPE_BOOL); s.markFieldSet(); s.bools.emplace_back(s.tag, v); s.dataWritten(); } inline void Encoder::writeByte(int8_t v) { auto& s = top(); DCHECK(s.inDataState()); s.markFieldSet(); s.checkType(reflection::TYPE_BYTE); s.bytes.emplace_back(s.tag, v); s.dataWritten(); } inline void Encoder::writeI16(int16_t v) { auto& s = top(); DCHECK(s.inDataState()); s.checkType(reflection::TYPE_I16); s.markFieldSet(); if (s.field.isFixed) { s.fixedInt16s.emplace_back(s.tag, v); } else { s.varInts.emplace_back(s.tag, v); } s.dataWritten(); } inline void Encoder::writeI32(int32_t v) { auto& s = top(); DCHECK(s.inDataState()); s.checkType(reflection::TYPE_I32, reflection::TYPE_ENUM); s.markFieldSet(); if (reflection::getType(s.field.type) == reflection::TYPE_ENUM && s.field.isStrictEnum) { auto& dt = schema_->map().at(s.field.type); uint8_t nbits = dt.enumBits(); uint32_t value = findIndex(dt.enumValues, v); s.strictEnums.push_back({s.tag, {nbits, value}}); s.totalStrictEnumBits += nbits; } else { if (s.field.isFixed) { s.fixedInt32s.emplace_back(s.tag, v); } else { s.varInts.emplace_back(s.tag, v); } } s.dataWritten(); } inline void Encoder::writeI64(int64_t v) { innerWriteI64(v, reflection::TYPE_I64); } inline void Encoder::writeDouble(double v) { innerWriteI64(bitwise_cast(v), reflection::TYPE_DOUBLE); } inline void Encoder::innerWriteI64(int64_t v, reflection::Type expected) { auto& s = top(); DCHECK(s.inDataState()); s.markFieldSet(); s.checkType(expected); if (s.field.isFixed) { s.fixedInt64s.emplace_back(s.tag, v); } else { s.varInt64s.emplace_back(s.tag, v); } s.dataWritten(); } namespace { std::unique_ptr copyBufferToSize( const void* data, size_t dataSize, size_t outSize, char pad) { auto buf = folly::IOBuf::create(outSize); dataSize = std::min(dataSize, outSize); memcpy(buf->writableData(), data, dataSize); if (dataSize < outSize) { memset(buf->writableData() + dataSize, pad, outSize - dataSize); } buf->append(outSize); return buf; } std::unique_ptr copyBufferAndTerminate( const void* data, size_t dataSize, char terminator) { if (memchr(data, terminator, dataSize)) { throw TProtocolException("terminator found in terminated string"); } // 1 byte of tailroom for the terminator auto buf = folly::IOBuf::copyBuffer(data, dataSize, 0, 1); buf->writableTail()[0] = terminator; buf->append(1); return buf; } } // namespace // TODO(tudorb): Zero-copy version inline void Encoder::writeBytes(folly::StringPiece data) { auto& s = top(); DCHECK(s.inDataState()); s.checkType(reflection::TYPE_STRING); s.markFieldSet(); if (s.field.isInterned) { CHECK(internTable_); s.varInternIds.emplace_back(s.tag, internTable_->add(data)); } else if (s.field.isFixed) { s.strings.emplace_back( s.tag, copyBufferToSize(data.data(), data.size(), s.field.fixedStringSize, s.field.pad)); s.bytesWritten += s.field.fixedStringSize; } else if (s.field.isTerminated) { s.strings.emplace_back( s.tag, copyBufferAndTerminate(data.data(), data.size(), s.field.terminator)); s.bytesWritten += data.size() + 1; } else { s.varLengths.emplace_back(s.tag, data.size()); s.strings.emplace_back( s.tag, folly::IOBuf::copyBuffer(data.data(), data.size())); s.bytesWritten += data.size(); } s.dataWritten(); } inline void Encoder::writeData(std::unique_ptr&& data) { auto& s = top(); DCHECK(s.inDataState()); s.markFieldSet(); s.strings.emplace_back(s.tag, std::move(data)); s.dataWritten(); } inline void Encoder::flush() { auto& s = top(); DCHECK(s.inFlushableState()); if (s.state == IN_STRUCT) { // TODO(tudorb): Check that all required fields were actually specified. flushStruct(); } flushData(s.state == IN_STRUCT); s.state = FLUSHED; } } // namespace neutronium } // namespace protocol } // namespace thrift } // namespace apache #endif /* THRIFT_LIB_CPP_PROTOCOL_NEUTRONIUM_ENCODER_INL_H_ */