Codebase list mozc / 7306a9d client / client_quality_test_main.cc
7306a9d

Tree @7306a9d (Download .tar.gz)

client_quality_test_main.cc @7306a9draw · history · blame

// Copyright 2010-2012, Google Inc.
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
//     * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//     * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
//     * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <iostream>
#include <map>
#include <numeric>  // accumulate
#include <string>
#include <vector>

#include "base/base.h"
#include "base/file_stream.h"
#include "base/flags.h"
#include "base/multifile.h"
#include "base/util.h"
#include "client/client.h"
#include "evaluation/scorer.h"
#include "session/commands.pb.h"

// Test data automatically generated by gen_client_quality_test_data.py
// TestCase test_cases[] is defined.
#include "client/client_quality_test_data.h"

DEFINE_string(server_path, "", "specify server path");
DEFINE_string(log_path, "", "specify log output file path");
DEFINE_int32(max_case_for_source, 500,
             "specify max test case number for each test sources");

namespace mozc {
bool IsValidSourceSentence(const string &str) {
  // TODO(noriyukit) Treat alphabets by changing to Eisu-mode
  if (Util::ContainsScriptType(str, Util::ALPHABET)) {
    LOG(WARNING) << "contains ALPHABET: " << str;
    return false;
  }

  // Source should not contain kanji
  if (Util::ContainsScriptType(str, Util::KANJI)) {
    LOG(WARNING) << "contains KANJI: " << str;
    return false;
  }

  // Source should not contain katakana
  string tmp, tmp2;
  Util::StringReplace(str, "\xE3\x83\xBC", "", true, &tmp);  // "ー" -> ""
  Util::StringReplace(tmp, "\xE3\x83\xBB", "", true, &tmp2);  // "・" -> ""
  if (Util::ContainsScriptType(tmp2, Util::KATAKANA)) {
    LOG(WARNING) << "contain KATAKANA: " << str;
    return false;
  }
  return true;
}

bool GenerateKeySequenceFrom(const string& hiragana_sentence,
                             vector<commands::KeyEvent>* keys) {
  CHECK(keys);
  keys->clear();

  string tmp, input;
  Util::HiraganaToRomanji(hiragana_sentence, &tmp);
  Util::FullWidthToHalfWidth(tmp, &input);

  const char* begin = input.c_str();
  const char* end = begin + input.size();
  while (begin < end) {
    size_t mblen = 0;
    const char32 ucs4 = Util::UTF8ToUCS4(begin, end, &mblen);
    CHECK_GT(mblen, 0);
    begin += mblen;

    // TODO(noriyukit) Improve key sequence generation; currently, a few ucs4
    // codes, like FF5E and 300E, cannot be handled.
    commands::KeyEvent key;
    if (ucs4 >= 0x20 && ucs4 <= 0x7F) {
      key.set_key_code(static_cast<int>(ucs4));
    } else if (ucs4 == 0x3001 || ucs4 == 0xFF64) {
      key.set_key_code(0x002C);  // Full-width comma -> Half-width comma
    } else if (ucs4 == 0x3002 || ucs4 == 0xFF0E || ucs4 == 0xFF61) {
      key.set_key_code(0x002E);  // Full-width period -> Half-width period
    } else if (ucs4 == 0x2212 || ucs4 == 0x2015) {
      key.set_key_code(0x002D);  // "−" -> "-"
    } else if (ucs4 == 0x300C || ucs4 == 0xff62) {
      key.set_key_code(0x005B);  // "「" -> "["
    } else if (ucs4 == 0x300D || ucs4 == 0xff63) {
      key.set_key_code(0x005D);  // "」" -> "]"
    } else if (ucs4 == 0x30FB || ucs4 == 0xFF65) {
      key.set_key_code(0x002F);  // "・" -> "/"  "・" -> "/"
    } else {
      LOG(WARNING) << "Unexpected character: " << hex << ucs4
                   << ": in " << input << " (" << hiragana_sentence << ")";
      return false;
    }
    keys->push_back(key);
  }

  // Conversion key
  {
    commands::KeyEvent key;
    key.set_special_key(commands::KeyEvent::SPACE);
    keys->push_back(key);
  }
  return true;
}

bool GetPreedit(const commands::Output &output, string* str) {
  CHECK(str);

  if (!output.has_preedit()) {
    LOG(WARNING) << "No result";
    return false;
  }

  str->clear();
  for (size_t i = 0; i < output.preedit().segment_size(); ++i) {
    str->append(output.preedit().segment(i).value());
  }

  return true;
}

bool CalculateBLEU(client::Client* client,
                   const string& hiragana_sentence,
                   const string& expected_result, double* score) {
  // Prepare key events
  vector<commands::KeyEvent> keys;
  if (!GenerateKeySequenceFrom(hiragana_sentence, &keys)) {
    LOG(WARNING) << "Failed to generated key events from: "
               << hiragana_sentence;
    return false;
  }

  // Must send ON first
  commands::Output output;
  {
    commands::KeyEvent key;
    key.set_special_key(commands::KeyEvent::ON);
    client->SendKey(key, &output);
  }

  // Send keys
  for (size_t i = 0; i < keys.size(); ++i) {
    client->SendKey(keys[i], &output);
  }
  VLOG(2) << "Server response: " << output.Utf8DebugString();

  // Calculate score
  string expected_normalized;
  Scorer::NormalizeForEvaluate(expected_result, &expected_normalized);
  vector<string> goldens;
  goldens.push_back(expected_normalized);
  string preedit, preedit_normalized;
  if (!GetPreedit(output, &preedit) || preedit.empty()) {
    LOG(WARNING) << "Could not get output";
    return false;
  }
  Scorer::NormalizeForEvaluate(preedit, &preedit_normalized);

  *score = Scorer::BLEUScore(goldens, preedit_normalized);

  VLOG(1) << hiragana_sentence << endl
          << "   score: " << (*score) << endl
          << " preedit: " << preedit_normalized << endl
          << "expected: " << expected_normalized;

  // Revert session to prevent server from learning this conversion
  commands::SessionCommand command;
  command.set_type(commands::SessionCommand::REVERT);
  client->SendCommand(command, &output);

  return true;
}

double CalculateMean(const vector<double>& scores) {
  CHECK(!scores.empty());
  const double sum = accumulate(scores.begin(), scores.end(), 0.0);
  return sum / static_cast<double>(scores.size());
}
}  // anonymous namespace


int main(int argc, char* argv[]) {
  InitGoogle(argv[0], &argc, &argv, true);

  mozc::client::Client client;
  if (!FLAGS_server_path.empty()) {
    client.set_server_program(FLAGS_server_path);
  }

  CHECK(client.IsValidRunLevel()) << "IsValidRunLevel failed";
  CHECK(client.EnsureSession()) << "EnsureSession failed";
  CHECK(client.NoOperation()) << "Server is not respoinding";

  map<string, vector<double> > scores;    // Results to be averaged

  for (mozc::TestCase* test_case = mozc::test_cases; test_case->source != NULL;
       ++test_case) {
    const string &source = test_case->source;
    const string &hiragana_sentence = test_case->hiragana_sentence;
    const string &expected_result = test_case->expected_result;

    if (scores.find(source) == scores.end()) {
      scores[source] = vector<double>();
    }
    if (scores[source].size() >= FLAGS_max_case_for_source) {
      continue;
    }

    VLOG(1) << "Processing " << hiragana_sentence;
    if (!mozc::IsValidSourceSentence(hiragana_sentence)) {
      LOG(WARNING) << "Invalid test case: " << endl
                   << "    source: " << source << endl
                   << "  hiragana: " << hiragana_sentence << endl
                   << "  expected: " << expected_result;
      continue;
    }

    double score;
    if (!mozc::CalculateBLEU(&client, hiragana_sentence,
                             expected_result, &score)) {
      LOG(WARNING) << "Failed to calculate BLEU score: " << endl
                   << "    source: " << source << endl
                   << "  hiragana: " << hiragana_sentence << endl
                   << "  expected: " << expected_result;
      continue;
    }
    scores[source].push_back(score);
  }

  ostream *ofs = &cout;
  if (!FLAGS_log_path.empty()) {
    ofs = new mozc::OutputFileStream(FLAGS_log_path.c_str());
  }

  // Average the scores
  for (map<string, vector<double> >::iterator it = scores.begin();
       it != scores.end(); ++it) {
    const double mean = mozc::CalculateMean(it->second);
    (*ofs) << it->first << " : " << mean << endl;
  }
  if (ofs != &cout) {
    delete ofs;
  }

  return 0;
}