Skip to content

Commit

Permalink
perf(poet): optimize for performance in making sentences (~40% faster)
Browse files Browse the repository at this point in the history
  • Loading branch information
lotem committed Jul 21, 2020
1 parent 44dd002 commit 0853465
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 121 deletions.
12 changes: 7 additions & 5 deletions src/rime/gear/contextual_translation.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <algorithm>
#include <iterator>
#include <rime/gear/contextual_translation.h>
#include <rime/gear/grammar.h>
#include <rime/gear/translator_commons.h>

namespace rime {
Expand Down Expand Up @@ -37,12 +38,13 @@ bool ContextualTranslation::Replenish() {
}

an<Phrase> ContextualTranslation::Evaluate(an<Phrase> phrase) {
auto sentence = New<Sentence>(phrase->language());
sentence->Offset(phrase->start());
bool is_rear = phrase->end() == input_.length();
sentence->Extend(phrase->entry(), phrase->end(), is_rear, preceding_text_,
grammar_);
phrase->set_weight(sentence->weight());
double weight = Grammar::Evaluate(preceding_text_,
phrase->text(),
phrase->weight(),
is_rear,
grammar_);
phrase->set_weight(weight);
DLOG(INFO) << "contextual suggestion: " << phrase->text()
<< " weight: " << phrase->weight();
return phrase;
Expand Down
8 changes: 4 additions & 4 deletions src/rime/gear/grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#include <rime/common.h>
#include <rime/component.h>
#include <rime/dict/vocabulary.h>

namespace rime {

Expand All @@ -17,12 +16,13 @@ class Grammar : public Class<Grammar, Config*> {
bool is_rear) = 0;

inline static double Evaluate(const string& context,
const DictEntry& entry,
const string& entry_text,
double entry_weight,
bool is_rear,
Grammar* grammar) {
const double kPenalty = -18.420680743952367; // log(1e-8)
return entry.weight +
(grammar ? grammar->Query(context, entry.text, is_rear) : kPenalty);
return entry_weight +
(grammar ? grammar->Query(context, entry_text, is_rear) : kPenalty);
}
};

Expand Down
234 changes: 152 additions & 82 deletions src/rime/gear/poet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,64 @@

namespace rime {

// internal data structure used during the sentence making process.
// the output line of the algorithm is transformed to an<Sentence>.
struct Line {
// be sure the pointer to predecessor Line object is stable. it works since
// pointer to values stored in std::map and std::unordered_map are stable.
const Line* predecessor;
// as long as the word graph lives, pointers to entries are valid.
const DictEntry* entry;
size_t end_pos;
double weight;

static const Line kEmpty;

bool empty() const {
return !predecessor && !entry;
}

string last_word() const {
return entry ? entry->text : string();
}

struct Components {
vector<const Line*> lines;

Components(const Line* line) {
for (const Line* cursor = line;
!cursor->empty();
cursor = cursor->predecessor) {
lines.push_back(cursor);
}
}

decltype(lines.crbegin()) begin() const { return lines.crbegin(); }
decltype(lines.crend()) end() const { return lines.crend(); }
};

Components components() const { return Components(this); }

string context() const {
// look back 2 words
return empty() ? string() :
!predecessor || predecessor->empty() ? last_word() :
predecessor->last_word() + last_word();
}

vector<size_t> word_lengths() const {
vector<size_t> lengths;
size_t last_end_pos = 0;
for (const auto* c : components()) {
lengths.push_back(c->end_pos - last_end_pos);
last_end_pos = c->end_pos;
}
return lengths;
}
};

const Line Line::kEmpty{nullptr, nullptr, 0, 0.0};

inline static Grammar* create_grammar(Config* config) {
if (auto* grammar = Grammar::Require("grammar")) {
return grammar->Create(config);
Expand All @@ -30,102 +88,103 @@ Poet::Poet(const Language* language, Config* config, Compare compare)

Poet::~Poet() {}

bool Poet::LeftAssociateCompare(const Sentence& one, const Sentence& other) {
return one.weight() < other.weight() || ( // left associate if even
one.weight() == other.weight() && (
one.size() > other.size() || ( // less components is more favorable
one.size() == other.size() &&
std::lexicographical_compare(one.syllable_lengths().begin(),
one.syllable_lengths().end(),
other.syllable_lengths().begin(),
other.syllable_lengths().end()))));
bool Poet::CompareWeight(const Line& one, const Line& other) {
return one.weight < other.weight;
}

// returns true if one is less than other.
bool Poet::LeftAssociateCompare(const Line& one, const Line& other) {
if (one.weight < other.weight) return true;
if (one.weight == other.weight) {
auto one_word_lens = one.word_lengths();
auto other_word_lens = other.word_lengths();
// less words is more favorable
if (one_word_lens.size() > other_word_lens.size()) return true;
if (one_word_lens.size() == other_word_lens.size()) {
return std::lexicographical_compare(
one_word_lens.begin(), one_word_lens.end(),
other_word_lens.begin(), other_word_lens.end());
}
}
return false;
}

// keep the best sentence candidate per last phrase
using SentenceCandidates = hash_map<string, of<Sentence>>;
// keep the best line candidate per last phrase
using LineCandidates = hash_map<string, Line>;

template <int N>
static vector<of<Sentence>> find_top_candidates(
const SentenceCandidates& candidates, Poet::Compare compare) {
vector<of<Sentence>> top;
static vector<const Line*> find_top_candidates(
const LineCandidates& candidates, Poet::Compare compare) {
vector<const Line*> top;
top.reserve(N + 1);
for (const auto& candidate : candidates) {
auto pos = std::upper_bound(
top.begin(), top.end(), candidate.second,
[&](const an<Sentence>& a, const an<Sentence>& b) {
return !compare(*a, *b); // desc
});
top.begin(), top.end(), &candidate.second,
[&](const Line* a, const Line* b) { return compare(*b, *a); }); // desc
if (pos - top.begin() >= N) continue;
top.insert(pos, candidate.second);
top.insert(pos, &candidate.second);
if (top.size() > N) top.pop_back();
}
return top;
}

static an<Sentence> find_best_sentence(const SentenceCandidates& candidates,
Poet::Compare compare) {
an<Sentence> best = nullptr;
for (const auto& candidate : candidates) {
if (!best || compare(*best, *candidate.second)) {
best = candidate.second;
}
}
return best;
}

using UpdateSetenceCandidate = function<void (const an<Sentence>& candidate)>;
using UpdateLineCandidate = function<void (const Line& candidate)>;

struct BeamSearch {
using State = SentenceCandidates;
using State = LineCandidates;

static constexpr int kMaxSentenceCandidates = 7;
static constexpr int kMaxLineCandidates = 7;

static void Initiate(State& initial_state, const Language* language) {
initial_state.emplace("", New<Sentence>(language));
static void Initiate(State& initial_state) {
initial_state.emplace("", Line::kEmpty);
}

static void ForEachCandidate(const State& state,
Poet::Compare compare,
UpdateSetenceCandidate update) {
UpdateLineCandidate update) {
auto top_candidates =
find_top_candidates<kMaxSentenceCandidates>(state, compare);
for (const auto& candidate : top_candidates) {
update(candidate);
find_top_candidates<kMaxLineCandidates>(state, compare);
for (const auto* candidate : top_candidates) {
update(*candidate);
}
}

static an<Sentence>& BestSentenceToUpdate(State& state,
const an<Sentence>& new_sentence) {
const auto& key = new_sentence->components().back().text;
static Line& BestLineToUpdate(State& state, const Line& new_line) {
const auto& key = new_line.last_word();
return state[key];
}

static an<Sentence> BestSentence(const State& final_state,
Poet::Compare compare) {
return find_best_sentence(final_state, compare);
static const Line& BestLineInState(const State& final_state,
Poet::Compare compare) {
const Line* best = nullptr;
for (const auto& candidate : final_state) {
if (!best || compare(*best, candidate.second)) {
best = &candidate.second;
}
}
return best ? *best : Line::kEmpty;
}
};

struct DynamicProgramming {
using State = an<Sentence>;
using State = Line;

static void Initiate(State& initial_state, const Language* language) {
initial_state = New<Sentence>(language);
static void Initiate(State& initial_state) {
initial_state = Line::kEmpty;
}

static void ForEachCandidate(const State& state,
Poet::Compare compare,
UpdateSetenceCandidate update) {
UpdateLineCandidate update) {
update(state);
}

static an<Sentence>& BestSentenceToUpdate(State& state,
const an<Sentence>& new_sentence) {
static Line& BestLineToUpdate(State& state, const Line& new_line) {
return state;
}

static an<Sentence> BestSentence(const State& final_state,
Poet::Compare compare) {
static const Line& BestLineInState(const State& final_state,
Poet::Compare compare) {
return final_state;
}
};
Expand All @@ -134,47 +193,58 @@ template <class Strategy>
an<Sentence> Poet::MakeSentenceWithStrategy(const WordGraph& graph,
size_t total_length,
const string& preceding_text) {
map<int, typename Strategy::State> sentences;
Strategy::Initiate(sentences[0], language_);
for (const auto& w : graph) {
size_t start_pos = w.first;
if (sentences.find(start_pos) == sentences.end())
map<int, typename Strategy::State> states;
Strategy::Initiate(states[0]);
for (const auto& sv : graph) {
size_t start_pos = sv.first;
if (states.find(start_pos) == states.end())
continue;
DLOG(INFO) << "start pos: " << start_pos;
const auto& source(sentences[start_pos]);
Strategy::ForEachCandidate(
source, compare_,
[&](const an<Sentence>& candidate) {
for (const auto& x : w.second) {
size_t end_pos = x.first;
const auto& source_state = states[start_pos];
const auto update =
[this, &states, &sv, start_pos, total_length, &preceding_text]
(const Line& candidate) {
for (const auto& ev : sv.second) {
size_t end_pos = ev.first;
if (start_pos == 0 && end_pos == total_length)
continue; // exclude single words from the result
continue; // exclude single word from the result
DLOG(INFO) << "end pos: " << end_pos;
bool is_rear = end_pos == total_length;
auto& target(sentences[end_pos]);
auto& target_state = states[end_pos];
// extend candidates with dict entries on a valid edge.
const DictEntryList& entries(x.second);
const DictEntryList& entries = ev.second;
for (const auto& entry : entries) {
auto new_sentence = New<Sentence>(*candidate);
new_sentence->Extend(
*entry, end_pos, is_rear, preceding_text, grammar_.get());
auto& best_sentence =
Strategy::BestSentenceToUpdate(target, new_sentence);
if (!best_sentence || compare_(*best_sentence, *new_sentence)) {
DLOG(INFO) << "updated sentences " << end_pos << ") with "
<< new_sentence->text() << " weight: "
<< new_sentence->weight();
best_sentence = std::move(new_sentence);
const string& context =
candidate.empty() ? preceding_text : candidate.context();
double weight = candidate.weight +
Grammar::Evaluate(context,
entry->text,
entry->weight,
is_rear,
grammar_.get());
Line new_line{&candidate, entry.get(), end_pos, weight};
Line& best = Strategy::BestLineToUpdate(target_state, new_line);
if (best.empty() || compare_(best, new_line)) {
DLOG(INFO) << "updated line ending at " << end_pos
<< " with text: ..." << new_line.last_word()
<< " weight: " << new_line.weight;
best = new_line;
}
}
}
});
};
Strategy::ForEachCandidate(source_state, compare_, update);
}
auto found = sentences.find(total_length);
if (found == sentences.end())
auto found = states.find(total_length);
if (found == states.end() || found->second.empty())
return nullptr;
else
return Strategy::BestSentence(found->second, compare_);
const Line& best = Strategy::BestLineInState(found->second, compare_);
auto sentence = New<Sentence>(language_);
for (const auto* c : best.components()) {
if (!c->entry) continue;
sentence->Extend(*c->entry, c->end_pos, c->weight);
}
return sentence;
}

an<Sentence> Poet::MakeSentence(const WordGraph& graph,
Expand Down
11 changes: 5 additions & 6 deletions src/rime/gear/poet.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@ using WordGraph = map<int, UserDictEntryCollector>;

class Grammar;
class Language;
struct Line;

class Poet {
public:
// sentence "less", used to compare sentences of the same input range.
using Compare = function<bool (const Sentence&, const Sentence&)>;
// Line "less", used to compare composed line of the same input range.
using Compare = function<bool (const Line&, const Line&)>;

static bool CompareWeight(const Sentence& one, const Sentence& other) {
return one.weight() < other.weight();
}
static bool LeftAssociateCompare(const Sentence& one, const Sentence& other);
static bool CompareWeight(const Line& one, const Line& other);
static bool LeftAssociateCompare(const Line& one, const Line& other);

Poet(const Language* language, Config* config,
Compare compare = CompareWeight);
Expand Down
Loading

0 comments on commit 0853465

Please sign in to comment.