diff --git a/thrift/compiler/sema/patch_mutator.cc b/thrift/compiler/sema/patch_mutator.cc index cfc1c04231a..cf57951bb01 100644 --- a/thrift/compiler/sema/patch_mutator.cc +++ b/thrift/compiler/sema/patch_mutator.cc @@ -80,6 +80,10 @@ const char* getOptionalPatchTypeName(t_base_type::type base_type) { } } +std::string getSibName(const std::string& sibling, const std::string& name) { + return sibling.substr(0, sibling.find_last_of("/")) + name; +} + // A fluent function to set the doc string on a given node. template N& doc(std::string txt, N& node) { @@ -377,18 +381,54 @@ t_type_ref patch_generator::find_patch_type( os << "Could not find expected patch type: " << name; }); return {}; - } else if (auto* structured = dynamic_cast(type)) { - // Try to find the generated patch type. + } + + // Check the field for a custom patch type. + if (auto* custom = field.find_annotation_or_null("thrift.patch.uri")) { + if (const auto* result = + dynamic_cast(program_.scope()->find_def(*custom))) { + return t_type_ref::from_ptr(result); + } + ctx_.warning(field, "Could not find custom type: {}", *custom); + } + + // Check the field type for a custom patch type. + if (auto* custom = t_typedef::get_first_annotation_or_null( + field.type().get_type(), {"thrift.patch.uri"})) { + if (const auto* result = + dynamic_cast(program_.scope()->find_def(*custom))) { + return t_type_ref::from_ptr(result); + } + ctx_.warning(*field.type(), "Could not find custom type: {}", *custom); + } + + if (auto* structured = dynamic_cast(type)) { std::string name = structured->name() + "ValuePatch"; if (field.qualifier() == t_field_qualifier::optional) { name = "Optional" + std::move(name); } - // It should be in the same program as the type itself. + + if (!structured->uri().empty()) { // Try to look up by URI. + if (auto* result = dynamic_cast(program_.scope()->find_def( + getSibName(structured->uri(), name)))) { + return t_type_ref::from_ptr(result); + } + } + + // Try to look up by Name. + // Look for it in the same program as the type itself. t_type_ref result = program_.scope()->ref_type( *structured->program(), name, field.src_range()); if (auto* ph = result.get_unresolved_type()) { + // Try the current program. + t_type_ref fallback = + program_.scope()->ref_type(program_, name, field.src_range()); + if (fallback.resolved()) { + return fallback; // TODO(afuller): Remove support for local fallbacks. + } // Set the location info, in case the type can't be resolved later. ph->set_lineno(field.lineno()); + ph->set_src_range(field.src_range()); ph->set_generated(); } return result;