Skip to content

Commit

Permalink
[generator] Refactor dependency management
Browse files Browse the repository at this point in the history
No functional changes intended.
  • Loading branch information
Yannic committed Jul 5, 2020
1 parent 3cf33f9 commit e62abdb
Showing 1 changed file with 108 additions and 141 deletions.
249 changes: 108 additions & 141 deletions javascript/net/grpc/web/grpc_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <algorithm>
#include <iterator>
#include <set>
#include <string>

using google::protobuf::Descriptor;
Expand Down Expand Up @@ -316,70 +317,73 @@ string ModuleAlias(const string& filename) {
}

string JSMessageType(const Descriptor *desc, const FileDescriptor *file) {
string module_prefix;
if (desc->file() != file) {
module_prefix = ModuleAlias(desc->file()->name()) + ".";
}
string class_name;
class_name = StripPrefixString(desc->full_name(), desc->file()->package());
if (!class_name.empty() && class_name[0] == '.') {
class_name = class_name.substr(1);
}
return module_prefix + class_name;
if (desc->file() == file) {
// [for protobuf .d.ts files only] Do not add the module prefix for local
// messages.
return class_name;
}
return ModuleAlias(desc->file()->name()) + "." + class_name;
}

string JSElementType(const FieldDescriptor *desc, const FileDescriptor *file)
{
string js_field_type;
switch (desc->type())
{
case FieldDescriptor::TYPE_DOUBLE:
case FieldDescriptor::TYPE_FLOAT:
case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_UINT32:
case FieldDescriptor::TYPE_SINT32:
case FieldDescriptor::TYPE_FIXED32:
case FieldDescriptor::TYPE_SFIXED32:
js_field_type = "number";
break;
case FieldDescriptor::TYPE_INT64:
case FieldDescriptor::TYPE_UINT64:
case FieldDescriptor::TYPE_SINT64:
case FieldDescriptor::TYPE_FIXED64:
case FieldDescriptor::TYPE_SFIXED64:
if (desc->options().jstype() == FieldOptions::JS_STRING) {
js_field_type = "string";
} else {
js_field_type = "number";
}
break;
case FieldDescriptor::TYPE_BOOL:
js_field_type = "boolean";
break;
case FieldDescriptor::TYPE_STRING:
js_field_type = "string";
break;
case FieldDescriptor::TYPE_BYTES:
js_field_type = "Uint8Array | string";
break;
case FieldDescriptor::TYPE_ENUM:
if (desc->enum_type()->file() != file) {
js_field_type = ModuleAlias(desc->enum_type()->file()->name());
}
js_field_type += StripPrefixString(desc->enum_type()->full_name(),
desc->enum_type()->file()->package());
if (!js_field_type.empty() && js_field_type[0] == '.') {
js_field_type = js_field_type.substr(1);
}
break;
case FieldDescriptor::TYPE_MESSAGE:
js_field_type = JSMessageType(desc->message_type(), file);
break;
default:
js_field_type = "{}";
break;
string JSMessageType(const Descriptor *desc) {
return JSMessageType(desc, nullptr);
}

string JSElementType(const FieldDescriptor *desc, const FileDescriptor *file) {
switch (desc->type()) {
case FieldDescriptor::TYPE_DOUBLE:
case FieldDescriptor::TYPE_FLOAT:
case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_UINT32:
case FieldDescriptor::TYPE_SINT32:
case FieldDescriptor::TYPE_FIXED32:
case FieldDescriptor::TYPE_SFIXED32:
return "number";

case FieldDescriptor::TYPE_INT64:
case FieldDescriptor::TYPE_UINT64:
case FieldDescriptor::TYPE_SINT64:
case FieldDescriptor::TYPE_FIXED64:
case FieldDescriptor::TYPE_SFIXED64:
if (desc->options().jstype() == FieldOptions::JS_STRING) {
return "string";
} else {
return "number";
}

case FieldDescriptor::TYPE_BOOL:
return "boolean";

case FieldDescriptor::TYPE_STRING:
return "string";

case FieldDescriptor::TYPE_BYTES:
return "Uint8Array | string";

case FieldDescriptor::TYPE_ENUM:
if (desc->enum_type()->file() == file) {
string enum_name =
StripPrefixString(
desc->enum_type()->full_name(),
desc->enum_type()->file()->package());
return enum_name.substr(1);
}
return ModuleAlias(desc->enum_type()->file()->name())
+ StripPrefixString(
desc->enum_type()->full_name(),
desc->enum_type()->file()->package());

case FieldDescriptor::TYPE_MESSAGE:
return JSMessageType(desc->message_type(), file);

default:
return "{}";
}
return js_field_type;
}

string JSFieldType(const FieldDescriptor *desc, const FileDescriptor *file) {
Expand All @@ -396,8 +400,8 @@ string JSFieldType(const FieldDescriptor *desc, const FileDescriptor *file) {
return js_field_type;
}

string AsObjectFieldType(const FieldDescriptor *desc,
const FileDescriptor *file) {
string AsObjectFieldType(
const FieldDescriptor *desc, const FileDescriptor *file) {
if (desc->type() != FieldDescriptor::TYPE_MESSAGE) {
return JSFieldType(desc, file);
}
Expand Down Expand Up @@ -523,35 +527,26 @@ string GetBasename(string filename) {
return basename;
}

/* Finds all message types used in all services in the file, and returns them
* as a map of fully qualified message type name to message descriptor */
std::map<string, const Descriptor*> GetAllMessages(const FileDescriptor* file) {
std::map<string, const Descriptor*> message_types;
for (int service_index = 0;
service_index < file->service_count();
++service_index) {
const ServiceDescriptor* service = file->service(service_index);
for (int method_index = 0;
method_index < service->method_count();
++method_index) {
const MethodDescriptor *method = service->method(method_index);
message_types[method->input_type()->full_name()] = method->input_type();
message_types[method->output_type()->full_name()] = method->output_type();
// Finds all message types used in all services in the file.
std::set<const Descriptor*> GetAllMessages(const FileDescriptor* file) {
std::set<const Descriptor*> messages;
for (int s = 0; s < file->service_count(); ++s) {
const ServiceDescriptor* service = file->service(s);
for (int m = 0; m < service->method_count(); ++m) {
const MethodDescriptor *method = service->method(m);
messages.insert(method->input_type());
messages.insert(method->output_type());
}
}

return message_types;
return messages;
}

void PrintMessagesDeps(Printer* printer, const FileDescriptor* file) {
std::map<string, const Descriptor*> messages = GetAllMessages(file);
std::map<string, string> vars;
for (std::map<string, const Descriptor*>::iterator it = messages.begin();
it != messages.end(); it++) {
vars["full_name"] = it->first;
void PrintClosureDependencies(Printer* printer, const FileDescriptor* file) {
for (const Descriptor* message : GetAllMessages(file)) {
printer->Print(
vars,
"goog.require('proto.$full_name$');\n");
"goog.require('proto.$full_name$');\n",
"full_name", message->full_name());
}
printer->Print("\n\n\n");
}
Expand Down Expand Up @@ -601,64 +596,26 @@ void PrintCommonJsMessagesDeps(Printer* printer, const FileDescriptor* file) {
}
}

void PrintES6Dependencies(Printer* printer, const FileDescriptor *file) {
std::map<string, string> vars;

for (int i = 0; i < file->dependency_count(); i++) {
const string& name = file->dependency(i)->name();
vars["alias"] = ModuleAlias(name);
vars["dep_filename"] = GetRootPath(file->name(), name) + StripProto(name);
// we need to give each cross-file import an alias
printer->Print(
vars,
"import * as $alias$ from '$dep_filename$_pb';\n");
}

if (file->dependency_count() != 0) {
printer->Print("\n");
}
}

void PrintES6Imports(Printer* printer, const FileDescriptor* file) {
std::map<string, string> vars;

printer->Print("import * as grpcWeb from 'grpc-web';\n\n");
PrintES6Dependencies(printer, file);

std::map<string, const Descriptor*> messages = GetAllMessages(file);
for (std::map<string, const Descriptor*>::iterator it = messages.begin();
it != messages.end();) {
if (it->second->file() != file) {
it = messages.erase(it);
} else {
it++;
std::set<string> imports;
for (const Descriptor* message : GetAllMessages(file)) {
const string& name = message->file()->name();
string dep_filename = GetRootPath(file->name(), name) + StripProto(name);
if (imports.find(dep_filename) != imports.end()) {
continue;
}
imports.insert(dep_filename);
// We need to give each cross-file import an alias.
printer->Print(
"import * as $alias$ from '$dep_filename$_pb';\n",
"alias", ModuleAlias(name),
"dep_filename", dep_filename);
}

if (messages.empty()) {
return;
}

std::map<string, const Descriptor*>::iterator it = messages.begin();
vars["base_name"] = GetBasename(StripProto(file->name()));
vars["class_name"] = it->second->name();

if (messages.size() == 1) {
printer->Print(vars, "import {$class_name$} from './$base_name$_pb';\n\n");
return;
}

printer->Print("import {\n");
printer->Indent();
printer->Print(vars, "$class_name$");

for (it++; it != messages.end(); it++) {
vars["class_name"] = it->second->name();
printer->Print(vars, ",\n$class_name$");
}

printer->Outdent();
printer->Print(vars, "} from './$base_name$_pb';\n\n");
printer->Print("\n\n");
}

void PrintTypescriptFile(Printer* printer, const FileDescriptor* file,
Expand Down Expand Up @@ -698,8 +655,8 @@ void PrintTypescriptFile(Printer* printer, const FileDescriptor* file,
const MethodDescriptor* method = service->method(method_index);
vars["js_method_name"] = LowercaseFirstLetter(method->name());
vars["method_name"] = method->name();
vars["input_type"] = JSMessageType(method->input_type(), file);
vars["output_type"] = JSMessageType(method->output_type(), file);
vars["input_type"] = JSMessageType(method->input_type());
vars["output_type"] = JSMessageType(method->output_type());
vars["serialize_func_name"] = GetSerializeMethodName(vars["mode"]);
vars["deserialize_func_name"] = GetDeserializeMethodName(vars["mode"]);
if (!method->client_streaming()) {
Expand Down Expand Up @@ -813,8 +770,8 @@ void PrintGrpcWebDtsClientClass(Printer* printer, const FileDescriptor* file,
++method_index) {
const MethodDescriptor* method = service->method(method_index);
vars["js_method_name"] = LowercaseFirstLetter(method->name());
vars["input_type"] = JSMessageType(method->input_type(), file);
vars["output_type"] = JSMessageType(method->output_type(), file);
vars["input_type"] = JSMessageType(method->input_type());
vars["output_type"] = JSMessageType(method->output_type());
if (!method->client_streaming()) {
if (method->server_streaming()) {
printer->Print(vars, "$js_method_name$(\n");
Expand Down Expand Up @@ -1027,8 +984,17 @@ void PrintProtoDtsMessage(Printer *printer, const Descriptor *desc,

void PrintProtoDtsFile(Printer *printer, const FileDescriptor *file)
{
printer->Print("import * as jspb from \"google-protobuf\"\n\n");
PrintES6Dependencies(printer, file);
printer->Print("import * as jspb from 'google-protobuf'\n\n");

for (int i = 0; i < file->dependency_count(); i++) {
const string& name = file->dependency(i)->name();
// We need to give each cross-file import an alias.
printer->Print(
"import * as $alias$ from '$dep_filename$_pb';\n",
"alias", ModuleAlias(name),
"dep_filename", GetRootPath(file->name(), name) + StripProto(name));
}
printer->Print("\n\n");

for (int i = 0; i < file->message_type_count(); i++) {
PrintProtoDtsMessage(printer, file->message_type(i), file);
Expand Down Expand Up @@ -1431,8 +1397,8 @@ void PrintMultipleFilesMode(const FileDescriptor* file, string file_name,
printer2.Print(vars, "goog.require('grpc.web.ClientReadableStream');\n");
printer2.Print(vars, "goog.require('grpc.web.Error');\n");

PrintMessagesDeps(&printer1, file);
PrintMessagesDeps(&printer2, file);
PrintClosureDependencies(&printer1, file);
PrintClosureDependencies(&printer2, file);

printer1.Print("goog.scope(function() {\n\n");
printer2.Print("goog.scope(function() {\n\n");
Expand Down Expand Up @@ -1724,7 +1690,8 @@ class GrpcCodeGenerator : public CodeGenerator {
printer.Print(vars, "goog.require('grpc.web.AbstractClientBase');\n");
printer.Print(vars, "goog.require('grpc.web.ClientReadableStream');\n");
printer.Print(vars, "goog.require('grpc.web.Error');\n");
PrintMessagesDeps(&printer, file);

PrintClosureDependencies(&printer, file);

printer.Print("goog.scope(function() {\n\n");
break;
Expand Down

0 comments on commit e62abdb

Please sign in to comment.