Skip to content

Commit

Permalink
feat(poet): find best sentence candidates
Browse files Browse the repository at this point in the history
  • Loading branch information
lotem committed Apr 20, 2019
1 parent 9934788 commit b3f4005
Showing 1 changed file with 61 additions and 23 deletions.
84 changes: 61 additions & 23 deletions src/rime/gear/poet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,81 @@ bool Poet::LeftAssociateCompare(const Sentence& one, const Sentence& other) {
other.syllable_lengths().end()))));
}

// keep the best sentence candidate per last phrase
using SentenceCandidates = hash_map<string, of<Sentence>>;

static vector<of<Sentence>> top_candidates(const SentenceCandidates& candidates,
size_t n,
Poet::Compare& compare) {
vector<of<Sentence>> top;
top.reserve(n + 1);
for (const auto& candidate : candidates) {
auto pos = std::upper_bound(
top.begin(), top.end(), candidate.second,
[&](const an<Sentence>& a, const an<Sentence>& b) {
return !compare(*a, *b); // desc
});
if (pos - top.begin() >= n) continue;
top.insert(pos, candidate.second);
if (top.size() > n) top.pop_back();
}
return top;
}

an<Sentence> find_best_sentence(const SentenceCandidates& candidates,
Poet::Compare& compare) {
an<Sentence> best = nullptr;
for (const auto& candidate : candidates) {
if (!best || compare(*best, *candidate.second)) {
best = candidate.second;
}
}
return best;
}

constexpr int kMaxSentenceCandidates = 7;

an<Sentence> Poet::MakeSentence(const WordGraph& graph,
size_t total_length,
const string& preceding_text) {
// TODO: save more intermediate sentence candidates
map<int, an<Sentence>> sentences;
sentences[0] = New<Sentence>(language_);
// dynamic programming
map<int, SentenceCandidates> sentences;
sentences[0].emplace("", New<Sentence>(language_));
for (const auto& w : graph) {
size_t start_pos = w.first;
DLOG(INFO) << "start pos: " << start_pos;
if (sentences.find(start_pos) == sentences.end())
continue;
for (const auto& x : w.second) {
size_t end_pos = x.first;
if (start_pos == 0 && end_pos == total_length)
continue; // exclude single words from the result
DLOG(INFO) << "end pos: " << end_pos;
bool is_rear = end_pos == total_length;
const DictEntryList& entries(x.second);
for (const auto& entry : entries) {
auto new_sentence = New<Sentence>(*sentences[start_pos]);
new_sentence->Extend(
*entry, end_pos, is_rear, preceding_text, grammar_.get());
if (sentences.find(end_pos) == sentences.end() ||
compare_(*sentences[end_pos], *new_sentence)) {
DLOG(INFO) << "updated sentences " << end_pos << ") with "
<< new_sentence->text() << " weight: "
<< new_sentence->weight();
sentences[end_pos] = std::move(new_sentence);
DLOG(INFO) << "start pos: " << start_pos;
auto top = top_candidates(
sentences[start_pos], kMaxSentenceCandidates, compare_);
for (const auto& candidate : top) {
for (const auto& x : w.second) {
size_t end_pos = x.first;
if (start_pos == 0 && end_pos == total_length)
continue; // exclude single words from the result
DLOG(INFO) << "end pos: " << end_pos;
bool is_rear = end_pos == total_length;
auto& target(sentences[end_pos]);
const DictEntryList& entries(x.second);
for (const auto& entry : entries) {
auto new_sentence = New<Sentence>(*candidate);
new_sentence->Extend(
*entry, end_pos, is_rear, preceding_text, grammar_.get());
const auto& key = new_sentence->components().back().text;
auto& best_sentence = target[key];
if (!best_sentence || compare_(*best_sentence, *new_sentence)) {
DLOG(INFO) << "updated sentences " << end_pos << ") with "
<< new_sentence->text() << " weight: "
<< new_sentence->weight();
best_sentence = std::move(new_sentence);
}
}
}
}
}
if (sentences.find(total_length) == sentences.end())
return nullptr;
else
return sentences[total_length];
return find_best_sentence(sentences[total_length], compare_);
}

} // namespace rime

0 comments on commit b3f4005

Please sign in to comment.