// Copyright 2010-2011, Google Inc.
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "converter/nbest_generator.h"
#include <string>
#include "base/base.h"
#include "base/singleton.h"
#include "converter/candidate_filter.h"
#include "converter/connector_interface.h"
#include "converter/lattice.h"
#include "converter/segmenter_interface.h"
#include "converter/segmenter.h"
#include "converter/segments.h"
#include "dictionary/pos_matcher.h"
namespace mozc {
const int kFreeListSize = 512;
const int kCostDiff = 3453;
NBestGenerator::NBestGenerator()
: freelist_(kFreeListSize), filter_(NULL),
begin_node_(NULL), end_node_(NULL),
connector_(ConnectorFactory::GetConnector()),
segmenter_(Singleton<Segmenter>::get()),
lattice_(NULL),
viterbi_result_checked_(false),
is_prediction_(false) {}
NBestGenerator::NBestGenerator(const SegmenterInterface *segmenter)
: freelist_(kFreeListSize), filter_(NULL),
begin_node_(NULL), end_node_(NULL),
connector_(ConnectorFactory::GetConnector()),
segmenter_(segmenter),
lattice_(NULL),
viterbi_result_checked_(false),
is_prediction_(false) {}
NBestGenerator::~NBestGenerator() {}
void NBestGenerator::Init(const Node *begin_node, const Node *end_node,
const Lattice *lattice,
bool is_prediction) {
Reset();
begin_node_ = begin_node;
end_node_ = end_node;
lattice_ = lattice;
is_prediction_ = is_prediction;
if (lattice_ == NULL || !lattice_->has_lattice()) {
LOG(ERROR) << "lattice is not available";
return;
}
for (Node *node = lattice_->begin_nodes(end_node_->begin_pos);
node != NULL; node = node->bnext) {
if (node == end_node_ ||
(node->lid != end_node_->lid &&
node->cost - end_node_->cost <= kCostDiff &&
node->prev != end_node_->prev)) {
QueueElement *eos = freelist_.Alloc();
DCHECK(eos);
eos->node = node;
eos->next = NULL;
eos->fx = node->cost;
eos->gx = 0;
eos->structure_gx = 0;
eos->w_gx = 0;
agenda_->push(eos);
}
}
}
void NBestGenerator::Reset() {
agenda_.reset(new Agenda);
filter_.reset(new CandidateFilter);
freelist_.Free();
viterbi_result_checked_ = false;
is_prediction_ = false;
}
void NBestGenerator::MakeCandidate(Segment::Candidate *candidate,
int32 cost, int32 structure_cost,
int32 wcost,
const vector<const Node *> nodes) const {
CHECK(!nodes.empty());
bool has_constrained_node = false;
bool is_functional = false;
candidate->Init();
candidate->lid = nodes.front()->lid;
candidate->rid = nodes.back()->rid;
candidate->cost = cost;
candidate->structure_cost = structure_cost;
candidate->wcost = wcost;
for (size_t i = 0; i < nodes.size(); ++i) {
const Node *node = nodes[i];
DCHECK(node != NULL);
if (node->constrained_prev != NULL ||
(node->next != NULL &&
node->next->constrained_prev == node)) {
has_constrained_node = true;
}
if (!is_functional && !POSMatcher::IsFunctional(node->lid)) {
candidate->content_value += node->value;
candidate->content_key += node->key;
} else {
is_functional = true;
}
candidate->key += node->key;
candidate->value += node->value;
if (node->attributes & Node::SPELLING_CORRECTION) {
candidate->attributes |= Segment::Candidate::SPELLING_CORRECTION;
}
if (node->attributes & Node::NO_VARIANTS_EXPANSION) {
candidate->attributes |= Segment::Candidate::NO_VARIANTS_EXPANSION;
}
if (node->attributes & Node::USER_DICTIONARY) {
candidate->attributes |= Segment::Candidate::USER_DICTIONARY;
}
}
if (candidate->content_value.empty() ||
candidate->content_key.empty()) {
candidate->content_value = candidate->value;
candidate->content_key = candidate->key;
}
// If result has constrained_node, set CONTEXT_SENSITIVE.
// If a node has constrained node, the node is generated by
// a) compound node and resegmented via personal name resegmentation
// b) compound-based reranking.
if (has_constrained_node) {
candidate->attributes |= Segment::Candidate::CONTEXT_SENSITIVE;
}
}
bool NBestGenerator::Next(Segment::Candidate *candidate,
Segments::RequestType request_type) {
if (lattice_ == NULL || !lattice_->has_lattice()) {
LOG(ERROR) << "Must create lattice in advance";
return false;
}
// |cost| and |structure_cost| are calculated as follows:
//
// Example:
// |left_node| => |node1| => |node2| => |node3| => |right_node|.
// |node1| .. |node2| consists of a candidate.
//
// cost = (left_node->cost - begin_node_->cost) +
// trans(left_node, node1) + node1->wcost +
// trans(node1, node2) + node2->wcost +
// trans(node2, node3) + node3->wcost +
// trans(node3, rigt_node) +
// (right_node->cost - end_node_->cost)
// structure_cost = trans(node1, node2) + trans(node2, node3);
// wcost = node1->wcost +
// trans(node1, node2) + node2->wcost +
// trans(node2, node3) + node3->wcost
//
// Here (left_node->cost - begin_node_->cost) and
// (right_node->cost - end_node->cost) act as an approximation
// of marginalized costs of the candidate |node1| .. |node3|.
// "marginalized cost" means that how likely the left_node or right_node
// are selected by taking the all paths encoded in the lattice.
// These approximated costs are exactly 0 when taking Viterbi-best
// path.
// Insert Viterbi best result here to make sure that
// the top result is Viterbi best result.
if (!viterbi_result_checked_) {
vector<const Node *> nodes;
int total_wcost = 0;
for (const Node *node = begin_node_->next;
node != end_node_; node = node->next) {
nodes.push_back(node);
if (node != begin_node_->next) {
total_wcost += node->wcost;
}
}
DCHECK(!nodes.empty());
const int cost = end_node_->cost -
begin_node_->cost - end_node_->wcost;
const int structure_cost = end_node_->prev->cost -
begin_node_->next->cost - total_wcost;
const int wcost = end_node_->prev->cost -
begin_node_->next->cost + begin_node_->next->wcost;
MakeCandidate(candidate, cost, structure_cost, wcost, nodes);
if (request_type == Segments::SUGGESTION) {
candidate->attributes |= Segment::Candidate::REALTIME_CONVERSION;
}
// Use CandiadteFilter so that filter is initialized with the
// Viterbi-best path.
viterbi_result_checked_ = true;
switch (filter_->FilterCandidate(candidate, nodes)) {
case CandidateFilter::GOOD_CANDIDATE:
return true;
case CandidateFilter::STOP_ENUMERATION:
return false;
// Viterbi best result was tried to be inserted but reverted.
case CandidateFilter::BAD_CANDIDATE:
default:
// do nothing
break;
}
}
const int KMaxTrial = 500;
int num_trials = 0;
while (!agenda_->empty()) {
const QueueElement *top = agenda_->top();
DCHECK(top);
agenda_->pop();
const Node *rnode = top->node;
CHECK(rnode);
if (num_trials++ > KMaxTrial) { // too many trials
VLOG(2) << "too many trials: " << num_trials;
return false;
}
// reached to the goal.
if (rnode->end_pos == begin_node_->end_pos) {
vector<const Node *> nodes;
for (const QueueElement *elm = top->next;
elm->next != NULL; elm = elm->next) {
nodes.push_back(elm->node);
}
CHECK(!nodes.empty());
MakeCandidate(candidate, top->gx, top->structure_gx, top->w_gx, nodes);
switch (filter_->FilterCandidate(candidate, nodes)) {
case CandidateFilter::GOOD_CANDIDATE:
return true;
case CandidateFilter::STOP_ENUMERATION:
return false;
case CandidateFilter::BAD_CANDIDATE:
default:
break;
// do nothing
}
} else {
const QueueElement *best_left_elm = NULL;
const bool is_right_edge = rnode->begin_pos == end_node_->begin_pos;
const bool is_left_edge = rnode->begin_pos == begin_node_->end_pos;
for (Node *lnode = lattice_->end_nodes(rnode->begin_pos);
lnode != NULL; lnode = lnode->enext) {
// is_edge is true if current lnode/rnode has same boundary as
// begin/end node regardless of its value.
DCHECK(!(is_right_edge && is_left_edge));
const bool is_edge = (is_right_edge || is_left_edge);
// is_boundary is true if there is a grammer-based boundary
// between lnode and rnode
const bool is_boundary = (lnode->node_type == Node::HIS_NODE ||
segmenter_->IsBoundary(lnode, rnode,
is_prediction_));
// is_valid_boudnary is true if the word connection from
// lnode to rnode has a gramatically correct relation.
const bool is_valid_boundary =
(lnode->node_type == Node::CON_NODE ||
rnode->node_type == Node::CON_NODE ||
(rnode->attributes & Node::WEAK_CONNECTED) ||
(is_edge && is_boundary) || // on the edge, have a boudnary.
(!is_boundary && !is_edge)); // not on the edge, not the case.
// is_valid_cost is true if the left node is valid
// in terms of cost. if left_node is left edge, there
// is a cost-based constraint.
const bool is_valid_cost =
(!is_left_edge ||
(is_left_edge && (begin_node_->cost - lnode->cost) <= kCostDiff));
// is_invalid_position is true if the lnode's location is invalid
// 1. |<-- begin_node_-->|
// |<--lnode-->| <== exceeds begin_node.
// 2. |<-- begin_node_-->|
// |<--lnode-->| <== overlapped.
const bool is_valid_position =
!(lnode->end_pos < begin_node_->end_pos || // case(1)
(lnode->begin_pos < begin_node_->end_pos && // case(2)
begin_node_->end_pos < lnode->end_pos));
// can_expand_more is true if we can expand candidates from
// |rnode| to |lnode|.
const bool can_expand_more =
(is_valid_boundary && is_valid_cost && is_valid_position);
if (can_expand_more) {
const int transition_cost = GetTransitionCost(lnode, rnode);
// How likely the costs get increased after expanding rnode.
int cost_diff = 0;
int structure_cost_diff = 0;
int wcost_diff = 0;
if (is_right_edge) {
// use |rnode->cost - end_node_->cost| is an approximation
// of marginalized word cost.
cost_diff = transition_cost + (rnode->cost - end_node_->cost);
structure_cost_diff = 0;
wcost_diff = 0;
} else if (is_left_edge) {
// use |lnode->cost - begin_node_->cost| is an approximation
// of marginalized word cost.
cost_diff = (lnode->cost - begin_node_->cost) +
transition_cost + rnode->wcost;
structure_cost_diff = 0;
wcost_diff = rnode->wcost;
} else {
// use rnode->wcost.
cost_diff = transition_cost + rnode->wcost;
structure_cost_diff = transition_cost;
wcost_diff = transition_cost + rnode->wcost;
}
if (rnode->attributes & Node::WEAK_CONNECTED) {
const int kWeakConnectedPenalty = 3453; // log prob of 1/1000
cost_diff += kWeakConnectedPenalty;
structure_cost_diff += kWeakConnectedPenalty / 2;
wcost_diff += kWeakConnectedPenalty / 2;
}
QueueElement *elm = freelist_.Alloc();
DCHECK(elm);
elm->node = lnode;
elm->gx = cost_diff + top->gx;
elm->structure_gx = structure_cost_diff + top->structure_gx;
elm->w_gx = wcost_diff + top->w_gx;
// |lnode->cost| is heuristics function of A* search, h(x).
// After Viterbi search, we already know an exact value of h(x).
elm->fx = lnode->cost + elm->gx;
elm->next = top;
if (is_left_edge) {
// We only need to only 1 left node here.
// Even if expand all left nodes, all the |value| part should
// be identical. Here, we simply user the best left edge node.
// This hack reduces the number of redundant calls of pop().
if (best_left_elm == NULL || best_left_elm->fx > elm->fx) {
best_left_elm = elm;
}
} else {
agenda_->push(elm);
}
}
}
if (best_left_elm != NULL) {
agenda_->push(best_left_elm);
}
}
}
return false;
}
int NBestGenerator::GetTransitionCost(const Node *lnode,
const Node *rnode) const {
const int kInvalidPenaltyCost = 100000;
if (rnode->constrained_prev != NULL && lnode != rnode->constrained_prev) {
return kInvalidPenaltyCost;
}
return connector_->GetTransitionCost(lnode->rid, rnode->lid);
}
} // namespace mozc