Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounded generics #499

Merged
merged 17 commits into from
Mar 7, 2022
1 change: 1 addition & 0 deletions lib/steep.rb
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

require "steep/range_extension"

require "steep/interface/type_param"
require "steep/interface/function"
require "steep/interface/block"
require "steep/interface/method_type"
Expand Down
287 changes: 149 additions & 138 deletions lib/steep/ast/types/factory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -21,147 +21,154 @@ def type_name_resolver
@type_name_resolver ||= RBS::TypeNameResolver.from_env(definition_builder.env)
end

def type_opt(type)
if type
type(type)
end
end

def type(type)
ty = type_cache[type] and return ty

type_cache[type] = case type
when RBS::Types::Bases::Any
Any.new(location: nil)
when RBS::Types::Bases::Class
Class.new(location: nil)
when RBS::Types::Bases::Instance
Instance.new(location: nil)
when RBS::Types::Bases::Self
Self.new(location: nil)
when RBS::Types::Bases::Top
Top.new(location: nil)
when RBS::Types::Bases::Bottom
Bot.new(location: nil)
when RBS::Types::Bases::Bool
Boolean.new(location: nil)
when RBS::Types::Bases::Void
Void.new(location: nil)
when RBS::Types::Bases::Nil
Nil.new(location: nil)
when RBS::Types::Variable
Var.new(name: type.name, location: nil)
when RBS::Types::ClassSingleton
type_name = type.name
Name::Singleton.new(name: type_name, location: nil)
when RBS::Types::ClassInstance
type_name = type.name
args = type.args.map {|arg| type(arg) }
Name::Instance.new(name: type_name, args: args, location: nil)
when RBS::Types::Interface
type_name = type.name
args = type.args.map {|arg| type(arg) }
Name::Interface.new(name: type_name, args: args, location: nil)
when RBS::Types::Alias
type_name = type.name
args = type.args.map {|arg| type(arg) }
Name::Alias.new(name: type_name, args: args, location: nil)
when RBS::Types::Union
Union.build(types: type.types.map {|ty| type(ty) }, location: nil)
when RBS::Types::Intersection
Intersection.build(types: type.types.map {|ty| type(ty) }, location: nil)
when RBS::Types::Optional
Union.build(types: [type(type.type), Nil.new(location: nil)], location: nil)
when RBS::Types::Literal
Literal.new(value: type.literal, location: nil)
when RBS::Types::Tuple
Tuple.new(types: type.types.map {|ty| type(ty) }, location: nil)
when RBS::Types::Record
elements = type.fields.each.with_object({}) do |(key, value), hash|
hash[key] = type(value)
end
Record.new(elements: elements, location: nil)
when RBS::Types::Proc
func = Interface::Function.new(
params: params(type.type),
return_type: type(type.type.return_type),
location: type.location
)
block = if type.block
Interface::Block.new(
type: Interface::Function.new(
params: params(type.block.type),
return_type: type(type.block.type.return_type),
location: type.location
),
optional: !type.block.required
)
end
type_cache[type] =
case type
when RBS::Types::Bases::Any
Any.new(location: type.location)
when RBS::Types::Bases::Class
Class.new(location: type.location)
when RBS::Types::Bases::Instance
Instance.new(location: type.location)
when RBS::Types::Bases::Self
Self.new(location: type.location)
when RBS::Types::Bases::Top
Top.new(location: type.location)
when RBS::Types::Bases::Bottom
Bot.new(location: type.location)
when RBS::Types::Bases::Bool
Boolean.new(location: type.location)
when RBS::Types::Bases::Void
Void.new(location: type.location)
when RBS::Types::Bases::Nil
Nil.new(location: type.location)
when RBS::Types::Variable
Var.new(name: type.name, location: type.location)
when RBS::Types::ClassSingleton
type_name = type.name
Name::Singleton.new(name: type_name, location: type.location)
when RBS::Types::ClassInstance
type_name = type.name
args = type.args.map {|arg| type(arg) }
Name::Instance.new(name: type_name, args: args, location: type.location)
when RBS::Types::Interface
type_name = type.name
args = type.args.map {|arg| type(arg) }
Name::Interface.new(name: type_name, args: args, location: type.location)
when RBS::Types::Alias
type_name = type.name
args = type.args.map {|arg| type(arg) }
Name::Alias.new(name: type_name, args: args, location: type.location)
when RBS::Types::Union
Union.build(types: type.types.map {|ty| type(ty) }, location: type.location)
when RBS::Types::Intersection
Intersection.build(types: type.types.map {|ty| type(ty) }, location: type.location)
when RBS::Types::Optional
Union.build(types: [type(type.type), Nil.new(location: nil)], location: type.location)
when RBS::Types::Literal
Literal.new(value: type.literal, location: type.location)
when RBS::Types::Tuple
Tuple.new(types: type.types.map {|ty| type(ty) }, location: type.location)
when RBS::Types::Record
elements = type.fields.each.with_object({}) do |(key, value), hash|
hash[key] = type(value)
end
Record.new(elements: elements, location: type.location)
when RBS::Types::Proc
func = Interface::Function.new(
params: params(type.type),
return_type: type(type.type.return_type),
location: type.location
)
block = if type.block
Interface::Block.new(
type: Interface::Function.new(
params: params(type.block.type),
return_type: type(type.block.type.return_type),
location: type.location
),
optional: !type.block.required
)
end

Proc.new(type: func, block: block)
else
raise "Unexpected type given: #{type}"
end
Proc.new(type: func, block: block)
else
raise "Unexpected type given: #{type}"
end
end

def type_1(type)
case type
when Any
RBS::Types::Bases::Any.new(location: nil)
RBS::Types::Bases::Any.new(location: type.location)
when Class
RBS::Types::Bases::Class.new(location: nil)
RBS::Types::Bases::Class.new(location: type.location)
when Instance
RBS::Types::Bases::Instance.new(location: nil)
RBS::Types::Bases::Instance.new(location: type.location)
when Self
RBS::Types::Bases::Self.new(location: nil)
RBS::Types::Bases::Self.new(location: type.location)
when Top
RBS::Types::Bases::Top.new(location: nil)
RBS::Types::Bases::Top.new(location: type.location)
when Bot
RBS::Types::Bases::Bottom.new(location: nil)
RBS::Types::Bases::Bottom.new(location: type.location)
when Boolean
RBS::Types::Bases::Bool.new(location: nil)
RBS::Types::Bases::Bool.new(location: type.location)
when Void
RBS::Types::Bases::Void.new(location: nil)
RBS::Types::Bases::Void.new(location: type.location)
when Nil
RBS::Types::Bases::Nil.new(location: nil)
RBS::Types::Bases::Nil.new(location: type.location)
when Var
RBS::Types::Variable.new(name: type.name, location: nil)
RBS::Types::Variable.new(name: type.name, location: type.location)
when Name::Singleton
RBS::Types::ClassSingleton.new(name: type.name, location: nil)
RBS::Types::ClassSingleton.new(name: type.name, location: type.location)
when Name::Instance
RBS::Types::ClassInstance.new(
name: type.name,
args: type.args.map {|arg| type_1(arg) },
location: nil
location: type.location
)
when Name::Interface
RBS::Types::Interface.new(
name: type.name,
args: type.args.map {|arg| type_1(arg) },
location: nil
location: type.location
)
when Name::Alias
RBS::Types::Alias.new(
name: type.name,
args: type.args.map {|arg| type_1(arg) },
location: nil
location: type.location
)
when Union
RBS::Types::Union.new(
types: type.types.map {|ty| type_1(ty) },
location: nil
location: type.location
)
when Intersection
RBS::Types::Intersection.new(
types: type.types.map {|ty| type_1(ty) },
location: nil
location: type.location
)
when Literal
RBS::Types::Literal.new(literal: type.value, location: nil)
RBS::Types::Literal.new(literal: type.value, location: type.location)
when Tuple
RBS::Types::Tuple.new(
types: type.types.map {|ty| type_1(ty) },
location: nil
location: type.location
)
when Record
fields = type.elements.each.with_object({}) do |(key, value), hash|
hash[key] = type_1(value)
end
RBS::Types::Record.new(fields: fields, location: nil)
RBS::Types::Record.new(fields: fields, location: type.location)
when Proc
block = if type.block
RBS::Types::Block.new(
Expand All @@ -172,10 +179,10 @@ def type_1(type)
RBS::Types::Proc.new(
type: function_1(type.type),
block: block,
location: nil
location: type.location
)
when Logic::Base
RBS::Types::Bases::Bool.new(location: nil)
RBS::Types::Bases::Bool.new(location: type.location)
else
raise "Unexpected type given: #{type} (#{type.class})"
end
Expand Down Expand Up @@ -208,47 +215,61 @@ def params(type)
)
end

def type_param(type_param)
Interface::TypeParam.new(
name: type_param.name,
upper_bound: type_param.upper_bound&.yield_self {|u| type(u) },
variance: type_param.variance,
unchecked: type_param.unchecked?
)
end

def type_param_1(type_param)
RBS::AST::TypeParam.new(
name: type_param.name,
variance: type_param.variance,
upper_bound: type_param.upper_bound&.yield_self {|u| type_1(u) },
location: type_param.location
).unchecked!(type_param.unchecked)
end

def method_type(method_type, self_type:, subst2: nil, method_decls:)
fvs = self_type.free_variables()

type_params = []
alpha_vars = []
alpha_types = []

method_type.type_params.map do |type_param|
name = type_param.name
conflicting_names = []

if fvs.include?(name)
type = Types::Var.fresh(name)
alpha_vars << name
alpha_types << type
type_params << type.name
else
type_params << name
type_params = method_type.type_params.map do |type_param|
if fvs.include?(type_param.name)
conflicting_names << type_param.name
end

type_param(type_param)
end
subst = Interface::Substitution.build(alpha_vars, alpha_types)

type_params, subst = Interface::TypeParam.rename(type_params, conflicting_names)
subst.merge!(subst2, overwrite: true) if subst2

type = Interface::MethodType.new(
type_params: type_params,
type: Interface::Function.new(
params: params(method_type.type).subst(subst),
return_type: type(method_type.type.return_type).subst(subst),
location: method_type.location
),
block: method_type.block&.yield_self do |block|
Interface::Block.new(
optional: !block.required,
type: Interface::Function.new(
params: params(block.type).subst(subst),
return_type: type(block.type.return_type).subst(subst),
location: nil
type =
Interface::MethodType.new(
type_params: type_params,
type: Interface::Function.new(
params: params(method_type.type).subst(subst),
return_type: type(method_type.type.return_type).subst(subst),
location: method_type.location
),
block: method_type.block&.yield_self do |block|
Interface::Block.new(
optional: !block.required,
type: Interface::Function.new(
params: params(block.type).subst(subst),
return_type: type(block.type.return_type).subst(subst),
location: nil
)
)
)
end,
method_decls: method_decls
)
end,
method_decls: method_decls
)

if block_given?
yield type
Expand All @@ -264,24 +285,14 @@ def method_type_1(method_type, self_type:)
alpha_vars = []
alpha_types = []

method_type.type_params.map do |name|
if fvs.include?(name)
type = RBS::Types::Variable.new(name: name, location: nil),
alpha_vars << name
alpha_types << type
end

type_params << RBS::AST::TypeParam.new(
name: name,
variance: :invariant,
upper_bound: nil,
location: nil
)
conflicting_names = method_type.type_params.each.with_object([]) do |param, names|
names << params.name if fvs.include?(param.name)
end
subst = Interface::Substitution.build(alpha_vars, alpha_types)

type_params, subst = Interface::TypeParam.rename(method_type.type_params, conflicting_names)

type = RBS::MethodType.new(
type_params: type_params,
type_params: type_params.map {|param| type_param_1(param) },
type: function_1(method_type.type.subst(subst)),
block: method_type.block&.yield_self do |block|
block_type = block.type.subst(subst)
Expand Down
Loading