// Copyright 2010-2012, Google Inc.
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "converter/nbest_generator.h"
#include <algorithm>
#include <string>
#include "base/base.h"
#include "base/logging.h"
#include "converter/candidate_filter.h"
#include "converter/connector_interface.h"
#include "converter/lattice.h"
#include "converter/segmenter_interface.h"
#include "converter/segments.h"
#include "dictionary/pos_matcher.h"
namespace mozc {
namespace {
const int kFreeListSize = 512;
const int kCostDiff = 3453;
} // namespace
using converter::CandidateFilter;
struct NBestGenerator::QueueElement {
const Node *node;
const QueueElement *next;
int32 fx; // f(x) = h(x) + g(x): cost function for A* search
int32 gx; // g(x)
// transition cost part of g(x).
// Do not take the transition costs to edge nodes.
int32 structure_gx;
int32 w_gx;
};
struct NBestGenerator::QueueElementComparator {
bool operator()(const NBestGenerator::QueueElement *q1,
const NBestGenerator::QueueElement *q2) const {
return (q1->fx > q2->fx);
}
};
inline void NBestGenerator::Agenda::Push(
const NBestGenerator::QueueElement *element) {
priority_queue_.push_back(element);
push_heap(priority_queue_.begin(), priority_queue_.end(),
QueueElementComparator());
}
inline void NBestGenerator::Agenda::Pop() {
DCHECK(!priority_queue_.empty());
pop_heap(priority_queue_.begin(), priority_queue_.end(),
QueueElementComparator());
priority_queue_.pop_back();
}
NBestGenerator::NBestGenerator(const SuppressionDictionary *suppression_dic,
const SegmenterInterface *segmenter,
const ConnectorInterface *connector,
const POSMatcher *pos_matcher,
const Lattice *lattice,
bool is_prediction)
: suppression_dictionary_(suppression_dic),
segmenter_(segmenter), connector_(connector), pos_matcher_(pos_matcher),
lattice_(lattice),
begin_node_(NULL), end_node_(NULL),
freelist_(kFreeListSize),
filter_(new CandidateFilter(suppression_dic, pos_matcher)),
viterbi_result_checked_(false),
is_prediction_(is_prediction) {
DCHECK(suppression_dictionary_);
DCHECK(segmenter);
DCHECK(connector);
if (lattice_ == NULL || !lattice_->has_lattice()) {
LOG(ERROR) << "lattice is not available";
return;
}
agenda_.Reserve(kFreeListSize);
}
NBestGenerator::~NBestGenerator() {
}
void NBestGenerator::Reset(const Node *begin_node, const Node *end_node) {
agenda_.Clear();
freelist_.Free();
filter_->Reset();
viterbi_result_checked_ = false;
begin_node_ = begin_node;
end_node_ = end_node;
for (Node *node = lattice_->begin_nodes(end_node_->begin_pos);
node != NULL; node = node->bnext) {
if (node == end_node_ ||
(node->lid != end_node_->lid &&
node->cost - end_node_->cost <= kCostDiff &&
node->prev != end_node_->prev)) {
QueueElement *eos = freelist_.Alloc();
DCHECK(eos);
eos->node = node;
eos->next = NULL;
eos->fx = node->cost;
eos->gx = 0;
eos->structure_gx = 0;
eos->w_gx = 0;
agenda_.Push(eos);
}
}
}
void NBestGenerator::MakeCandidate(Segment::Candidate *candidate,
int32 cost, int32 structure_cost,
int32 wcost,
const vector<const Node *> &nodes) const {
CHECK(!nodes.empty());
candidate->Init();
candidate->lid = nodes.front()->lid;
candidate->rid = nodes.back()->rid;
candidate->cost = cost;
candidate->structure_cost = structure_cost;
candidate->wcost = wcost;
bool is_functional = false;
for (size_t i = 0; i < nodes.size(); ++i) {
const Node *node = nodes[i];
DCHECK(node != NULL);
if (!is_functional && !pos_matcher_->IsFunctional(node->lid)) {
candidate->content_value += node->value;
candidate->content_key += node->key;
} else {
is_functional = true;
}
candidate->key += node->key;
candidate->value += node->value;
if (node->constrained_prev != NULL ||
(node->next != NULL && node->next->constrained_prev == node)) {
// If result has constrained_node, set CONTEXT_SENSITIVE.
// If a node has constrained node, the node is generated by
// a) compound node and resegmented via personal name resegmentation
// b) compound-based reranking.
candidate->attributes |= Segment::Candidate::CONTEXT_SENSITIVE;
}
if (node->attributes & Node::SPELLING_CORRECTION) {
candidate->attributes |= Segment::Candidate::SPELLING_CORRECTION;
}
if (node->attributes & Node::NO_VARIANTS_EXPANSION) {
candidate->attributes |= Segment::Candidate::NO_VARIANTS_EXPANSION;
}
if (node->attributes & Node::USER_DICTIONARY) {
candidate->attributes |= Segment::Candidate::USER_DICTIONARY;
}
}
if (candidate->content_value.empty() || candidate->content_key.empty()) {
candidate->content_value = candidate->value;
candidate->content_key = candidate->key;
}
}
bool NBestGenerator::Next(Segment::Candidate *candidate,
Segments::RequestType request_type) {
DCHECK(begin_node_);
DCHECK(end_node_);
DCHECK(candidate);
if (lattice_ == NULL || !lattice_->has_lattice()) {
LOG(ERROR) << "Must create lattice in advance";
return false;
}
// |cost| and |structure_cost| are calculated as follows:
//
// Example:
// |left_node| => |node1| => |node2| => |node3| => |right_node|.
// |node1| .. |node2| consists of a candidate.
//
// cost = (left_node->cost - begin_node_->cost) +
// trans(left_node, node1) + node1->wcost +
// trans(node1, node2) + node2->wcost +
// trans(node2, node3) + node3->wcost +
// trans(node3, rigt_node) +
// (right_node->cost - end_node_->cost)
// structure_cost = trans(node1, node2) + trans(node2, node3);
// wcost = node1->wcost +
// trans(node1, node2) + node2->wcost +
// trans(node2, node3) + node3->wcost
//
// Here (left_node->cost - begin_node_->cost) and
// (right_node->cost - end_node->cost) act as an approximation
// of marginalized costs of the candidate |node1| .. |node3|.
// "marginalized cost" means that how likely the left_node or right_node
// are selected by taking the all paths encoded in the lattice.
// These approximated costs are exactly 0 when taking Viterbi-best
// path.
// Insert Viterbi best result here to make sure that
// the top result is Viterbi best result.
if (!viterbi_result_checked_) {
// Use CandiadteFilter so that filter is initialized with the
// Viterbi-best path.
switch (InsertTopResult(candidate, request_type)) {
case CandidateFilter::GOOD_CANDIDATE:
return true;
case CandidateFilter::STOP_ENUMERATION:
return false;
// Viterbi best result was tried to be inserted but reverted.
case CandidateFilter::BAD_CANDIDATE:
default:
// do nothing
break;
}
}
const int KMaxTrial = 500;
int num_trials = 0;
while (!agenda_.IsEmpty()) {
const QueueElement *top = agenda_.Top();
DCHECK(top);
agenda_.Pop();
const Node *rnode = top->node;
CHECK(rnode);
if (num_trials++ > KMaxTrial) { // too many trials
VLOG(2) << "too many trials: " << num_trials;
return false;
}
// reached to the goal.
if (rnode->end_pos == begin_node_->end_pos) {
nodes_.clear();
for (const QueueElement *elm = top->next;
elm->next != NULL; elm = elm->next) {
nodes_.push_back(elm->node);
}
CHECK(!nodes_.empty());
MakeCandidate(candidate, top->gx, top->structure_gx, top->w_gx, nodes_);
int filter_result = filter_->FilterCandidate(candidate, nodes_);
nodes_.clear();
switch (filter_result) {
case CandidateFilter::GOOD_CANDIDATE:
return true;
case CandidateFilter::STOP_ENUMERATION:
return false;
case CandidateFilter::BAD_CANDIDATE:
default:
break;
// do nothing
}
} else {
const QueueElement *best_left_elm = NULL;
const bool is_right_edge = rnode->begin_pos == end_node_->begin_pos;
const bool is_left_edge = rnode->begin_pos == begin_node_->end_pos;
DCHECK(!(is_right_edge && is_left_edge));
// is_edge is true if current lnode/rnode has same boundary as
// begin/end node regardless of its value.
const bool is_edge = (is_right_edge || is_left_edge);
for (Node *lnode = lattice_->end_nodes(rnode->begin_pos);
lnode != NULL; lnode = lnode->enext) {
// is_invalid_position is true if the lnode's location is invalid
// 1. |<-- begin_node_-->|
// |<--lnode-->| <== overlapped.
//
// 2. |<-- begin_node_-->|
// |<--lnode-->| <== exceeds begin_node.
// This case can't be happened because the |rnode| is always at just
// right of the |lnode|. By avoiding case1, this can't be happen.
// 2'. |<-- begin_node_-->|
// |<--lnode-->||<--rnode-->|
const bool is_valid_position =
!((lnode->begin_pos < begin_node_->end_pos &&
begin_node_->end_pos < lnode->end_pos));
if (!is_valid_position) {
continue;
}
// is_valid_cost is true if the left node is valid
// in terms of cost. if left_node is left edge, there
// is a cost-based constraint.
const bool is_valid_cost =
(!is_left_edge ||
(is_left_edge && (begin_node_->cost - lnode->cost) <= kCostDiff));
if (!is_valid_cost) {
continue;
}
if (!(rnode->node_type == Node::CON_NODE ||
(rnode->attributes & Node::WEAK_CONNECTED) ||
lnode->node_type == Node::CON_NODE)) {
// is_boundary is true if there is a grammer-based boundary
// between lnode and rnode
const bool is_boundary = (lnode->node_type == Node::HIS_NODE ||
segmenter_->IsBoundary(lnode, rnode,
is_prediction_));
if (is_edge != is_boundary) {
// on the edge, have a boudnary.
// not on the edge, not the case.
continue;
}
}
// We can expand candidates from |rnode| to |lnode|.
const int transition_cost = GetTransitionCost(lnode, rnode);
// How likely the costs get increased after expanding rnode.
int cost_diff = 0;
int structure_cost_diff = 0;
int wcost_diff = 0;
if (is_right_edge) {
// use |rnode->cost - end_node_->cost| is an approximation
// of marginalized word cost.
cost_diff = transition_cost + (rnode->cost - end_node_->cost);
structure_cost_diff = 0;
wcost_diff = 0;
} else if (is_left_edge) {
// use |lnode->cost - begin_node_->cost| is an approximation
// of marginalized word cost.
cost_diff = (lnode->cost - begin_node_->cost) +
transition_cost + rnode->wcost;
structure_cost_diff = 0;
wcost_diff = rnode->wcost;
} else {
// use rnode->wcost.
cost_diff = transition_cost + rnode->wcost;
structure_cost_diff = transition_cost;
wcost_diff = transition_cost + rnode->wcost;
}
if (rnode->attributes & Node::WEAK_CONNECTED) {
const int kWeakConnectedPenalty = 3453; // log prob of 1/1000
cost_diff += kWeakConnectedPenalty;
structure_cost_diff += kWeakConnectedPenalty / 2;
wcost_diff += kWeakConnectedPenalty / 2;
}
QueueElement *elm = freelist_.Alloc();
DCHECK(elm);
elm->node = lnode;
elm->gx = cost_diff + top->gx;
elm->structure_gx = structure_cost_diff + top->structure_gx;
elm->w_gx = wcost_diff + top->w_gx;
// |lnode->cost| is heuristics function of A* search, h(x).
// After Viterbi search, we already know an exact value of h(x).
elm->fx = lnode->cost + elm->gx;
elm->next = top;
if (is_left_edge) {
// We only need to only 1 left node here.
// Even if expand all left nodes, all the |value| part should
// be identical. Here, we simply use the best left edge node.
// This hack reduces the number of redundant calls of pop().
if (best_left_elm == NULL || best_left_elm->fx > elm->fx) {
best_left_elm = elm;
}
} else {
agenda_.Push(elm);
}
}
if (best_left_elm != NULL) {
agenda_.Push(best_left_elm);
}
}
}
return false;
}
int NBestGenerator::InsertTopResult(Segment::Candidate *candidate,
Segments::RequestType request_type) {
nodes_.clear();
int total_wcost = 0;
for (const Node *node = begin_node_->next;
node != end_node_; node = node->next) {
nodes_.push_back(node);
if (node != begin_node_->next) {
total_wcost += node->wcost;
}
}
DCHECK(!nodes_.empty());
const int cost = end_node_->cost -
begin_node_->cost - end_node_->wcost;
const int structure_cost = end_node_->prev->cost -
begin_node_->next->cost - total_wcost;
const int wcost = end_node_->prev->cost -
begin_node_->next->cost + begin_node_->next->wcost;
MakeCandidate(candidate, cost, structure_cost, wcost, nodes_);
if (request_type == Segments::SUGGESTION) {
candidate->attributes |= Segment::Candidate::REALTIME_CONVERSION;
}
viterbi_result_checked_ = true;
int result = filter_->FilterCandidate(candidate, nodes_);
nodes_.clear();
return result;
}
int NBestGenerator::GetTransitionCost(const Node *lnode,
const Node *rnode) const {
const int kInvalidPenaltyCost = 100000;
if (rnode->constrained_prev != NULL && lnode != rnode->constrained_prev) {
return kInvalidPenaltyCost;
}
return connector_->GetTransitionCost(lnode->rid, rnode->lid);
}
} // namespace mozc