#include <string>
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"  // from @com_google_absl
#include "absl/status/status.h"  // from @com_google_absl
#include "absl/strings/str_format.h"  // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h"
#include "tensorflow_lite_support/cc/task/text/proto/retrieval.pb.h"

namespace tflite {
namespace task {
namespace text {

// QAInput and QAOutput for UniversalSentenceEncoderQA internally.
namespace internal {
struct QAInput;
struct QAOutput;
}  // namespace internal

// Universal Sentence Encoder (USE) Question Answerer. The model uses USE as the
// backbone and answers a question.
// The API expects an USE QA model with the following input and output tensor
// names:
//                    Metadata tensor name   |   Model tensor name
// Inputs tensors:                           |
//   - Query text         "inp_text"         | "ParseExample/ParseExampleV2:1"
//   - Response context   "res_context"      | "ParseExample/ParseExampleV2:2"
//   - Response text:     "res_text"         | "ParseExample/ParseExampleV2:3"
// Output tensors:
//   - Query encoding     "query_encoding"   | "Final/EncodeQuery/mul"
//   - Response encoding  "response_encoding"| "Final/EncodeResult/mul"
// Tensors will be matched by first checking the metadata tesnor name and then
// the Model tensor name. If no matching tensor name is found, the first three
// input tensors will be used for query text, response context, response text,
// respectively; the first two output tensors will be used for query_encoding
// and response encoding, respectively. Other input or output tensors will be
// ignored by `UniversalSentenceEncoderQA`.

class UniversalSentenceEncoderQA
    : public core::BaseTaskApi<internal::QAOutput, const internal::QAInput&> {
  using BaseTaskApi::BaseTaskApi;
  using FeatureVector = ::tflite::task::processor::FeatureVector;

  // TODO(b/198995952): add support to parameterize.
  static constexpr int kFinalEmbeddingSize = 100;

  static tflite::support::StatusOr<std::unique_ptr<UniversalSentenceEncoderQA>>
      const tflite::task::text::RetrievalOptions& options,
      std::unique_ptr<tflite::OpResolver> resolver =

  // Retrieves output from the input by running TFLite engine.
  // Returns an error, if either query_text or responses is empty.
  tflite::support::StatusOr<RetrievalOutput> Retrieve(
      const RetrievalInput& input);

  // Encodes query from the text.
  // Returns an error, if query text is empty.
  tflite::support::StatusOr<FeatureVector> EncodeQuery(
      absl::string_view query_text);

  // Encodes response from the text and/or context.
  // Returns an error, if both text and context are empty.
  tflite::support::StatusOr<FeatureVector> EncodeResponse(
      absl::string_view response_text, absl::string_view response_context);

  // Calculates similarity between two encoded vectors (require same size).
  static tflite::support::StatusOr<float> Similarity(const FeatureVector& a,
                                                     const FeatureVector& b);

  // Gets top k corresponding to output response scores in descending order.
  // If k == 0, all responses are ranked.
  static std::vector<size_t> Top(const RetrievalOutput& output, size_t k = 0);

  absl::Status Init(std::unique_ptr<RetrievalOptions> options);

  absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
                          const internal::QAInput& input) override;

  tflite::support::StatusOr<internal::QAOutput> Postprocess(
      const std::vector<const TfLiteTensor*>& output_tensors,
      const internal::QAInput& input) override;

  internal::QAOutput Run(absl::string_view query_text,
                         absl::string_view response_text,
                         absl::string_view response_context);

  std::unique_ptr<tflite::task::text::RetrievalOptions> options_;

  // The input tensor indices corresponding to the query text tensor, the
  // response context tensor, and the response text tensor, respectively.
  std::vector<int> input_indices_;

  // The output tensor indices corresponding to the query encoding tensor and
  // the response encoding tensor, respectively.
  std::vector<int> output_indices_;

}  // namespace text
}  // namespace task
}  // namespace tflite