Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce Signal parameter types in Connect #71952

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/object/callable_method_pointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ CallableCustom::CompareLessFunc CallableCustomMethodPointerBase::get_compare_les
return compare_less;
}

bool CallableCustomMethodPointerBase::get_method_info(MethodInfo *r_method_info) const {
return false; // Ignore for C++ method pointers
}

uint32_t CallableCustomMethodPointerBase::hash() const {
return h;
}
Expand Down
1 change: 1 addition & 0 deletions core/object/callable_method_pointer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class CallableCustomMethodPointerBase : public CallableCustom {
#endif
virtual CompareEqualFunc get_compare_equal_func() const;
virtual CompareLessFunc get_compare_less_func() const;
virtual bool get_method_info(MethodInfo *r_method_info) const;

virtual uint32_t hash() const;
};
Expand Down
96 changes: 87 additions & 9 deletions core/object/object.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,52 @@ Variant Object::property_get_revert(const StringName &p_name) const {
return Variant();
}

bool Object::get_method_info(const StringName &p_method, MethodInfo *r_info) const {
if (r_info == nullptr) {
return false;
}
List<MethodInfo> method_list;
int index = 0;

// Check script instance
if (script_instance) {
script_instance->get_method_list(&method_list);
while (index < method_list.size()) {
if (method_list[index].name == p_method) {
// Duplicate method info to return, since method_list only contains references
*r_info = method_list[index];
return true;
}
index++;
}
}
// Check base class
MethodBind *method = ClassDB::get_method(get_class_name(), p_method);
if (method == nullptr) {
return false;
}
// Create MethodInfo from MethodBind.
MethodInfo minfo;
minfo.name = method->get_name();
minfo.id = method->get_method_id();

for (int i = 0; i < method->get_argument_count(); i++) {
minfo.arguments.push_back(method->get_argument_info(i));
}

minfo.return_val = method->get_return_info();
minfo.flags = method->get_hint_flags();

for (int i = 0; i < method->get_argument_count(); i++) {
if (method->has_default_argument(i)) {
minfo.default_arguments.push_back(method->get_default_argument(i));
}
}

*r_info = minfo;
return true;
}

void Object::get_method_list(List<MethodInfo> *p_list) const {
ClassDB::get_method_list(get_class_name(), p_list);
if (script_instance) {
Expand Down Expand Up @@ -955,15 +1001,15 @@ void Object::add_user_signal(const MethodInfo &p_signal) {
ERR_FAIL_COND_MSG(ClassDB::has_signal(get_class_name(), p_signal.name), "User signal's name conflicts with a built-in signal of '" + get_class_name() + "'.");
ERR_FAIL_COND_MSG(signal_map.has(p_signal.name), "Trying to add already existing signal '" + p_signal.name + "'.");
SignalData s;
s.user = p_signal;
s.info = p_signal;
signal_map[p_signal.name] = s;
}

bool Object::_has_user_signal(const StringName &p_name) const {
if (!signal_map.has(p_name)) {
return false;
}
return signal_map[p_name].user.name.length() > 0;
return signal_map[p_name].info.name.length() > 0;
}

struct _ObjectSignalDisconnectData {
Expand Down Expand Up @@ -1179,9 +1225,9 @@ void Object::get_signal_list(List<MethodInfo> *p_signals) const {
//find maybe usersignals?

for (const KeyValue<StringName, SignalData> &E : signal_map) {
if (!E.value.user.name.is_empty()) {
if (!E.value.info.name.is_empty()) {
//user signal
p_signals->push_back(E.value.user);
p_signals->push_back(E.value.info);
}
}
}
Expand Down Expand Up @@ -1237,14 +1283,21 @@ Error Object::connect(const StringName &p_signal, const Callable &p_callable, ui

SignalData *s = signal_map.getptr(p_signal);
if (!s) {
bool signal_is_valid = ClassDB::has_signal(get_class_name(), p_signal);
MethodInfo signal;
bool signal_is_valid = ClassDB::get_signal(get_class_name(), p_signal, &signal);
//check in script
if (!signal_is_valid && !script.is_null()) {
if (Ref<Script>(script)->has_script_signal(p_signal)) {
signal_is_valid = true;
List<MethodInfo> signal_list;
Ref<Script>(script)->get_script_signal_list(&signal_list);
for (int i = 0; i < signal_list.size(); i++) {
if (signal_list[i].name == p_signal) {
signal = signal_list[i];
signal_is_valid = true;
break;
}
}
#ifdef TOOLS_ENABLED
else {
if (!signal_is_valid) {
//allow connecting signals anyway if script is invalid, see issue #17070
if (!Ref<Script>(script)->is_valid()) {
signal_is_valid = true;
Expand All @@ -1255,12 +1308,37 @@ Error Object::connect(const StringName &p_signal, const Callable &p_callable, ui

ERR_FAIL_COND_V_MSG(!signal_is_valid, ERR_INVALID_PARAMETER, "In Object of type '" + String(get_class()) + "': Attempt to connect nonexistent signal '" + p_signal + "' to callable '" + p_callable + "'.");

signal_map[p_signal] = SignalData();
signal_map[p_signal] = SignalData(signal);
s = &signal_map[p_signal];
}

Callable target = p_callable;

MethodInfo target_method_info;
if (target.get_method_info(&target_method_info)) {
int target_required_arguments_count = target_method_info.arguments.size();
int signal_required_argument_count = s->info.arguments.size() - s->info.default_arguments.size();
if (target_required_arguments_count < signal_required_argument_count) {
ERR_FAIL_V_MSG(ERR_INVALID_PARAMETER, vformat("Callable '%s' could not be connected to '%s': '%s' requires at least %s argument(s).", target.operator String(), p_signal, p_signal, signal_required_argument_count));
}
for (int i = 0; i < signal_required_argument_count; i++) {
PropertyInfo target_arg_info = target_method_info.arguments[i];
PropertyInfo signal_arg_info = s->info.arguments[i];
if (Variant::can_convert_strict(target_arg_info.type, signal_arg_info.type) && target_arg_info.class_name == signal_arg_info.class_name) {
continue;
}
String signature = "";
for (int arg = 0; arg < signal_required_argument_count; arg++) {
signature += Variant::get_type_name(s->info.arguments[arg].type);
if (arg < signal_required_argument_count - 1) {
signature += ", ";
}
}
ERR_FAIL_V_MSG(ERR_INVALID_PARAMETER, vformat("Callable '%s' could not be connected to '%s': '%s' does not have the required arguments '%s'", target.operator String(), p_signal, target.get_method(), signature));
}
}


//compare with the base callable, so binds can be ignored
if (s->slot_map.has(*target.get_base_comparator())) {
if (p_flags & CONNECT_REFERENCE_COUNTED) {
Expand Down
11 changes: 10 additions & 1 deletion core/object/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,16 @@ class Object {
List<Connection>::Element *cE = nullptr;
};

MethodInfo user;
MethodInfo info;
VMap<Callable, Slot> slot_map;

SignalData() {

}

SignalData(MethodInfo p_info) {
info = p_info;
}
};

HashMap<StringName, SignalData> signal_map;
Expand Down Expand Up @@ -799,6 +807,7 @@ class Object {
Variant property_get_revert(const StringName &p_name) const;

bool has_method(const StringName &p_method) const;
bool get_method_info(const StringName &p_method, MethodInfo *r_info) const;
void get_method_list(List<MethodInfo> *p_list) const;
Variant callv(const StringName &p_method, const Array &p_args);
virtual Variant callp(const StringName &p_method, const Variant **p_args, int p_argcount, Callable::CallError &r_error);
Expand Down
22 changes: 22 additions & 0 deletions core/variant/callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,24 @@ StringName Callable::get_method() const {
return method;
}

bool Callable::get_method_info(MethodInfo *r_method_info) const {
if (r_method_info == nullptr) {
return false;
}
if (is_null()) {
return false;
}
if (is_custom()) {
return custom->get_method_info(r_method_info);
}
MethodInfo method_info;
if (!get_object()->get_method_info(method, &method_info)) {
return false;
}
*r_method_info = method_info;
return true;
}

int Callable::get_bound_arguments_count() const {
if (!is_null() && is_custom()) {
return custom->get_bound_arguments_count();
Expand Down Expand Up @@ -388,6 +406,10 @@ const Callable *CallableCustom::get_base_comparator() const {
return nullptr;
}

bool CallableCustom::get_method_info(MethodInfo *r_method_info) const {
return false;
}

int CallableCustom::get_bound_arguments_count() const {
return 0;
}
Expand Down
3 changes: 3 additions & 0 deletions core/variant/callable.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "core/templates/list.h"

class Object;
struct MethodInfo;
class Variant;
class CallableCustom;

Expand Down Expand Up @@ -106,6 +107,7 @@ class Callable {
Object *get_object() const;
ObjectID get_object_id() const;
StringName get_method() const;
bool get_method_info(MethodInfo *r_method_info) const;
CallableCustom *get_custom() const;
int get_bound_arguments_count() const;
void get_bound_arguments_ref(Vector<Variant> &r_arguments, int &r_argcount) const; // Internal engine use, the exposed one is below.
Expand Down Expand Up @@ -150,6 +152,7 @@ class CallableCustom {
virtual void call(const Variant **p_arguments, int p_argcount, Variant &r_return_value, Callable::CallError &r_call_error) const = 0;
virtual Error rpc(int p_peer_id, const Variant **p_arguments, int p_argcount, Callable::CallError &r_call_error) const;
virtual const Callable *get_base_comparator() const;
virtual bool get_method_info(MethodInfo *r_method_info) const;
virtual int get_bound_arguments_count() const;
virtual void get_bound_arguments(Vector<Variant> &r_arguments, int &r_argcount) const;

Expand Down
8 changes: 8 additions & 0 deletions core/variant/callable_bind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ const Callable *CallableCustomBind::get_base_comparator() const {
return &callable;
}

bool CallableCustomBind::get_method_info(MethodInfo *r_method_info) const {
return callable.get_method_info(r_method_info);
}

int CallableCustomBind::get_bound_arguments_count() const {
return callable.get_bound_arguments_count() + binds.size();
}
Expand Down Expand Up @@ -205,6 +209,10 @@ const Callable *CallableCustomUnbind::get_base_comparator() const {
return &callable;
}

bool CallableCustomUnbind::get_method_info(MethodInfo *r_method_info) const {
return callable.get_method_info(r_method_info);
}

int CallableCustomUnbind::get_bound_arguments_count() const {
return callable.get_bound_arguments_count() - argcount;
}
Expand Down
2 changes: 2 additions & 0 deletions core/variant/callable_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class CallableCustomBind : public CallableCustom {
virtual ObjectID get_object() const override; //must always be able to provide an object
virtual void call(const Variant **p_arguments, int p_argcount, Variant &r_return_value, Callable::CallError &r_call_error) const override;
virtual const Callable *get_base_comparator() const override;
virtual bool get_method_info(MethodInfo *r_method_info) const override;
virtual int get_bound_arguments_count() const override;
virtual void get_bound_arguments(Vector<Variant> &r_arguments, int &r_argcount) const override;
Callable get_callable() { return callable; }
Expand All @@ -77,6 +78,7 @@ class CallableCustomUnbind : public CallableCustom {
virtual ObjectID get_object() const override; //must always be able to provide an object
virtual void call(const Variant **p_arguments, int p_argcount, Variant &r_return_value, Callable::CallError &r_call_error) const override;
virtual const Callable *get_base_comparator() const override;
virtual bool get_method_info(MethodInfo *r_method_info) const override;
virtual int get_bound_arguments_count() const override;
virtual void get_bound_arguments(Vector<Variant> &r_arguments, int &r_argcount) const override;

Expand Down
28 changes: 11 additions & 17 deletions modules/gdscript/gdscript.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -715,12 +715,7 @@ bool GDScript::_update_exports(bool *r_err, bool p_recursive_call, PlaceHolderSc
} break;
case GDScriptParser::ClassNode::Member::SIGNAL: {
// TODO: Cache this in parser to avoid loops like this.
Vector<StringName> parameters_names;
parameters_names.resize(member.signal->parameters.size());
for (int j = 0; j < member.signal->parameters.size(); j++) {
parameters_names.write[j] = member.signal->parameters[j]->identifier->name;
}
_signals[member.signal->identifier->name] = parameters_names;
_signals[member.signal->identifier->name] = member.signal->datatype.method_info;
} break;
case GDScriptParser::ClassNode::Member::GROUP: {
members_cache.push_back(member.annotation->export_info);
Expand Down Expand Up @@ -1246,15 +1241,8 @@ bool GDScript::has_script_signal(const StringName &p_signal) const {
}

void GDScript::_get_script_signal_list(List<MethodInfo> *r_list, bool p_include_base) const {
for (const KeyValue<StringName, Vector<StringName>> &E : _signals) {
MethodInfo mi;
mi.name = E.key;
for (int i = 0; i < E.value.size(); i++) {
PropertyInfo arg;
arg.name = E.value[i];
mi.arguments.push_back(arg);
}
r_list->push_back(mi);
for (const KeyValue<StringName, MethodInfo> &E : _signals) {
r_list->push_back(E.value);
}

if (!p_include_base) {
Expand Down Expand Up @@ -1616,7 +1604,7 @@ bool GDScriptInstance::get(const StringName &p_name, Variant &r_ret) const {
// Signals.
const GDScript *sl = sptr;
while (sl) {
HashMap<StringName, Vector<StringName>>::ConstIterator E = sl->_signals.find(p_name);
HashMap<StringName, MethodInfo>::ConstIterator E = sl->_signals.find(p_name);
if (E) {
r_ret = Signal(this->owner, E->key);
return true; //index found
Expand Down Expand Up @@ -1801,7 +1789,13 @@ void GDScriptInstance::get_method_list(List<MethodInfo> *p_list) const {
MethodInfo mi;
mi.name = E.key;
for (int i = 0; i < E.value->get_argument_count(); i++) {
mi.arguments.push_back(PropertyInfo(Variant::NIL, "arg" + itos(i)));
PropertyInfo arg = E.value->get_argument_type(i);
#ifdef TOOLS_ENABLED
arg.name = E.value->get_argument_name(i);
#else
arg.name = "arg" + itos(i);
#endif
mi.arguments.push_back(arg);
}
p_list->push_back(mi);
}
Expand Down
2 changes: 1 addition & 1 deletion modules/gdscript/gdscript.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class GDScript : public Script {
HashMap<StringName, GDScriptFunction *> member_functions;
HashMap<StringName, MemberInfo> member_indices; //members are just indices to the instantiated script.
HashMap<StringName, Ref<GDScript>> subclasses;
HashMap<StringName, Vector<StringName>> _signals;
HashMap<StringName, MethodInfo> _signals;
Dictionary rpc_config;

#ifdef TOOLS_ENABLED
Expand Down
Loading