Skip to content

Commit

Permalink
Use static registration for cross-package validation (#88)
Browse files Browse the repository at this point in the history
When compiling with MSVC, there is not an equivalent to__attribute__(weak).
In order for cross-package validation to work with MSVC, use static
registration instead

Signed-off-by: Arjun Sreedharan <[email protected]>
Signed-off-by: Sam Smith <[email protected]>
  • Loading branch information
sesmith177 authored and rodaine committed Aug 6, 2018
1 parent 10eb9db commit 05ff6bc
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 61 deletions.
12 changes: 3 additions & 9 deletions templates/cc/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,10 @@ using std::string;
{{ range .AllMessages }}
{{- if not (disabled .) -}}
bool CheckMessage(const {{ class . }}& m, pgv::ValidationMsg* err) {
return {{ package . }}::Validate(m, err);
}
{{ end }}
{{ end }}
{{ range (weakCheckMsgs .AllMessages) }}
bool __attribute__((weak)) CheckMessage(const {{ . }}& m, pgv::ValidationMsg* err) {
return true;
}
pgv::Validator<{{ class . }}> {{ staticVarName . }}(static_cast<bool(*)(const {{ class .}}&, pgv::ValidationMsg*)>({{ package .}}::Validate));
{{ end }}
{{ end }}
} // namespace validate
Expand Down
3 changes: 1 addition & 2 deletions templates/cc/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ package tpl
const messageTpl = `
{{ $f := .Field }}{{ $r := .Rules }}
{{ template "required" . }}
{{ if $r.GetSkip }}
// skipping validation for {{ $f.Name }}
{{ else }}
{
pgv::ValidationMsg inner_err;
if ({{ hasAccessor .}} && !pgv::validate::CheckMessage({{ accessor . }}, &inner_err)) {
if ({{ hasAccessor .}} && !pgv::Validator<{{ ctype $f.Type }}>::CheckMessage({{ accessor . }}, &inner_err)) {
{{ errCause . "inner_err" "embedded message failed validation" }}
}
}
Expand Down
48 changes: 8 additions & 40 deletions templates/cc/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ func RegisterModule(tpl *template.Template) {
"accessor": accessor,
"hasAccessor": hasAccessor,
"ctype": cType,
"weakCheckMsgs": weakCheckMsgs,
"err": err,
"errCause": errCause,
"errIdx": errIdx,
Expand All @@ -43,6 +42,7 @@ func RegisterModule(tpl *template.Template) {
"tsStr": tsStr,
"unwrap": unwrap,
"unimplemented": failUnimplemented,
"staticVarName": staticVarName,
})
template.Must(tpl.Parse(moduleFileTpl))
template.Must(tpl.New("msg").Parse(msgTpl))
Expand Down Expand Up @@ -86,8 +86,8 @@ func RegisterModule(tpl *template.Template) {

func RegisterHeader(tpl *template.Template) {
tpl.Funcs(map[string]interface{}{
"class": className,
"upper": strings.ToUpper,
"class": className,
"upper": strings.ToUpper,
})

template.Must(tpl.Parse(headerFileTpl))
Expand Down Expand Up @@ -274,7 +274,7 @@ func cType(t pgs.FieldType) string {
if t.IsEmbed() {
return className(t.Embed())
}
if t.IsRepeated(){
if t.IsRepeated() {
if t.ProtoType() == pgs.MessageT {
return className(t.Element().Embed())
}
Expand All @@ -285,42 +285,6 @@ func cType(t pgs.FieldType) string {
return cTypeOfString(t.Name().String())
}

// Compute unique C++ types that correspond to all message fields in a
// compilation unit that need to be weak (i.e. not already defined). Used to
// generate weak default definitions for CheckMessage.
func weakCheckMsgs(msgs []pgs.Message) []string {
already_defined := map[string]bool{}
// First compute the C++ type names for things we're going to provide an explicit
// CheckMessage() with Validate(..) body in this file. We can't define the
// same CheckMessage() signature twice in a compilation unit, even if one of
// them is weak.
for _, msg := range msgs {
already_defined[className(msg)] = true
}
// Compute the set of C++ type names we need weak definitions for.
ctype_map := map[string]bool{}
for _, msg := range msgs {
if disabled, _ := shared.Disabled(msg); disabled {
continue
}
for _, f := range msg.Fields() {
ctype := cType(f.Type())
if already_defined[ctype] {
continue
}
if f.Type().IsEmbed() || (f.Type().IsRepeated() && f.Type().Element().IsEmbed()) {
ctype_map[ctype] = true
}
}
}
// Convert to array.
ctypes := make([]string, 0, len(ctype_map))
for ctype := range ctype_map {
ctypes = append(ctypes, ctype)
}
return ctypes
}

func cTypeOfString(s string) string {
switch s {
case "float32":
Expand Down Expand Up @@ -412,3 +376,7 @@ func unwrap(ctx shared.RuleContext, name string) (shared.RuleContext, error) {
func failUnimplemented() string {
return "throw pgv::UnimplementedException();"
}

func staticVarName(msg pgs.Message) string {
return "validator_" + strings.Replace(className(msg), ":", "_", -1)
}
49 changes: 39 additions & 10 deletions validate/validate.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <stdexcept>
#include <string>
#include <typeinfo>

#include <google/protobuf/message.h>
#include <google/protobuf/util/time_util.h>
Expand All @@ -14,30 +15,58 @@ namespace protobuf = google::protobuf;
namespace protobuf_wkt = google::protobuf;

class UnimplementedException : public std::runtime_error {
public:
public:
UnimplementedException() : std::runtime_error("not yet implemented") {}
// Thrown by C++ validation code that is not yet implemented.
};

using ValidationMsg = std::string;

static inline std::string String(const ValidationMsg& msg) {
class BaseValidator {
protected:
static std::map<size_t, BaseValidator*> validators;
};

std::map<size_t, BaseValidator*> __attribute__((weak)) BaseValidator::validators;

template <typename T>
class Validator : public BaseValidator {
public:
Validator(std::function<bool(const T&, ValidationMsg*)> check) : check_(check)
{
validators[typeid(T).hash_code()] = this;
}

static bool CheckMessage(const T& m, ValidationMsg* err)
{
auto val = static_cast<Validator<T>*>(validators[typeid(T).hash_code()]);
if (val) {
return val->check_(m, err);
}
return true;
}

private:
std::function<bool(const T&, ValidationMsg*)> check_;
};

static inline std::string String(const ValidationMsg& msg)
{
return std::string(msg);
}

static inline bool IsPrefix(const string& maybe_prefix,
const string& search_in) {
static inline bool IsPrefix(const string& maybe_prefix, const string& search_in)
{
return search_in.compare(0, maybe_prefix.size(), maybe_prefix) == 0;
}

static inline bool IsSuffix(const string& maybe_suffix,
const string& search_in) {
return maybe_suffix.size() <= search_in.size() &&
search_in.compare(search_in.size() - maybe_suffix.size(),
maybe_suffix.size(), maybe_suffix) == 0;
static inline bool IsSuffix(const string& maybe_suffix, const string& search_in)
{
return maybe_suffix.size() <= search_in.size() && search_in.compare(search_in.size() - maybe_suffix.size(), maybe_suffix.size(), maybe_suffix) == 0;
}

static inline bool Contains(const string& search_in, const string& to_find) {
static inline bool Contains(const string& search_in, const string& to_find)
{
return search_in.find(to_find) != string::npos;
}

Expand Down

0 comments on commit 05ff6bc

Please sign in to comment.