Skip to content

Commit

Permalink
feat(translator): contextual suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
lotem authored Apr 18, 2019
2 parents 97220ce + 12a7501 commit 9934788
Show file tree
Hide file tree
Showing 13 changed files with 270 additions and 75 deletions.
3 changes: 3 additions & 0 deletions src/rime/commit_history.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class CommitHistory : public list<CommitRecord> {
void Push(const KeyEvent& key_event);
void Push(const Composition& composition, const string& input);
string repr() const;
string latest_text() const {
return empty() ? string() : back().text;
}
};

} // Namespace rime
Expand Down
13 changes: 13 additions & 0 deletions src/rime/composition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// 2011-06-19 GONG Chen <[email protected]>
//
#include <boost/algorithm/string.hpp>
#include <boost/range/adaptor/reversed.hpp>
#include <rime/candidate.h>
#include <rime/composition.h>
#include <rime/menu.h>
Expand Down Expand Up @@ -167,4 +168,16 @@ string Composition::GetDebugText() const {
return result;
}

string Composition::GetTextBefore(size_t pos) const {
if (empty()) return string();
for (const auto& seg : boost::adaptors::reverse(*this)) {
if (seg.end <= pos) {
if (auto cand = seg.GetSelectedCandidate()) {
return cand->text();
}
}
}
return string();
}

} // namespace rime
2 changes: 2 additions & 0 deletions src/rime/composition.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class Composition : public Segmentation {
string GetCommitText() const;
string GetScriptText() const;
string GetDebugText() const;
// Returns text of the last segment before the given position.
string GetTextBefore(size_t pos) const;
};

} // namespace rime
Expand Down
60 changes: 60 additions & 0 deletions src/rime/gear/contextual_translation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include <algorithm>
#include <iterator>
#include <rime/gear/contextual_translation.h>
#include <rime/gear/translator_commons.h>

namespace rime {

const int kContextualSearchLimit = 32;

bool ContextualTranslation::Replenish() {
vector<of<Phrase>> queue;
size_t end_pos = 0;
while (!translation_->exhausted() &&
cache_.size() + queue.size() < kContextualSearchLimit) {
auto cand = translation_->Peek();
DLOG(INFO) << cand->text() << " cache/queue: "
<< cache_.size() << "/" << queue.size();
if (cand->type() == "phrase" || cand->type() == "table") {
if (end_pos != cand->end()) {
end_pos = cand->end();
AppendToCache(queue);
}
queue.push_back(Evaluate(As<Phrase>(cand)));
} else {
AppendToCache(queue);
cache_.push_back(cand);
}
if (!translation_->Next()) {
break;
}
}
AppendToCache(queue);
return !cache_.empty();
}

an<Phrase> ContextualTranslation::Evaluate(an<Phrase> phrase) {
auto sentence = New<Sentence>(phrase->language());
sentence->Offset(phrase->start());
bool is_rear = phrase->end() == input_.length();
sentence->Extend(phrase->entry(), phrase->end(), is_rear, preceding_text_,
grammar_);
phrase->set_weight(sentence->weight());
DLOG(INFO) << "contextual suggestion: " << phrase->text()
<< " weight: " << phrase->weight();
return phrase;
}

static bool compare_by_weight_desc(const an<Phrase>& a, const an<Phrase>& b) {
return a->weight() > b->weight();
}

void ContextualTranslation::AppendToCache(vector<of<Phrase>>& queue) {
if (queue.empty()) return;
DLOG(INFO) << "appending to cache " << queue.size() << " candidates.";
std::sort(queue.begin(), queue.end(), compare_by_weight_desc);
std::copy(queue.begin(), queue.end(), std::back_inserter(cache_));
queue.clear();
}

} // namespace rime
38 changes: 38 additions & 0 deletions src/rime/gear/contextual_translation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//
// Copyright RIME Developers
// Distributed under the BSD License
//

#include <rime/common.h>
#include <rime/translation.h>

namespace rime {

class Candidate;
class Grammar;
class Phrase;

class ContextualTranslation : public PrefetchTranslation {
public:
ContextualTranslation(an<Translation> translation,
string input,
string preceding_text,
Grammar* grammar)
: PrefetchTranslation(translation),
input_(input),
preceding_text_(preceding_text),
grammar_(grammar) {}

protected:
bool Replenish() override;

private:
an<Phrase> Evaluate(an<Phrase> phrase);
void AppendToCache(vector<of<Phrase>>& queue);

string input_;
string preceding_text_;
Grammar* grammar_;
};

} // namespace rime
36 changes: 26 additions & 10 deletions src/rime/gear/poet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//
// 2011-10-06 GONG Chen <[email protected]>
//
#include <algorithm>
#include <functional>
#include <rime/candidate.h>
#include <rime/config.h>
#include <rime/dict/vocabulary.h>
Expand All @@ -21,14 +23,27 @@ inline static Grammar* create_grammar(Config* config) {
return nullptr;
}

Poet::Poet(const Language* language, Config* config)
Poet::Poet(const Language* language, Config* config, Compare compare)
: language_(language),
grammar_(create_grammar(config)) {}
grammar_(create_grammar(config)),
compare_(compare) {}

Poet::~Poet() {}

bool Poet::LeftAssociateCompare(const Sentence& one, const Sentence& other) {
return one.weight() < other.weight() || ( // left associate if even
one.weight() == other.weight() && (
one.size() > other.size() || ( // less components is more favorable
one.size() == other.size() &&
std::lexicographical_compare(one.syllable_lengths().begin(),
one.syllable_lengths().end(),
other.syllable_lengths().begin(),
other.syllable_lengths().end()))));
}

an<Sentence> Poet::MakeSentence(const WordGraph& graph,
size_t total_length) {
size_t total_length,
const string& preceding_text) {
// TODO: save more intermediate sentence candidates
map<int, an<Sentence>> sentences;
sentences[0] = New<Sentence>(language_);
Expand All @@ -43,16 +58,17 @@ an<Sentence> Poet::MakeSentence(const WordGraph& graph,
if (start_pos == 0 && end_pos == total_length)
continue; // exclude single words from the result
DLOG(INFO) << "end pos: " << end_pos;
bool is_rear = end_pos == total_length;
const DictEntryList& entries(x.second);
for (size_t i = 0; i < entries.size(); ++i) {
const auto& entry(entries[i]);
for (const auto& entry : entries) {
auto new_sentence = New<Sentence>(*sentences[start_pos]);
bool is_rear = end_pos == total_length;
new_sentence->Extend(*entry, end_pos, is_rear, grammar_.get());
new_sentence->Extend(
*entry, end_pos, is_rear, preceding_text, grammar_.get());
if (sentences.find(end_pos) == sentences.end() ||
sentences[end_pos]->weight() < new_sentence->weight()) {
DLOG(INFO) << "updated sentences " << end_pos << ") with '"
<< new_sentence->text() << "', " << new_sentence->weight();
compare_(*sentences[end_pos], *new_sentence)) {
DLOG(INFO) << "updated sentences " << end_pos << ") with "
<< new_sentence->text() << " weight: "
<< new_sentence->weight();
sentences[end_pos] = std::move(new_sentence);
}
}
Expand Down
36 changes: 33 additions & 3 deletions src/rime/gear/poet.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
#define RIME_POET_H_

#include <rime/common.h>
#include <rime/translation.h>
#include <rime/dict/user_dictionary.h>
#include <rime/gear/translator_commons.h>
#include <rime/gear/contextual_translation.h>

namespace rime {

Expand All @@ -23,14 +25,42 @@ class Language;

class Poet {
public:
Poet(const Language* language, Config* config);
// sentence "less", used to compare sentences of the same input range.
using Compare = function<bool (const Sentence&, const Sentence&)>;

static bool CompareWeight(const Sentence& one, const Sentence& other) {
return one.weight() < other.weight();
}
static bool LeftAssociateCompare(const Sentence& one, const Sentence& other);

Poet(const Language* language, Config* config,
Compare compare = CompareWeight);
~Poet();

an<Sentence> MakeSentence(const WordGraph& graph, size_t total_length);
an<Sentence> MakeSentence(const WordGraph& graph,
size_t total_length,
const string& preceding_text);

template <class TranslatorT>
an<Translation> ContextualWeighted(an<Translation> translation,
const string& input,
size_t start,
TranslatorT* translator) {
if (!translator->contextual_suggestions() || !grammar_) {
return translation;
}
auto preceding_text = translator->GetPrecedingText(start);
if (preceding_text.empty()) {
return translation;
}
return New<ContextualTranslation>(
translation, input, preceding_text, grammar_.get());
}

protected:
private:
const Language* language_;
the<Grammar> grammar_;
Compare compare_;
};

} // namespace rime
Expand Down
21 changes: 17 additions & 4 deletions src/rime/gear/script_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,11 @@ an<Translation> ScriptTranslator::Query(const string& input,
enable_user_dict ? user_dict_.get() : NULL)) {
return nullptr;
}
return New<DistinctTranslation>(result);
auto deduped = New<DistinctTranslation>(result);
if (contextual_suggestions_) {
return poet_->ContextualWeighted(deduped, input, segment.start, this);
}
return deduped;
}

string ScriptTranslator::FormatPreedit(const string& preedit) {
Expand All @@ -214,6 +218,12 @@ string ScriptTranslator::Spell(const Code& code) {
return result;
}

string ScriptTranslator::GetPrecedingText(size_t start) const {
return !contextual_suggestions_ ? string() :
start > 0 ? engine_->context()->composition().GetTextBefore(start) :
engine_->context()->commit_history().latest_text();
}

bool ScriptTranslator::Memorize(const CommitEntry& commit_entry) {
bool update_elements = false;
// avoid updating single character entries within a phrase which is
Expand Down Expand Up @@ -538,12 +548,15 @@ an<Sentence> ScriptTranslation::MakeSentence(Dictionary* dict,
}
}
}
auto sentence = poet_->MakeSentence(graph, syllable_graph.interpreted_length);
if (sentence) {
if (auto sentence =
poet_->MakeSentence(graph,
syllable_graph.interpreted_length,
translator_->GetPrecedingText(start_))) {
sentence->Offset(start_);
sentence->set_syllabifier(syllabifier_);
return sentence;
}
return sentence;
return nullptr;
}

} // namespace rime
1 change: 1 addition & 0 deletions src/rime/gear/script_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ScriptTranslator : public Translator,

string FormatPreedit(const string& preedit);
string Spell(const Code& code);
string GetPrecedingText(size_t start) const;

// options
int max_homophones() const { return max_homophones_; }
Expand Down
Loading

0 comments on commit 9934788

Please sign in to comment.