// Copyright (c) 2011-present, Facebook, Inc.  All rights reserved.
//  This source code is licensed under both the GPLv2 (found in the
//  COPYING file in the root directory) and Apache 2.0 License
//  (found in the LICENSE.Apache file in the root directory).
//
// This file implements the callback "bridge" between Java and C++ for
// ROCKSDB_NAMESPACE::Comparator.

#include "rocksjni/comparatorjnicallback.h"
#include "rocksjni/portal.h"

namespace ROCKSDB_NAMESPACE {
ComparatorJniCallback::ComparatorJniCallback(
    JNIEnv* env, jobject jcomparator,
    const ComparatorJniCallbackOptions* options)
    : JniCallback(env, jcomparator),
    m_options(options) {

  // cache the AbstractComparatorJniBridge class as we will reuse it many times for each callback
  m_abstract_comparator_jni_bridge_clazz =
      static_cast<jclass>(env->NewGlobalRef(AbstractComparatorJniBridge::getJClass(env)));

  // Note: The name of a Comparator will not change during it's lifetime,
  // so we cache it in a global var
  jmethodID jname_mid = AbstractComparatorJni::getNameMethodId(env);
  if (jname_mid == nullptr) {
    // exception thrown: NoSuchMethodException or OutOfMemoryError
    return;
  }
  jstring js_name = (jstring)env->CallObjectMethod(m_jcallback_obj, jname_mid);
  if (env->ExceptionCheck()) {
    // exception thrown
    return;
  }
  jboolean has_exception = JNI_FALSE;
  m_name = JniUtil::copyString(env, js_name,
      &has_exception);  // also releases jsName
  if (has_exception == JNI_TRUE) {
    // exception thrown
    return;
  }

  // cache the ByteBuffer class as we will reuse it many times for each callback
  m_jbytebuffer_clazz =
      static_cast<jclass>(env->NewGlobalRef(ByteBufferJni::getJClass(env)));

  m_jcompare_mid = AbstractComparatorJniBridge::getCompareInternalMethodId(
      env, m_abstract_comparator_jni_bridge_clazz);
  if (m_jcompare_mid == nullptr) {
    // exception thrown: NoSuchMethodException or OutOfMemoryError
    return;
  }

  m_jshortest_mid =
    AbstractComparatorJniBridge::getFindShortestSeparatorInternalMethodId(
        env, m_abstract_comparator_jni_bridge_clazz);
  if (m_jshortest_mid == nullptr) {
    // exception thrown: NoSuchMethodException or OutOfMemoryError
    return;
  }

  m_jshort_mid =
    AbstractComparatorJniBridge::getFindShortSuccessorInternalMethodId(env,
        m_abstract_comparator_jni_bridge_clazz);
  if (m_jshort_mid == nullptr) {
    // exception thrown: NoSuchMethodException or OutOfMemoryError
    return;
  }

  // do we need reusable buffers?
  if (m_options->max_reused_buffer_size > -1) {

    if (m_options->reused_synchronisation_type
        == ReusedSynchronisationType::THREAD_LOCAL) {
      // buffers reused per thread
      UnrefHandler unref = [](void* ptr) {
        ThreadLocalBuf* tlb = reinterpret_cast<ThreadLocalBuf*>(ptr);
        jboolean attached_thread = JNI_FALSE;
        JNIEnv* _env = JniUtil::getJniEnv(tlb->jvm, &attached_thread);
        if (_env != nullptr) {
          if (tlb->direct_buffer) {
            void* buf = _env->GetDirectBufferAddress(tlb->jbuf);
            delete[] static_cast<char*>(buf);
          }
          _env->DeleteGlobalRef(tlb->jbuf);
          JniUtil::releaseJniEnv(tlb->jvm, attached_thread);
        }
      };

      m_tl_buf_a = new ThreadLocalPtr(unref);
      m_tl_buf_b = new ThreadLocalPtr(unref);

      m_jcompare_buf_a = nullptr;
      m_jcompare_buf_b = nullptr;
      m_jshortest_buf_start = nullptr;
      m_jshortest_buf_limit = nullptr;
      m_jshort_buf_key = nullptr;

    } else {
      //buffers reused and shared across threads
      const bool adaptive =
          m_options->reused_synchronisation_type == ReusedSynchronisationType::ADAPTIVE_MUTEX;
      mtx_compare = std::unique_ptr<port::Mutex>(new port::Mutex(adaptive));
      mtx_shortest = std::unique_ptr<port::Mutex>(new port::Mutex(adaptive));
      mtx_short = std::unique_ptr<port::Mutex>(new port::Mutex(adaptive));

      m_jcompare_buf_a = env->NewGlobalRef(ByteBufferJni::construct(
          env, m_options->direct_buffer, m_options->max_reused_buffer_size,
          m_jbytebuffer_clazz));
      if (m_jcompare_buf_a == nullptr) {
        // exception thrown: OutOfMemoryError
        return;
      }

      m_jcompare_buf_b = env->NewGlobalRef(ByteBufferJni::construct(
          env, m_options->direct_buffer, m_options->max_reused_buffer_size,
          m_jbytebuffer_clazz));
      if (m_jcompare_buf_b == nullptr) {
        // exception thrown: OutOfMemoryError
        return;
      }

      m_jshortest_buf_start = env->NewGlobalRef(ByteBufferJni::construct(
          env, m_options->direct_buffer, m_options->max_reused_buffer_size,
          m_jbytebuffer_clazz));
      if (m_jshortest_buf_start == nullptr) {
        // exception thrown: OutOfMemoryError
        return;
      }

      m_jshortest_buf_limit = env->NewGlobalRef(ByteBufferJni::construct(
          env, m_options->direct_buffer, m_options->max_reused_buffer_size,
          m_jbytebuffer_clazz));
      if (m_jshortest_buf_limit == nullptr) {
        // exception thrown: OutOfMemoryError
        return;
      }

      m_jshort_buf_key = env->NewGlobalRef(ByteBufferJni::construct(
          env, m_options->direct_buffer, m_options->max_reused_buffer_size,
          m_jbytebuffer_clazz));
      if (m_jshort_buf_key == nullptr) {
        // exception thrown: OutOfMemoryError
        return;
      }

      m_tl_buf_a = nullptr;
      m_tl_buf_b = nullptr;
    }

  } else {
    m_jcompare_buf_a = nullptr;
    m_jcompare_buf_b = nullptr;
    m_jshortest_buf_start = nullptr;
    m_jshortest_buf_limit = nullptr;
    m_jshort_buf_key = nullptr;

    m_tl_buf_a = nullptr;
    m_tl_buf_b = nullptr;
  }
}

ComparatorJniCallback::~ComparatorJniCallback() {
  jboolean attached_thread = JNI_FALSE;
  JNIEnv* env = getJniEnv(&attached_thread);
  assert(env != nullptr);

  env->DeleteGlobalRef(m_abstract_comparator_jni_bridge_clazz);

  env->DeleteGlobalRef(m_jbytebuffer_clazz);

  if (m_jcompare_buf_a != nullptr) {
    if (m_options->direct_buffer) {
      void* buf = env->GetDirectBufferAddress(m_jcompare_buf_a);
      delete[] static_cast<char*>(buf);
    }
    env->DeleteGlobalRef(m_jcompare_buf_a);
  }

  if (m_jcompare_buf_b != nullptr) {
    if (m_options->direct_buffer) {
      void* buf = env->GetDirectBufferAddress(m_jcompare_buf_b);
      delete[] static_cast<char*>(buf);
    }
    env->DeleteGlobalRef(m_jcompare_buf_b);
  }

  if (m_jshortest_buf_start != nullptr) {
    if (m_options->direct_buffer) {
      void* buf = env->GetDirectBufferAddress(m_jshortest_buf_start);
      delete[] static_cast<char*>(buf);
    }
    env->DeleteGlobalRef(m_jshortest_buf_start);
  }

  if (m_jshortest_buf_limit != nullptr) {
    if (m_options->direct_buffer) {
      void* buf = env->GetDirectBufferAddress(m_jshortest_buf_limit);
      delete[] static_cast<char*>(buf);
    }
    env->DeleteGlobalRef(m_jshortest_buf_limit);
  }

  if (m_jshort_buf_key != nullptr) {
    if (m_options->direct_buffer) {
      void* buf = env->GetDirectBufferAddress(m_jshort_buf_key);
      delete[] static_cast<char*>(buf);
    }
    env->DeleteGlobalRef(m_jshort_buf_key);
  }

  if (m_tl_buf_a != nullptr) {
    delete m_tl_buf_a;
  }

  if (m_tl_buf_b != nullptr) {
    delete m_tl_buf_b;
  }

  releaseJniEnv(attached_thread);
}

const char* ComparatorJniCallback::Name() const {
  return m_name.get();
}

int ComparatorJniCallback::Compare(const Slice& a, const Slice& b) const {
  jboolean attached_thread = JNI_FALSE;
  JNIEnv* env = getJniEnv(&attached_thread);
  assert(env != nullptr);

  const bool reuse_jbuf_a =
      static_cast<int64_t>(a.size()) <= m_options->max_reused_buffer_size;
  const bool reuse_jbuf_b =
      static_cast<int64_t>(b.size()) <= m_options->max_reused_buffer_size;

  MaybeLockForReuse(mtx_compare, reuse_jbuf_a || reuse_jbuf_b);

  jobject jcompare_buf_a = GetBuffer(env, a, reuse_jbuf_a, m_tl_buf_a, m_jcompare_buf_a);
  if (jcompare_buf_a == nullptr) {
    // exception occurred
    MaybeUnlockForReuse(mtx_compare, reuse_jbuf_a || reuse_jbuf_b);
    env->ExceptionDescribe(); // print out exception to stderr
    releaseJniEnv(attached_thread);
    return 0;
  }

  jobject jcompare_buf_b = GetBuffer(env, b, reuse_jbuf_b, m_tl_buf_b, m_jcompare_buf_b);
  if (jcompare_buf_b == nullptr) {
    // exception occurred
    if (!reuse_jbuf_a) {
      DeleteBuffer(env, jcompare_buf_a);
    }
    MaybeUnlockForReuse(mtx_compare, reuse_jbuf_a || reuse_jbuf_b);
    env->ExceptionDescribe(); // print out exception to stderr
    releaseJniEnv(attached_thread);
    return 0;
  }

  jint result =
    env->CallStaticIntMethod(
      m_abstract_comparator_jni_bridge_clazz, m_jcompare_mid,
      m_jcallback_obj,
      jcompare_buf_a, reuse_jbuf_a ? a.size() : -1,
      jcompare_buf_b, reuse_jbuf_b ? b.size() : -1);

  if (env->ExceptionCheck()) {
    // exception thrown from CallIntMethod
    env->ExceptionDescribe(); // print out exception to stderr
    result = 0; // we could not get a result from java callback so use 0
  }

  if (!reuse_jbuf_a) {
    DeleteBuffer(env, jcompare_buf_a);
  }
  if (!reuse_jbuf_b) {
    DeleteBuffer(env, jcompare_buf_b);
  }

  MaybeUnlockForReuse(mtx_compare, reuse_jbuf_a || reuse_jbuf_b);

  releaseJniEnv(attached_thread);

  return result;
}

void ComparatorJniCallback::FindShortestSeparator(
    std::string* start, const Slice& limit) const {
  if (start == nullptr) {
    return;
  }

  jboolean attached_thread = JNI_FALSE;
  JNIEnv* env = getJniEnv(&attached_thread);
  assert(env != nullptr);

  const bool reuse_jbuf_start =
      static_cast<int64_t>(start->length()) <= m_options->max_reused_buffer_size;
  const bool reuse_jbuf_limit =
      static_cast<int64_t>(limit.size()) <= m_options->max_reused_buffer_size;

  MaybeLockForReuse(mtx_shortest, reuse_jbuf_start || reuse_jbuf_limit);

  Slice sstart(start->data(), start->length());
  jobject j_start_buf = GetBuffer(env, sstart, reuse_jbuf_start, m_tl_buf_a, m_jshortest_buf_start);
  if (j_start_buf == nullptr) {
    // exception occurred
    MaybeUnlockForReuse(mtx_shortest, reuse_jbuf_start || reuse_jbuf_limit);
    env->ExceptionDescribe(); // print out exception to stderr
    releaseJniEnv(attached_thread);
    return;
  }

  jobject j_limit_buf = GetBuffer(env, limit, reuse_jbuf_limit, m_tl_buf_b, m_jshortest_buf_limit);
  if (j_limit_buf == nullptr) {
    // exception occurred
    if (!reuse_jbuf_start) {
      DeleteBuffer(env, j_start_buf);
    }
    MaybeUnlockForReuse(mtx_shortest, reuse_jbuf_start || reuse_jbuf_limit);
    env->ExceptionDescribe(); // print out exception to stderr
    releaseJniEnv(attached_thread);
    return;
  }

  jint jstart_len = env->CallStaticIntMethod(
      m_abstract_comparator_jni_bridge_clazz, m_jshortest_mid,
      m_jcallback_obj,
      j_start_buf, reuse_jbuf_start ? start->length() : -1,
      j_limit_buf, reuse_jbuf_limit ? limit.size() : -1);

  if (env->ExceptionCheck()) {
    // exception thrown from CallIntMethod
    env->ExceptionDescribe(); // print out exception to stderr

  } else if (static_cast<size_t>(jstart_len) != start->length()) {
    // start buffer has changed in Java, so update `start` with the result
    bool copy_from_non_direct = false;
    if (reuse_jbuf_start) {
        // reused a buffer
        if (m_options->direct_buffer) {
          // reused direct buffer
          void* start_buf = env->GetDirectBufferAddress(j_start_buf);
          if (start_buf == nullptr) {
            if (!reuse_jbuf_start) {
              DeleteBuffer(env, j_start_buf);
            }
            if (!reuse_jbuf_limit) {
              DeleteBuffer(env, j_limit_buf);
            }
            MaybeUnlockForReuse(mtx_shortest, reuse_jbuf_start || reuse_jbuf_limit);
            ROCKSDB_NAMESPACE::RocksDBExceptionJni::ThrowNew(
                env, "Unable to get Direct Buffer Address");
            env->ExceptionDescribe();  // print out exception to stderr
            releaseJniEnv(attached_thread);
            return;
          }
          start->assign(static_cast<const char*>(start_buf), jstart_len);

        } else {

          // reused non-direct buffer
          copy_from_non_direct = true;
        }
    } else {
        // there was a new buffer
        if (m_options->direct_buffer) {
          // it was direct... don't forget to potentially truncate the `start` string
          start->resize(jstart_len);
        } else {
          // it was non-direct
          copy_from_non_direct = true;
        }
    }

    if (copy_from_non_direct) {
      jbyteArray jarray = ByteBufferJni::array(env, j_start_buf,
          m_jbytebuffer_clazz);
      if (jarray == nullptr) {
        if (!reuse_jbuf_start) {
          DeleteBuffer(env, j_start_buf);
        }
        if (!reuse_jbuf_limit) {
          DeleteBuffer(env, j_limit_buf);
        }
        MaybeUnlockForReuse(mtx_shortest, reuse_jbuf_start || reuse_jbuf_limit);
        env->ExceptionDescribe();  // print out exception to stderr
        releaseJniEnv(attached_thread);
        return;
      }
      jboolean has_exception = JNI_FALSE;
      JniUtil::byteString<std::string>(env, jarray, [start, jstart_len](const char* data, const size_t) {
        return start->assign(data, static_cast<size_t>(jstart_len));
      }, &has_exception);
      env->DeleteLocalRef(jarray);
      if (has_exception == JNI_TRUE) {
        if (!reuse_jbuf_start) {
          DeleteBuffer(env, j_start_buf);
        }
        if (!reuse_jbuf_limit) {
          DeleteBuffer(env, j_limit_buf);
        }
        env->ExceptionDescribe();  // print out exception to stderr
        MaybeUnlockForReuse(mtx_shortest, reuse_jbuf_start || reuse_jbuf_limit);
        releaseJniEnv(attached_thread);
        return;
      }
    }
  }

  if (!reuse_jbuf_start) {
    DeleteBuffer(env, j_start_buf);
  }
  if (!reuse_jbuf_limit) {
    DeleteBuffer(env, j_limit_buf);
  }

  MaybeUnlockForReuse(mtx_shortest, reuse_jbuf_start || reuse_jbuf_limit);

  releaseJniEnv(attached_thread);
}

void ComparatorJniCallback::FindShortSuccessor(
    std::string* key) const {
  if (key == nullptr) {
    return;
  }

  jboolean attached_thread = JNI_FALSE;
  JNIEnv* env = getJniEnv(&attached_thread);
  assert(env != nullptr);

  const bool reuse_jbuf_key =
      static_cast<int64_t>(key->length()) <= m_options->max_reused_buffer_size;

  MaybeLockForReuse(mtx_short, reuse_jbuf_key);

  Slice skey(key->data(), key->length());
  jobject j_key_buf = GetBuffer(env, skey, reuse_jbuf_key, m_tl_buf_a, m_jshort_buf_key);
  if (j_key_buf == nullptr) {
    // exception occurred
    MaybeUnlockForReuse(mtx_short, reuse_jbuf_key);
    env->ExceptionDescribe(); // print out exception to stderr
    releaseJniEnv(attached_thread);
    return;
  }

  jint jkey_len = env->CallStaticIntMethod(
      m_abstract_comparator_jni_bridge_clazz, m_jshort_mid,
      m_jcallback_obj,
      j_key_buf, reuse_jbuf_key ? key->length() : -1);

  if (env->ExceptionCheck()) {
    // exception thrown from CallObjectMethod
    if (!reuse_jbuf_key) {
      DeleteBuffer(env, j_key_buf);
    }
    MaybeUnlockForReuse(mtx_short, reuse_jbuf_key);
    env->ExceptionDescribe();  // print out exception to stderr
    releaseJniEnv(attached_thread);
    return;

  }

  if (static_cast<size_t>(jkey_len) != key->length()) {
    // key buffer has changed in Java, so update `key` with the result
    bool copy_from_non_direct = false;
    if (reuse_jbuf_key) {
        // reused a buffer
        if (m_options->direct_buffer) {
          // reused direct buffer
          void* key_buf = env->GetDirectBufferAddress(j_key_buf);
          if (key_buf == nullptr) {
            ROCKSDB_NAMESPACE::RocksDBExceptionJni::ThrowNew(
                env, "Unable to get Direct Buffer Address");
            if (!reuse_jbuf_key) {
              DeleteBuffer(env, j_key_buf);
            }
            MaybeUnlockForReuse(mtx_short, reuse_jbuf_key);
            env->ExceptionDescribe();  // print out exception to stderr
            releaseJniEnv(attached_thread);
            return;
          }
          key->assign(static_cast<const char*>(key_buf), jkey_len);
        } else {
          // reused non-direct buffer
          copy_from_non_direct = true;
        }
    } else {
        // there was a new buffer
        if (m_options->direct_buffer) {
          // it was direct... don't forget to potentially truncate the `key` string
          key->resize(jkey_len);
        } else {
          // it was non-direct
          copy_from_non_direct = true;
        }
    }

    if (copy_from_non_direct) {
      jbyteArray jarray = ByteBufferJni::array(env, j_key_buf,
          m_jbytebuffer_clazz);
      if (jarray == nullptr) {

        if (!reuse_jbuf_key) {
          DeleteBuffer(env, j_key_buf);
        }
        MaybeUnlockForReuse(mtx_short, reuse_jbuf_key);
        env->ExceptionDescribe();  // print out exception to stderr
        releaseJniEnv(attached_thread);
        return;
      }
      jboolean has_exception = JNI_FALSE;
      JniUtil::byteString<std::string>(env, jarray, [key, jkey_len](const char* data, const size_t) {
        return key->assign(data, static_cast<size_t>(jkey_len));
      }, &has_exception);
      env->DeleteLocalRef(jarray);
      if (has_exception == JNI_TRUE) {
        if (!reuse_jbuf_key) {
          DeleteBuffer(env, j_key_buf);
        }
        MaybeUnlockForReuse(mtx_short, reuse_jbuf_key);
        env->ExceptionDescribe();  // print out exception to stderr
        releaseJniEnv(attached_thread);
        return;
      }
    }
  }

  if (!reuse_jbuf_key) {
    DeleteBuffer(env, j_key_buf);
  }

  MaybeUnlockForReuse(mtx_short, reuse_jbuf_key);

  releaseJniEnv(attached_thread);
}

inline void ComparatorJniCallback::MaybeLockForReuse(
    const std::unique_ptr<port::Mutex>& mutex, const bool cond) const {
  // no need to lock if using thread_local
  if (m_options->reused_synchronisation_type != ReusedSynchronisationType::THREAD_LOCAL
      && cond) {
    mutex.get()->Lock();
  }
}

inline void ComparatorJniCallback::MaybeUnlockForReuse(
    const std::unique_ptr<port::Mutex>& mutex, const bool cond) const {
  // no need to unlock if using thread_local
  if (m_options->reused_synchronisation_type != ReusedSynchronisationType::THREAD_LOCAL
      && cond) {
    mutex.get()->Unlock();
  }
}

jobject ComparatorJniCallback::GetBuffer(JNIEnv* env, const Slice& src,
    bool reuse_buffer, ThreadLocalPtr* tl_buf, jobject jreuse_buffer) const {
  if (reuse_buffer) {
    if (m_options->reused_synchronisation_type
        == ReusedSynchronisationType::THREAD_LOCAL) {

      // reuse thread-local bufffer
      ThreadLocalBuf* tlb = reinterpret_cast<ThreadLocalBuf*>(tl_buf->Get());
      if (tlb == nullptr) {
        // thread-local buffer has not yet been created, so create it
        jobject jtl_buf = env->NewGlobalRef(ByteBufferJni::construct(
            env, m_options->direct_buffer, m_options->max_reused_buffer_size,
            m_jbytebuffer_clazz));
        if (jtl_buf == nullptr) {
          // exception thrown: OutOfMemoryError
          return nullptr;
        }
        tlb = new ThreadLocalBuf(m_jvm, m_options->direct_buffer, jtl_buf);
        tl_buf->Reset(tlb);
      }
      return ReuseBuffer(env, src, tlb->jbuf);
    } else {

      // reuse class member buffer
      return ReuseBuffer(env, src, jreuse_buffer);
    }
  } else {

    // new buffer
    return NewBuffer(env, src);
  }
}

jobject ComparatorJniCallback::ReuseBuffer(
    JNIEnv* env, const Slice& src, jobject jreuse_buffer) const {
  // we can reuse the buffer
  if (m_options->direct_buffer) {
    // copy into direct buffer
    void* buf = env->GetDirectBufferAddress(jreuse_buffer);
    if (buf == nullptr) {
      // either memory region is undefined, given object is not a direct java.nio.Buffer, or JNI access to direct buffers is not supported by this virtual machine.
      ROCKSDB_NAMESPACE::RocksDBExceptionJni::ThrowNew(
          env, "Unable to get Direct Buffer Address");
      return nullptr;
    }
    memcpy(buf, src.data(), src.size());
  } else {
    // copy into non-direct buffer
    const jbyteArray jarray = ByteBufferJni::array(env, jreuse_buffer,
        m_jbytebuffer_clazz);
    if (jarray == nullptr) {
      // exception occurred
      return nullptr;
    }
    env->SetByteArrayRegion(jarray, 0, static_cast<jsize>(src.size()),
        const_cast<jbyte*>(reinterpret_cast<const jbyte*>(src.data())));
    if (env->ExceptionCheck()) {
      // exception occurred
      env->DeleteLocalRef(jarray);
      return nullptr;
    }
    env->DeleteLocalRef(jarray);
  }
  return jreuse_buffer;
}

jobject ComparatorJniCallback::NewBuffer(JNIEnv* env, const Slice& src) const {
  // we need a new buffer
  jobject jbuf = ByteBufferJni::constructWith(env, m_options->direct_buffer,
      src.data(), src.size(), m_jbytebuffer_clazz);
  if (jbuf == nullptr) {
    // exception occurred
    return nullptr;
  }
  return jbuf;
}

void ComparatorJniCallback::DeleteBuffer(JNIEnv* env, jobject jbuffer) const {
  env->DeleteLocalRef(jbuffer);
}

}  // namespace ROCKSDB_NAMESPACE