diff --git a/src/rime/algo/syllabifier.cc b/src/rime/algo/syllabifier.cc index adb140ccd..b87ec7963 100644 --- a/src/rime/algo/syllabifier.cc +++ b/src/rime/algo/syllabifier.cc @@ -7,10 +7,13 @@ // #include #include -#include #include +#include +#include +#include "syllabifier.h" namespace rime { +using namespace corrector; using Vertex = pair; using VertexQueue = std::priority_queuevertices.find(current_pos) == graph->vertices.end()) graph->vertices.insert(vertex); // preferred spelling type comes first - else + else { +// graph->vertices[current_pos] = std::min(vertex.second, graph->vertices[current_pos]); continue; // discard worse spelling types + } if (current_pos > farthest) farthest = current_pos; @@ -44,7 +49,25 @@ int Syllabifier::BuildSyllableGraph(const string &input, // see where we can go by advancing a syllable vector matches; - prism.CommonPrefixSearch(input.substr(current_pos), &matches); + set match_set; + auto current_input = input.substr(current_pos); + prism.CommonPrefixSearch(current_input, &matches); + for (auto &m : matches) { + match_set.insert(m.value); + } + if (enable_correction_) { + Corrections corrections; + corrector_->ToleranceSearch(prism, current_input, &corrections, 5); + for (const auto &m : corrections) { + for (auto accessor = prism.QuerySpelling(m.first); !accessor.exhausted(); accessor.Next()) { + if (accessor.properties().type == kNormalSpelling) { + matches.push_back({ m.first, m.second.length }); + break; + } + } + } + } + if (!matches.empty()) { auto& end_vertices(graph->edges[current_pos]); for (const auto& m : matches) { @@ -56,7 +79,7 @@ int Syllabifier::BuildSyllableGraph(const string &input, ++end_pos; DLOG(INFO) << "end_pos: " << end_pos; bool matches_input = (current_pos == 0 && end_pos == input.length()); - SpellingMap spellings; + SpellingMap& spellings(end_vertices[end_pos]); SpellingType end_vertex_type = kInvalidSpelling; // when spelling algebra is enabled, // a spelling evaluates to a set of syllables; @@ -64,7 +87,7 @@ int Syllabifier::BuildSyllableGraph(const string &input, SpellingAccessor accessor(prism.QuerySpelling(m.value)); while (!accessor.exhausted()) { SyllableId syllable_id = accessor.syllable_id(); - SpellingProperties props = accessor.properties(); + EdgeProperties props(accessor.properties()); if (strict_spelling_ && matches_input && props.type != kNormalSpelling) { @@ -74,10 +97,19 @@ int Syllabifier::BuildSyllableGraph(const string &input, props.end_pos = end_pos; // add a syllable with properties to the edge's // spelling-to-syllable map - spellings.insert({syllable_id, props}); + if (match_set.find(m.value) == match_set.end()) { + props.is_correction = true; + props.credibility = 0.01; + } + auto it = spellings.find(syllable_id); + if (it == spellings.end()) { + spellings.insert({syllable_id, props}); + } else { + it->second.type = std::min(it->second.type, props.type); + } // let end_vertex_type be the best (smaller) type of spelling // that ends at the vertex - if (end_vertex_type > props.type) { + if (end_vertex_type > props.type && !props.is_correction) { end_vertex_type = props.type; } } @@ -85,9 +117,9 @@ int Syllabifier::BuildSyllableGraph(const string &input, } if (spellings.empty()) { DLOG(INFO) << "not spelt."; + end_vertices.erase(end_pos); continue; } - end_vertices[end_pos].swap(spellings); // find the best common type in a path up to the end vertex // eg. pinyin "shurfa" has vertex type kNormalSpelling at position 3, // kAbbreviation at position 4 and kAbbreviation at position 6 @@ -121,6 +153,10 @@ int Syllabifier::BuildSyllableGraph(const string &input, // when there is a path of more favored type SpellingType edge_type = kInvalidSpelling; for (auto k = j->second.begin(); k != j->second.end(); ) { + if (k->second.is_correction) { + ++k; + continue; // Don't care correction edges + } if (k->second.type > last_type) { j->second.erase(k++); } @@ -245,4 +281,9 @@ void Syllabifier::Transpose(SyllableGraph* graph) { } } +void Syllabifier::EnableCorrection(an corrector) { + enable_correction_ = true; + corrector_ = std::move(corrector); +} + } // namespace rime diff --git a/src/rime/algo/syllabifier.h b/src/rime/algo/syllabifier.h index 505da9a7f..e9cab4630 100644 --- a/src/rime/algo/syllabifier.h +++ b/src/rime/algo/syllabifier.h @@ -15,15 +15,22 @@ namespace rime { class Prism; +class Corrector; using SyllableId = int32_t; -using SpellingMap = map; +struct EdgeProperties : SpellingProperties { + EdgeProperties(SpellingProperties sup): SpellingProperties(sup) {}; + EdgeProperties() = default; + bool is_correction = false; +}; + +using SpellingMap = map; using VertexMap = map; using EndVertexMap = map; using EdgeMap = map; -using SpellingPropertiesList = vector; +using SpellingPropertiesList = vector; using SpellingIndex = map; using SpellingIndices = map; @@ -49,6 +56,7 @@ class Syllabifier { RIME_API int BuildSyllableGraph(const string &input, Prism &prism, SyllableGraph *graph); + RIME_API void EnableCorrection(an corrector); protected: void CheckOverlappedSpellings(SyllableGraph *graph, @@ -58,6 +66,8 @@ class Syllabifier { string delimiters_; bool enable_completion_ = false; bool strict_spelling_ = false; + an corrector_ = nullptr; + bool enable_correction_ = false; }; } // namespace rime diff --git a/src/rime/common.h b/src/rime/common.h index 0a5f47f6e..e2b91c1b6 100644 --- a/src/rime/common.h +++ b/src/rime/common.h @@ -18,6 +18,7 @@ #include #include #include +#include #define BOOST_BIND_NO_PLACEHOLDERS #ifdef BOOST_SIGNALS2 #include @@ -47,6 +48,7 @@ using std::pair; using std::set; using std::string; using std::vector; +using boost::optional; template using hash_map = std::unordered_map; diff --git a/src/rime/dict/corrector.cc b/src/rime/dict/corrector.cc new file mode 100644 index 000000000..421a7ce50 --- /dev/null +++ b/src/rime/dict/corrector.cc @@ -0,0 +1,313 @@ +// +// Copyright RIME Developers +// Distributed under the BSD License +// +// Created by nameoverflow on 2018/11/14. +// + +#include "corrector.h" +#include +#include +#include +#include +#include +#include + +using namespace rime; +using namespace corrector; + +static hash_map> keyboard_map = { + {'1', {'2', 'q', 'w'}}, + {'2', {'1', '3', 'q', 'w', 'e'}}, + {'3', {'2', '4', 'w', 'e', 'r'}}, + {'4', {'3', '5', 'e', 'r', 't'}}, + {'5', {'4', '6', 'r', 't', 'y'}}, + {'6', {'5', '7', 't', 'y', 'u'}}, + {'7', {'6', '8', 'y', 'u', 'i'}}, + {'8', {'7', '9', 'u', 'i', 'o'}}, + {'9', {'8', '0', 'i', 'o', 'p'}}, + {'0', {'9', '-', 'o', 'p', '['}}, + {'-', {'0', '=', 'p', '[', ']'}}, + {'=', {'-', '[', ']', '\\'}}, + {'q', {'w'}}, + {'w', {'q', 'e'}}, + {'e', {'w', 'r'}}, + {'r', {'e', 't'}}, + {'t', {'r', 'y'}}, + {'y', {'t', 'u'}}, + {'u', {'y', 'i'}}, + {'i', {'u', 'o'}}, + {'o', {'i', 'p'}}, + {'p', {'o', '['}}, + {'[', {'p', ']'}}, + {']', {'[', '\\'}}, + {'\\', {']'}}, + {'a', {'s'}}, + {'s', {'a', 'd'}}, + {'d', {'s', 'f'}}, + {'f', {'d', 'g'}}, + {'g', {'f', 'h'}}, + {'h', {'g', 'j'}}, + {'j', {'h', 'k'}}, + {'k', {'j', 'l'}}, + {'l', {'k', ';'}}, + {';', {'l', '\''}}, + {'\'', {';'}}, + {'z', {'x'}}, + {'x', {'z', 'c'}}, + {'c', {'x', 'v'}}, + {'v', {'c', 'b'}}, + {'b', {'v', 'n'}}, + {'n', {'b', 'm'}}, + {'m', {'n', ','}}, + {',', {'m', '.'}}, + {'.', {',', '/'}}, + {'/', {'.'}}, +}; + +void DFSCollect(const string &origin, const string ¤t, size_t ed, Script &result); + +Script SymDeleteCollector::Collect(size_t edit_distance) { + // TODO: specifically for 1 length str + Script script; + + for (auto &v : syllabary_) { + DFSCollect(v, v, edit_distance, script); + } + + return script; +} + +void DFSCollect(const string &origin, const string ¤t, size_t ed, Script &result) { + if (ed <= 0) return; + for (size_t i = 0; i < current.size(); i++) { + string temp = current; + temp.erase(i, 1); + Spelling spelling(origin); + spelling.properties.tips = origin; + result[temp].push_back(spelling); + DFSCollect(origin, temp, ed - 1, result); + } +} + +void EditDistanceCorrector::ToleranceSearch(const Prism &prism, + const string &key, + Corrections *results, + size_t threshold) { + if (key.empty()) + return; + size_t key_len = key.length(); + + vector jump_pos(key_len); + + auto match_next = [&](size_t &node, size_t &point) -> bool { + auto res_val = trie_->traverse(key.c_str(), node, point, point + 1); + if (res_val == -2) return false; + if (res_val >= 0) { + for (auto accessor = QuerySpelling(res_val); !accessor.exhausted(); accessor.Next()) { + auto origin = accessor.properties().tips; + auto current_input = key.substr(0, point); + if (origin == current_input) { + continue; // early termination: this comparision is O(n) + } + auto distance = RestrictedDistance(origin, current_input, threshold); + if (distance <= threshold) { // only trace near words + SyllableId corrected; + if (prism.GetValue(origin, &corrected)) { + results->Alter(corrected, { distance, corrected, point }); + } + } + + } + } + return true; + }; + + // pass through origin key, cache trie nodes + size_t max_match = 0; + for (size_t next_node = 0; max_match < key_len;) { + jump_pos[max_match] = next_node; + if (!match_next(next_node, max_match)) break; + } + + // start at the next position of deleted char + for (size_t del_pos = 0; del_pos <= max_match; del_pos++) { + size_t next_node = jump_pos[del_pos]; + for (size_t key_point = del_pos + 1; key_point < key_len;) { + if (!match_next(next_node, key_point)) break; + } + } +} + + +inline uint8_t SubstCost(char left, char right) { + if (left == right) return 0; + if (keyboard_map[left].find(right) != keyboard_map[left].end()) { + return 1; + } + return 4; +} + +// This nice O(min(m, n)) implementation is from +// https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance#C++ +Distance EditDistanceCorrector::LevenshteinDistance(const std::string &s1, const std::string &s2) { + // To change the type this function manipulates and returns, change + // the return type and the types of the two variables below. + auto s1len = (size_t)s1.size(); + auto s2len = (size_t)s2.size(); + + auto column_start = (decltype(s1len))1; + + auto column = new decltype(s1len)[s1len + 1]; + std::iota(column + column_start - 1, column + s1len + 1, column_start - 1); + + for (auto x = column_start; x <= s2len; x++) { + column[0] = x; + auto last_diagonal = x - column_start; + for (auto y = column_start; y <= s1len; y++) { + auto old_diagonal = column[y]; + auto possibilities = { + column[y] + 1, + column[y - 1] + 1, + last_diagonal + SubstCost(s1[y - 1], s2[x - 1]) + }; + + column[y] = std::min(possibilities); + last_diagonal = old_diagonal; + } + } + auto result = column[s1len]; + delete[] column; + return result; +} + +// L's distance with transposition allowed +Distance EditDistanceCorrector::RestrictedDistance(const std::string& s1, + const std::string& s2, + Distance threshold) { + auto len1 = s1.size(), len2 = s2.size(); + vector d((len1 + 1) * (len2 + 1)); + + auto index = [len1, len2](size_t i, size_t j) { + return i * (len2 + 1) + j; + }; + + d[0] = 0; + for(size_t i = 1; i <= len1; ++i) d[index(i, 0)] = i * 2; + for(size_t i = 1; i <= len2; ++i) d[index(0, i)] = i * 2; + + for(size_t i = 1; i <= len1; ++i) { + auto min_d = threshold + 1; + for(size_t j = 1; j <= len2; ++j) { + d[index(i, j)] = std::min({ + d[index(i - 1, j)] + 2, + d[index(i, j - 1)] + 2, + d[index(i - 1, j - 1)] + SubstCost(s1[i - 1], s2[j - 1]) + }); + if (i > 1 && j > 1 && s1[i - 2] == s2[j - 1] && s1[i - 1] == s2[j - 2]) { + d[index(i, j)] = std::min(d[index(i, j)], d[index(i - 2, j - 2)] + 2); + } + min_d = std::min(min_d, d[index(i, j)]); + } + // early termination: do not continue if too far + if (min_d > threshold) + return min_d; + } + return (uint8_t)d[index(len1, len2)]; +} +bool EditDistanceCorrector::Build(const Syllabary &syllabary, + const Script *script, + uint32_t dict_file_checksum, + uint32_t schema_file_checksum) { + Syllabary correct_syllabary; + if (script && !script->empty()) { + for (auto &v : *script) { + correct_syllabary.insert(v.first); + } + } else { + correct_syllabary = syllabary; + } + + SymDeleteCollector collector(correct_syllabary); + auto correction_script = collector.Collect((size_t)1); + + return Prism::Build(syllabary, &correction_script, dict_file_checksum, schema_file_checksum); +} +EditDistanceCorrector::EditDistanceCorrector(const string &file_name) : Prism(file_name) {} + +void +NearSearchCorrector::ToleranceSearch(const Prism &prism, + const string &key, + Corrections *results, + size_t threshold) { + if (key.empty()) + return ; + + using record = struct { + size_t node_pos; + size_t idx; + size_t distance; + char ch; + }; + + std::queue queue; + queue.push({ 0, 0, 0, key[0] }); + for (auto subst : keyboard_map[key[0]]) { + queue.push({ 0, 0, 1, subst }); + } + for (; !queue.empty(); queue.pop()) { + auto &rec = queue.front(); + char ch = rec.ch; + char &exchange(const_cast(key.c_str())[rec.idx]); + std::swap(ch, exchange); + auto val = prism.trie().traverse(key.c_str(), rec.node_pos, rec.idx, rec.idx + 1); + std::swap(ch, exchange); + + if (val == -2) continue; + if (val >= 0) { + results->Alter(val, { rec.distance, val, rec.idx }); + } + if (rec.idx < key.size()) { + queue.push({ rec.node_pos, rec.idx, rec.distance, key[rec.idx] }); + if (rec.distance < threshold) { + for (auto subst : keyboard_map[key[rec.idx]]) { + queue.push({ rec.node_pos, rec.idx, rec.distance + 1, subst }); + } + } + } + } +} +void CorrectorComponent::Unified::ToleranceSearch(const Prism &prism, + const string &key, + Corrections *results, + size_t tolerance) { + for (auto &c : contents) { + c->ToleranceSearch(prism, key, results, tolerance); + } +} +CorrectorComponent::CorrectorComponent() + : resolver_(Service::instance().CreateResourceResolver({ "corrector", "build/", ".correction.bin" })) { +} + +Corrector *CorrectorComponent::Create(const Ticket &ticket) noexcept { + // Don't use edit distance based correction for now. +#if 0 + if (!ticket.schema) return nullptr; + Config* config = ticket.schema->config(); + string prism_name; + if (!config->GetString(ticket.name_space + "/prism", &prism_name)) { + config->GetString(ticket.name_space + "/dictionary", &prism_name); + } + + auto file_name = resolver_->ResolvePath(prism_name).string(); + + auto ed_corrector = New(file_name); + if (edCorrector->Load()) { + return Combine(New(), ed_corrector); + } else { + return new NearSearchCorrector(); + } +#endif + return new NearSearchCorrector(); + +} diff --git a/src/rime/dict/corrector.h b/src/rime/dict/corrector.h new file mode 100644 index 000000000..ddaa3ef07 --- /dev/null +++ b/src/rime/dict/corrector.h @@ -0,0 +1,131 @@ +// +// Copyright RIME Developers +// Distributed under the BSD License +// +// Created by nameoverflow on 2018/11/14. +// + +#ifndef RIME_CORRECTOR_H +#define RIME_CORRECTOR_H + +#include +#include +#include +#include +#include +#include + +namespace rime { +struct Ticket; + +class SymDeleteCollector { + public: + explicit SymDeleteCollector(const Syllabary& syllabary): syllabary_(syllabary) {} + + Script Collect(size_t edit_distance); + + private: + const Syllabary& syllabary_; +}; + +namespace corrector { +using Distance = size_t; +struct Correction { + size_t distance; + SyllableId syllable; + size_t length; +}; +class Corrections : public hash_map { + public: + /// Update for better correction + /// \param syllable + /// \param correction + inline void Alter(SyllableId syllable, Correction correction) { + if (find(syllable) == end() || correction.distance < (*this)[syllable].distance) { + (*this)[syllable] = correction; + } + }; +}; +} // namespace corrector + +/** + * The unify interface of correctors + */ +class Corrector : public Class { + public: + virtual ~Corrector() = default; + RIME_API virtual void ToleranceSearch(const Prism &prism, + const string &key, + corrector::Corrections *results, + size_t tolerance) = 0; +}; + +class CorrectorComponent : public Corrector::Component { + public: + CorrectorComponent(); + ~CorrectorComponent() override = default; + Corrector *Create(const Ticket& ticket) noexcept override; + private: + template + static Corrector *Combine(Cs ...args); + + map> correctors_; + the resolver_; + + class Unified : public Corrector { + public: + Unified() = default; + RIME_API void ToleranceSearch(const Prism &prism, + const string &key, + corrector::Corrections *results, + size_t tolerance) override; + template + void Add(Cs ...args) { + contents = { args... }; + } + + private: + vector> contents = {}; + }; +}; + + +class EditDistanceCorrector : public Corrector, + public Prism { + public: + ~EditDistanceCorrector() override = default; + RIME_API explicit EditDistanceCorrector(const string& file_name); + + RIME_API bool Build(const Syllabary& syllabary, + const Script* script = nullptr, + uint32_t dict_file_checksum = 0, + uint32_t schema_file_checksum = 0); + + RIME_API void ToleranceSearch(const Prism &prism, + const string &key, + corrector::Corrections *results, + size_t tolerance) override; + corrector::Distance LevenshteinDistance(const std::string &s1, const std::string &s2); + corrector::Distance RestrictedDistance(const std::string& s1, const std::string& s2, corrector::Distance threshold); +}; + +class NearSearchCorrector : public Corrector { + public: + NearSearchCorrector() = default; + ~NearSearchCorrector() override = default; + RIME_API void ToleranceSearch(const Prism &prism, + const string &key, + corrector::Corrections *results, + size_t tolerance) override; +}; + +template +Corrector *CorrectorComponent::Combine(Cs ...args) { + auto u = new Unified(); + u->Add(args...); + return u; +} + +} // namespace rime + +#endif //RIME_CORRECTOR_H diff --git a/src/rime/dict/dict_compiler.cc b/src/rime/dict/dict_compiler.cc index b622a2d6e..10cd3fc3a 100644 --- a/src/rime/dict/dict_compiler.cc +++ b/src/rime/dict/dict_compiler.cc @@ -4,20 +4,21 @@ // // 2011-11-27 GONG Chen // -#include #include -#include -#include +#include #include #include -#include +#include #include #include +#include #include #include #include -#include #include +#include +#include +#include namespace rime { @@ -212,7 +213,7 @@ bool DictCompiler::BuildPrism(const string &schema_file, Syllabary syllabary; if (!table_->Load() || !table_->GetSyllabary(&syllabary) || syllabary.empty()) return false; - // apply spelling algebra + // apply spelling algebra and prepare corrections (if enabled) Script script; if (!schema_file.empty()) { Config config; @@ -230,6 +231,26 @@ bool DictCompiler::BuildPrism(const string &schema_file, script.clear(); } } + +#if 0 + // build corrector + bool enable_correction = false; // Avoid if initializer to comfort compilers + if (config.GetBool("translator/enable_correction", &enable_correction) && + enable_correction) { + boost::filesystem::path corrector_path(prism_->file_name()); + corrector_path.replace_extension(""); + corrector_path.replace_extension(".correction.bin"); + correction_ = New(RelocateToUserDirectory(prefix_, corrector_path.string())); + if (correction_->Exists()) { + correction_->Remove(); + } + if (!correction_->Build(syllabary, &script, + dict_file_checksum, schema_file_checksum) || + !correction_->Save()) { + return false; + } + } +#endif } if ((options_ & kDump) && !script.empty()) { boost::filesystem::path path(prism_->file_name()); @@ -239,12 +260,13 @@ bool DictCompiler::BuildPrism(const string &schema_file, // build .prism.bin { prism_->Remove(); - if (!prism_->Build(syllabary, script.empty() ? NULL : &script, + if (!prism_->Build(syllabary, script.empty() ? nullptr : &script, dict_file_checksum, schema_file_checksum) || !prism_->Save()) { return false; } } + return true; } diff --git a/src/rime/dict/dict_compiler.h b/src/rime/dict/dict_compiler.h index 0e274970a..b0bbff63d 100644 --- a/src/rime/dict/dict_compiler.h +++ b/src/rime/dict/dict_compiler.h @@ -17,6 +17,7 @@ class Prism; class Table; class ReverseDb; class DictSettings; +class EditDistanceCorrector; class DictCompiler { public: @@ -43,6 +44,7 @@ class DictCompiler { string dict_name_; an prism_; + an correction_; an table_; int options_ = 0; string prefix_; diff --git a/src/rime/dict/dict_module.cc b/src/rime/dict/dict_module.cc index 752c2401e..30492ec3e 100644 --- a/src/rime/dict/dict_module.cc +++ b/src/rime/dict/dict_module.cc @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -32,6 +33,8 @@ static void rime_dict_initialize() { // upgrade userdbs from an old file format (eg. TreeDb) during maintenance. //r.Register("legacy_userdb", ...); + r.Register("corrector", new CorrectorComponent); + r.Register("dictionary", new DictionaryComponent); r.Register("reverse_lookup_dictionary", new ReverseLookupDictionaryComponent); diff --git a/src/rime/dict/dictionary.cc b/src/rime/dict/dictionary.cc index 996d83a54..b0fa9682a 100644 --- a/src/rime/dict/dictionary.cc +++ b/src/rime/dict/dictionary.cc @@ -4,15 +4,15 @@ // // 2011-07-05 GONG Chen // -#include #include +#include #include +#include #include #include #include #include -#include -#include +#include namespace rime { @@ -147,9 +147,9 @@ bool DictEntryIterator::Skip(size_t num_entries) { // Dictionary members Dictionary::Dictionary(const string& name, - const an
& table, - const an& prism) - : name_(name), table_(table), prism_(prism) { + an
table, + an prism) + : name_(name), table_(std::move(table)), prism_(std::move(prism)) { } Dictionary::~Dictionary() { @@ -292,23 +292,22 @@ DictionaryComponent::DictionaryComponent() : prism_resource_resolver_( Service::instance().CreateResourceResolver(kPrismResourceType)), table_resource_resolver_( - Service::instance().CreateResourceResolver(kTableResourceType)) { -} + Service::instance().CreateResourceResolver(kTableResourceType)) {} DictionaryComponent::~DictionaryComponent() { } Dictionary* DictionaryComponent::Create(const Ticket& ticket) { - if (!ticket.schema) return NULL; + if (!ticket.schema) return nullptr; Config* config = ticket.schema->config(); string dict_name; if (!config->GetString(ticket.name_space + "/dictionary", &dict_name)) { LOG(ERROR) << ticket.name_space << "/dictionary not specified in schema '" << ticket.schema->schema_id() << "'."; - return NULL; + return nullptr; } if (dict_name.empty()) { - return NULL; // not requiring static dictionary + return nullptr; // not requiring static dictionary } string prism_name; if (!config->GetString(ticket.name_space + "/prism", &prism_name)) { diff --git a/src/rime/dict/dictionary.h b/src/rime/dict/dictionary.h index 2851cf7b4..f1aaaedb5 100644 --- a/src/rime/dict/dictionary.h +++ b/src/rime/dict/dictionary.h @@ -70,14 +70,15 @@ struct DictEntryCollector : map { class Config; class Schema; +class EditDistanceCorrector; struct SyllableGraph; struct Ticket; class Dictionary : public Class { public: RIME_API Dictionary(const string& name, - const an
& table, - const an& prism); + an
table, + an prism); virtual ~Dictionary(); bool Exists() const; @@ -113,8 +114,8 @@ class ResourceResolver; class DictionaryComponent : public Dictionary::Component { public: DictionaryComponent(); - ~DictionaryComponent(); - Dictionary* Create(const Ticket& ticket); + ~DictionaryComponent() override; + Dictionary* Create(const Ticket& ticket) override; Dictionary* CreateDictionaryWithName(const string& dict_name, const string& prism_name); diff --git a/src/rime/dict/prism.cc b/src/rime/dict/prism.cc index dbd4b10f3..ae84e25f4 100644 --- a/src/rime/dict/prism.cc +++ b/src/rime/dict/prism.cc @@ -122,7 +122,6 @@ bool Prism::Save() { } return ShrinkToFit(); } - bool Prism::Build(const Syllabary& syllabary, const Script* script, uint32_t dict_file_checksum, @@ -240,7 +239,7 @@ bool Prism::HasKey(const string& key) { return value != -1; } -bool Prism::GetValue(const string& key, int* value) { +bool Prism::GetValue(const string& key, int* value) const { int result = trie_->exactMatchSearch(key.c_str()); if (result == -1) { return false; diff --git a/src/rime/dict/prism.h b/src/rime/dict/prism.h index a70d5b6c2..c0a120224 100644 --- a/src/rime/dict/prism.h +++ b/src/rime/dict/prism.h @@ -72,12 +72,12 @@ class Prism : public MappedFile { RIME_API bool Load(); RIME_API bool Save(); RIME_API bool Build(const Syllabary& syllabary, - const Script* script = NULL, + const Script* script = nullptr, uint32_t dict_file_checksum = 0, uint32_t schema_file_checksum = 0); RIME_API bool HasKey(const string& key); - RIME_API bool GetValue(const string& key, int* value); + RIME_API bool GetValue(const string& key, int* value) const; RIME_API void CommonPrefixSearch(const string& key, vector* result); RIME_API void ExpandSearch(const string& key, vector* result, size_t limit); SpellingAccessor QuerySpelling(SyllableId spelling_id); @@ -86,8 +86,9 @@ class Prism : public MappedFile { uint32_t dict_file_checksum() const; uint32_t schema_file_checksum() const; + Darts::DoubleArray& trie() const { return *trie_; } - private: + protected: the trie_; prism::Metadata* metadata_ = nullptr; prism::SpellingMap* spelling_map_ = nullptr; diff --git a/src/rime/gear/gears_module.cc b/src/rime/gear/gears_module.cc index ad175315d..bf0fb38a6 100644 --- a/src/rime/gear/gears_module.cc +++ b/src/rime/gear/gears_module.cc @@ -5,38 +5,37 @@ // 2013-10-17 GONG Chen // -#include #include -#include - #include #include #include #include #include #include +#include #include #include #include +#include +#include #include #include -#include #include -#include #include +#include +#include +#include +#include #include +#include #include #include -#include #include -#include -#include -#include #include #include #include -#include -#include +#include +#include static void rime_gears_initialize() { using namespace rime; diff --git a/src/rime/gear/script_translator.cc b/src/rime/gear/script_translator.cc index 8a168c478..7babde202 100644 --- a/src/rime/gear/script_translator.cc +++ b/src/rime/gear/script_translator.cc @@ -17,8 +17,9 @@ #include #include #include -#include #include +#include +#include #include #include #include @@ -79,6 +80,7 @@ class ScriptSyllabifier : public PhraseSyllabifier { size_t BuildSyllableGraph(Prism& prism); string GetPreeditString(const Phrase& cand) const; string GetOriginalSpelling(const Phrase& cand) const; + bool IsCandidateCorrection(const Phrase& cand) const; const SyllableGraph& syllable_graph() const { return syllable_graph_; } @@ -92,9 +94,11 @@ class ScriptSyllabifier : public PhraseSyllabifier { class ScriptTranslation : public Translation { public: ScriptTranslation(ScriptTranslator* translator, - const string& input, size_t start) + const string& input, size_t start, + bool enable_correction = false) : translator_(translator), start_(start), - syllabifier_(New(translator, input, start)) { + syllabifier_(New(translator, input, start)), + enable_correction_(enable_correction) { set_exhausted(true); } bool Evaluate(Dictionary* dict, UserDictionary* user_dict); @@ -106,6 +110,7 @@ class ScriptTranslation : public Translation { bool IsNormalSpelling() const; an MakeSentence(Dictionary* dict, UserDictionary* user_dict); + void PrepareCandidate(); ScriptTranslator* translator_; size_t start_; @@ -115,9 +120,16 @@ class ScriptTranslation : public Translation { an user_phrase_; an sentence_; + an candidate_ = nullptr; + DictEntryCollector::reverse_iterator phrase_iter_; UserDictEntryCollector::reverse_iterator user_phrase_iter_; size_t user_phrase_index_ = 0; + + size_t max_corrections_ = 4; + size_t correction_count_ = 0; + + bool enable_correction_; }; // ScriptTranslator implementation @@ -132,6 +144,11 @@ ScriptTranslator::ScriptTranslator(const Ticket& ticket) config->GetInt(name_space_ + "/spelling_hints", &spelling_hints_); config->GetBool(name_space_ + "/always_show_comments", &always_show_comments_); + config->GetBool(name_space_ + "/enable_correction", &enable_correction_); + } + if (enable_correction_) { + auto corrector = Corrector::Require("corrector"); + corrector_.reset(corrector->Create(ticket)); } } @@ -150,7 +167,7 @@ an ScriptTranslator::Query(const string& input, !IsUserDictDisabledFor(input); // the translator should survive translations it creates - auto result = New(this, input, segment.start); + auto result = New(this, input, segment.start, enable_correction_); if (!result || !result->Evaluate(dict_.get(), enable_user_dict ? user_dict_.get() : NULL)) { @@ -225,12 +242,55 @@ size_t ScriptSyllabifier::BuildSyllableGraph(Prism& prism) { Syllabifier syllabifier(translator_->delimiters(), translator_->enable_completion(), translator_->strict_spelling()); - size_t consumed = syllabifier.BuildSyllableGraph(input_, + if (translator_->enable_correction()) { + syllabifier.EnableCorrection(translator_->corrector()); + } + auto consumed = (size_t)syllabifier.BuildSyllableGraph(input_, prism, &syllable_graph_); + return consumed; } +bool ScriptSyllabifier::IsCandidateCorrection(const rime::Phrase &cand) const { + std::stack results; + bool result = false; + // Perform DFS on syllable graph to find whether this candidate is a correction + SyllabifyTask task { + cand.code(), + syllable_graph_, + cand.end() - start_, + [&](SyllabifyTask* task, size_t depth, + size_t current_pos, size_t next_pos) { + auto id = cand.code()[depth]; + auto it_s = syllable_graph_.edges.find(current_pos); + // C++ prohibit operator [] of const map + // if (syllable_graph_.edges[current_pos][next_pos][id].type == kCorrection) + if (it_s != syllable_graph_.edges.end()) { + auto it_e = it_s->second.find(next_pos); + if (it_e != it_s->second.end()) { + auto it_type = it_e->second.find(id); + if (it_type != it_e->second.end()) { + results.push(it_type->second.is_correction); + return; + } + } + } + results.push(false); + }, + [&](SyllabifyTask* task, size_t depth) { + results.pop(); + } + }; + if (syllabify_dfs(&task, 0, cand.start() - start_)) { + for (; !results.empty(); results.pop()) { + if (results.top()) + return results.top(); + } + } + return false; +} + string ScriptSyllabifier::GetPreeditString(const Phrase& cand) const { const auto& delimiters = translator_->delimiters(); std::stack lengths; @@ -301,33 +361,50 @@ bool ScriptTranslation::Evaluate(Dictionary* dict, UserDictionary* user_dict) { } bool ScriptTranslation::Next() { - if (exhausted()) - return false; - if (sentence_) { - sentence_.reset(); - return !CheckEmpty(); - } - int user_phrase_code_length = 0; - if (user_phrase_ && user_phrase_iter_ != user_phrase_->rend()) { - user_phrase_code_length = user_phrase_iter_->first; - } - int phrase_code_length = 0; - if (phrase_ && phrase_iter_ != phrase_->rend()) { - phrase_code_length = phrase_iter_->first; - } - if (user_phrase_code_length > 0 && - user_phrase_code_length >= phrase_code_length) { - DictEntryList& entries(user_phrase_iter_->second); - if (++user_phrase_index_ >= entries.size()) { - ++user_phrase_iter_; - user_phrase_index_ = 0; + bool is_correction; + do { + is_correction = false; + if (exhausted()) + return false; + if (sentence_) { + sentence_.reset(); + return !CheckEmpty(); } - } - else if (phrase_code_length > 0) { - DictEntryIterator& iter(phrase_iter_->second); - if (!iter.Next()) { - ++phrase_iter_; + int user_phrase_code_length = 0; + if (user_phrase_ && user_phrase_iter_ != user_phrase_->rend()) { + user_phrase_code_length = user_phrase_iter_->first; + } + int phrase_code_length = 0; + if (phrase_ && phrase_iter_ != phrase_->rend()) { + phrase_code_length = phrase_iter_->first; + } + if (user_phrase_code_length > 0 && + user_phrase_code_length >= phrase_code_length) { + DictEntryList& entries(user_phrase_iter_->second); + if (++user_phrase_index_ >= entries.size()) { + ++user_phrase_iter_; + user_phrase_index_ = 0; + } + } + else if (phrase_code_length > 0) { + DictEntryIterator& iter(phrase_iter_->second); + if (!iter.Next()) { + ++phrase_iter_; + } } + if (enable_correction_) { + PrepareCandidate(); + if (!candidate_) { + break; + } + is_correction = syllabifier_->IsCandidateCorrection(*candidate_); + } + } while ( // limit the number of correction candidates + enable_correction_ && + is_correction && + correction_count_ > max_corrections_); + if (is_correction) { + ++correction_count_; } return !CheckEmpty(); } @@ -339,8 +416,28 @@ bool ScriptTranslation::IsNormalSpelling() const { } an ScriptTranslation::Peek() { - if (exhausted()) + PrepareCandidate(); + if (!candidate_) { return nullptr; + } + if (candidate_->preedit().empty()) { + candidate_->set_preedit(syllabifier_->GetPreeditString(*candidate_)); + } + if (candidate_->comment().empty()) { + auto spelling = syllabifier_->GetOriginalSpelling(*candidate_); + if (!spelling.empty() && spelling != candidate_->preedit()) { + candidate_->set_comment(/*quote_left + */spelling/* + quote_right*/); + } + } + candidate_->set_syllabifier(syllabifier_); + return candidate_; +} + +void ScriptTranslation::PrepareCandidate() { + if (exhausted()) { + candidate_ = nullptr; + return; + } if (sentence_) { if (sentence_->preedit().empty()) { sentence_->set_preedit(syllabifier_->GetPreeditString(*sentence_)); @@ -349,11 +446,12 @@ an ScriptTranslation::Peek() { auto spelling = syllabifier_->GetOriginalSpelling(*sentence_); if (!spelling.empty() && (translator_->always_show_comments() || - spelling != sentence_->preedit())) { + spelling != sentence_->preedit())) { sentence_->set_comment(/*quote_left + */spelling/* + quote_right*/); } } - return sentence_; + candidate_ = sentence_; + return; } size_t user_phrase_code_length = 0; if (user_phrase_ && user_phrase_iter_ != user_phrase_->rend()) { @@ -376,8 +474,8 @@ an ScriptTranslation::Peek() { start_ + user_phrase_code_length, entry); cand->set_quality(entry->weight + - translator_->initial_quality() + - (IsNormalSpelling() ? 0.5 : -0.5)); + translator_->initial_quality() + + (IsNormalSpelling() ? 0.5 : -0.5)); } else if (phrase_code_length > 0) { DictEntryIterator& iter(phrase_iter_->second); @@ -390,20 +488,10 @@ an ScriptTranslation::Peek() { start_ + phrase_code_length, entry); cand->set_quality(entry->weight + - translator_->initial_quality() + - (IsNormalSpelling() ? 0 : -1)); - } - if (cand->preedit().empty()) { - cand->set_preedit(syllabifier_->GetPreeditString(*cand)); - } - if (cand->comment().empty()) { - auto spelling = syllabifier_->GetOriginalSpelling(*cand); - if (!spelling.empty() && spelling != cand->preedit()) { - cand->set_comment(/*quote_left + */spelling/* + quote_right*/); - } + translator_->initial_quality() + + (IsNormalSpelling() ? 0 : -1)); } - cand->set_syllabifier(syllabifier_); - return cand; + candidate_ = cand; } bool ScriptTranslation::CheckEmpty() { diff --git a/src/rime/gear/script_translator.h b/src/rime/gear/script_translator.h index 2eab5cf55..ed0bcb0a2 100644 --- a/src/rime/gear/script_translator.h +++ b/src/rime/gear/script_translator.h @@ -21,6 +21,7 @@ struct DictEntry; struct DictEntryCollector; class Dictionary; class UserDictionary; +class EditDistanceCorrector; struct SyllableGraph; class ScriptTranslator : public Translator, @@ -39,10 +40,14 @@ class ScriptTranslator : public Translator, // options int spelling_hints() const { return spelling_hints_; } bool always_show_comments() const { return always_show_comments_; } + bool enable_correction() const { return enable_correction_; } + an corrector() const { return corrector_; } protected: int spelling_hints_ = 0; bool always_show_comments_ = false; + bool enable_correction_ = false; + an corrector_ = nullptr; }; } // namespace rime diff --git a/test/corrector_test.cc b/test/corrector_test.cc new file mode 100644 index 000000000..06b980fd4 --- /dev/null +++ b/test/corrector_test.cc @@ -0,0 +1,151 @@ +// +// Copyright RIME Developers +// Distributed under the BSD License +// +// Created by nameoverflow on 2018/11/21. +// +#include +#include +#include +#include +#include +#include +#include + +class RimeCorrectorSearchTest : public ::testing::Test { + public: + void SetUp() override { + rime::vector syllables; + syllables.emplace_back("chang"); // 0 + syllables.emplace_back("tuan"); // 1 + std::sort(syllables.begin(), syllables.end()); + for (size_t i = 0; i < syllables.size(); ++i) { + syllable_id_[syllables[i]] = i; + } + + prism_.reset(new rime::Prism("corrector_simple_test.prism.bin")); + rime::set keyset; + std::copy(syllables.begin(), syllables.end(), + std::inserter(keyset, keyset.begin())); + prism_->Build(keyset); + + } + void TearDown() override {} + protected: + rime::map syllable_id_; + rime::the prism_; +}; + +class RimeCorrectorTest : public ::testing::Test { + public: + void SetUp() override { + rime::vector syllables; + syllables.emplace_back("j"); // 0 == id + syllables.emplace_back("ji"); // 1 + syllables.emplace_back("jie"); // 2 + syllables.emplace_back("ju"); // 3 + syllables.emplace_back("jue"); // 4 + syllables.emplace_back("shen"); // 5 + std::sort(syllables.begin(), syllables.end()); + for (size_t i = 0; i < syllables.size(); ++i) { + syllable_id_[syllables[i]] = i; + } + + prism_.reset(new rime::Prism("corrector_test.prism.bin")); + rime::set keyset; + std::copy(syllables.begin(), syllables.end(), + std::inserter(keyset, keyset.begin())); + prism_->Build(keyset); + } + + virtual void TearDown() { + } + + protected: + rime::map syllable_id_; + rime::the prism_; +}; + +TEST_F(RimeCorrectorSearchTest, CaseNearSubstitute) { + rime::Syllabifier s; + s.EnableCorrection(std::make_shared()); + rime::SyllableGraph g; + const rime::string input("chsng"); + s.BuildSyllableGraph(input, *prism_, &g); + EXPECT_EQ(input.length(), g.input_length); + EXPECT_EQ(input.length(), g.interpreted_length); + EXPECT_EQ(2, g.vertices.size()); + ASSERT_FALSE(g.vertices.end() == g.vertices.find(5)); + rime::SpellingMap& sp(g.edges[0][5]); + EXPECT_EQ(1, sp.size()); + ASSERT_FALSE(sp.end() == sp.find(syllable_id_["chang"])); +} + +TEST_F(RimeCorrectorSearchTest, CaseFarSubstitute) { + rime::Syllabifier s; + s.EnableCorrection(std::make_shared()); + rime::SyllableGraph g; + const rime::string input("chpng"); + s.BuildSyllableGraph(input, *prism_, &g); + EXPECT_EQ(input.length(), g.input_length); + EXPECT_EQ(0, g.interpreted_length); + EXPECT_EQ(1, g.vertices.size()); + ASSERT_TRUE(g.vertices.end() == g.vertices.find(5)); +} + +TEST_F(RimeCorrectorSearchTest, DISABLED_CaseTranspose) { + rime::Syllabifier s; + s.EnableCorrection(std::make_shared()); + rime::SyllableGraph g; + const rime::string input("cahng"); + s.BuildSyllableGraph(input, *prism_, &g); + EXPECT_EQ(input.length(), g.input_length); + EXPECT_EQ(input.length(), g.interpreted_length); + EXPECT_EQ(2, g.vertices.size()); + ASSERT_FALSE(g.vertices.end() == g.vertices.find(5)); + rime::SpellingMap& sp(g.edges[0][5]); + EXPECT_EQ(1, sp.size()); + ASSERT_FALSE(sp.end() == sp.find(syllable_id_["chang"])); +} + +TEST_F(RimeCorrectorSearchTest, CaseCorrectionSyllabify) { + rime::Syllabifier s; + s.EnableCorrection(std::make_shared()); + rime::SyllableGraph g; + const rime::string input("chabgtyan"); + s.BuildSyllableGraph(input, *prism_, &g); + EXPECT_EQ(input.length(), g.input_length); + EXPECT_EQ(input.length(), g.interpreted_length); + EXPECT_EQ(3, g.vertices.size()); + ASSERT_FALSE(g.vertices.end() == g.vertices.find(9)); + rime::SpellingMap& sp1(g.edges[0][5]); + EXPECT_EQ(1, sp1.size()); + ASSERT_FALSE(sp1.end() == sp1.find(syllable_id_["chang"])); + ASSERT_TRUE(sp1[0].is_correction); + rime::SpellingMap& sp2(g.edges[5][9]); + EXPECT_EQ(1, sp2.size()); + ASSERT_FALSE(sp2.end() == sp2.find(syllable_id_["tuan"])); + ASSERT_TRUE(sp2[1].is_correction); +} + +TEST_F(RimeCorrectorTest, CaseMultipleEdges1) { + rime::Syllabifier s; + s.EnableCorrection(std::make_shared()); + rime::SyllableGraph g; + const rime::string input("jiejue"); // jie'jue jie'jie jue'jue jue'jie + s.BuildSyllableGraph(input, *prism_, &g); + EXPECT_EQ(input.length(), g.input_length); + EXPECT_EQ(input.length(), g.interpreted_length); + rime::SpellingMap& sp1(g.edges[0][3]); + EXPECT_EQ(2, sp1.size()); + ASSERT_FALSE(sp1.end() == sp1.find(syllable_id_["jie"])); + ASSERT_TRUE(sp1[syllable_id_["jie"]].type == rime::kNormalSpelling); + ASSERT_FALSE(sp1.end() == sp1.find(syllable_id_["jue"])); + ASSERT_TRUE(sp1[syllable_id_["jue"]].is_correction); + rime::SpellingMap& sp2(g.edges[3][6]); + EXPECT_EQ(2, sp2.size()); + ASSERT_FALSE(sp2.end() == sp2.find(syllable_id_["jie"])); + ASSERT_TRUE(sp2[syllable_id_["jie"]].is_correction); + ASSERT_FALSE(sp2.end() == sp2.find(syllable_id_["jue"])); + ASSERT_TRUE(sp2[syllable_id_["jue"]].type == rime::kNormalSpelling); +}