Skip to content

Commit

Permalink
Merge pull request #5132 from rmosolgo/limit-dataloader-fibers
Browse files Browse the repository at this point in the history
Add Dataloader fiber_limit option
  • Loading branch information
rmosolgo authored Oct 24, 2024
2 parents 2ae5393 + 59ec797 commit 9268dc3
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 17 deletions.
41 changes: 31 additions & 10 deletions lib/graphql/dataloader.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,23 @@ module GraphQL
#
class Dataloader
class << self
attr_accessor :default_nonblocking
attr_accessor :default_nonblocking, :default_fiber_limit
end

NonblockingDataloader = Class.new(self) { self.default_nonblocking = true }

def self.use(schema, nonblocking: nil)
schema.dataloader_class = if nonblocking
def self.use(schema, nonblocking: nil, fiber_limit: nil)
dataloader_class = if nonblocking
warn("`nonblocking: true` is deprecated from `GraphQL::Dataloader`, please use `GraphQL::Dataloader::AsyncDataloader` instead. Docs: https://graphql-ruby.org/dataloader/async_dataloader.")
NonblockingDataloader
Class.new(self) { self.default_nonblocking = true }
else
self
end

if fiber_limit
dataloader_class = Class.new(dataloader_class)
dataloader_class.default_fiber_limit = fiber_limit
end

schema.dataloader_class = dataloader_class
end

# Call the block with a Dataloader instance,
Expand All @@ -50,14 +55,18 @@ def self.with_dataloading(&block)
result
end

def initialize(nonblocking: self.class.default_nonblocking)
def initialize(nonblocking: self.class.default_nonblocking, fiber_limit: self.class.default_fiber_limit)
@source_cache = Hash.new { |h, k| h[k] = {} }
@pending_jobs = []
if !nonblocking.nil?
@nonblocking = nonblocking
end
@fiber_limit = fiber_limit
end

# @return [Integer, nil]
attr_reader :fiber_limit

def nonblocking?
@nonblocking
end
Expand Down Expand Up @@ -178,6 +187,7 @@ def run_isolated
end

def run
jobs_fiber_limit, total_fiber_limit = calculate_fiber_limit
job_fibers = []
next_job_fibers = []
source_fibers = []
Expand All @@ -187,7 +197,7 @@ def run
while first_pass || job_fibers.any?
first_pass = false

while (f = (job_fibers.shift || spawn_job_fiber))
while (f = (job_fibers.shift || (((next_job_fibers.size + job_fibers.size) < jobs_fiber_limit) && spawn_job_fiber)))
if f.alive?
finished = run_fiber(f)
if !finished
Expand All @@ -197,8 +207,8 @@ def run
end
join_queues(job_fibers, next_job_fibers)

while source_fibers.any? || @source_cache.each_value.any? { |group_sources| group_sources.each_value.any?(&:pending?) }
while (f = source_fibers.shift || spawn_source_fiber)
while (source_fibers.any? || @source_cache.each_value.any? { |group_sources| group_sources.each_value.any?(&:pending?) })
while (f = source_fibers.shift || (((job_fibers.size + source_fibers.size + next_source_fibers.size + next_job_fibers.size) < total_fiber_limit) && spawn_source_fiber))
if f.alive?
finished = run_fiber(f)
if !finished
Expand Down Expand Up @@ -242,6 +252,17 @@ def spawn_fiber

private

def calculate_fiber_limit
total_fiber_limit = @fiber_limit || Float::INFINITY
if total_fiber_limit < 4
raise ArgumentError, "Dataloader fiber limit is too low (#{total_fiber_limit}), it must be at least 4"
end
total_fiber_limit -= 1 # deduct one fiber for `manager`
# Deduct at least one fiber for sources
jobs_fiber_limit = total_fiber_limit - 2
return jobs_fiber_limit, total_fiber_limit
end

def join_queues(prev_queue, new_queue)
@nonblocking && Fiber.scheduler.run
prev_queue.concat(new_queue)
Expand Down
5 changes: 3 additions & 2 deletions lib/graphql/dataloader/async_dataloader.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def yield
end

def run
jobs_fiber_limit, total_fiber_limit = calculate_fiber_limit
job_fibers = []
next_job_fibers = []
source_tasks = []
Expand All @@ -23,7 +24,7 @@ def run
first_pass = false
fiber_vars = get_fiber_variables

while (f = (job_fibers.shift || spawn_job_fiber))
while (f = (job_fibers.shift || (((job_fibers.size + next_job_fibers.size + source_tasks.size) < jobs_fiber_limit) && spawn_job_fiber)))
if f.alive?
finished = run_fiber(f)
if !finished
Expand All @@ -37,7 +38,7 @@ def run
Sync do |root_task|
set_fiber_variables(fiber_vars)
while source_tasks.any? || @source_cache.each_value.any? { |group_sources| group_sources.each_value.any?(&:pending?) }
while (task = source_tasks.shift || spawn_source_task(root_task, sources_condition))
while (task = (source_tasks.shift || (((job_fibers.size + next_job_fibers.size + source_tasks.size + next_source_tasks.size) < total_fiber_limit) && spawn_source_task(root_task, sources_condition))))
if task.alive?
root_task.yield # give the source task a chance to run
next_source_tasks << task
Expand Down
2 changes: 1 addition & 1 deletion lib/graphql/dataloader/source.rb
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def sync(pending_result_keys)
while pending_result_keys.any? { |key| !@results.key?(key) }
iterations += 1
if iterations > MAX_ITERATIONS
raise "#{self.class}#sync tried #{MAX_ITERATIONS} times to load pending keys (#{pending_result_keys}), but they still weren't loaded. There is likely a circular dependency."
raise "#{self.class}#sync tried #{MAX_ITERATIONS} times to load pending keys (#{pending_result_keys}), but they still weren't loaded. There is likely a circular dependency#{@dataloader.fiber_limit ? " or `fiber_limit: #{@dataloader.fiber_limit}` is set too low" : ""}."
end
@dataloader.yield
end
Expand Down
8 changes: 4 additions & 4 deletions spec/graphql/dataloader/nonblocking_dataloader_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
require "spec_helper"

if Fiber.respond_to?(:scheduler) # Ruby 3+
describe GraphQL::Dataloader::NonblockingDataloader do
describe "GraphQL::Dataloader::NonblockingDataloader" do
class NonblockingSchema < GraphQL::Schema
class SleepSource < GraphQL::Dataloader::Source
def fetch(keys)
Expand Down Expand Up @@ -84,7 +84,7 @@ def wait_for(tag:, wait:)
end

query(Query)
use GraphQL::Dataloader::NonblockingDataloader
use GraphQL::Dataloader, nonblocking: true
end

def with_scheduler
Expand All @@ -99,7 +99,7 @@ def self.included(child_class)
child_class.class_eval do

it "runs IO in parallel by default" do
dataloader = GraphQL::Dataloader::NonblockingDataloader.new
dataloader = GraphQL::Dataloader.new(nonblocking: true)
results = {}
dataloader.append_job { sleep(0.1); results[:a] = 1 }
dataloader.append_job { sleep(0.2); results[:b] = 2 }
Expand All @@ -115,7 +115,7 @@ def self.included(child_class)
end

it "works with sources" do
dataloader = GraphQL::Dataloader::NonblockingDataloader.new
dataloader = GraphQL::Dataloader.new(nonblocking: true)
r1 = dataloader.with(NonblockingSchema::SleepSource).request(0.1)
r2 = dataloader.with(NonblockingSchema::SleepSource).request(0.2)
r3 = dataloader.with(NonblockingSchema::SleepSource).request(0.3)
Expand Down
8 changes: 8 additions & 0 deletions spec/graphql/dataloader/source_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@ def fetch(keys)
end
expected_message = "FailsToLoadSource#sync tried 1000 times to load pending keys ([1]), but they still weren't loaded. There is likely a circular dependency."
assert_equal expected_message, err.message

dl = GraphQL::Dataloader.new(fiber_limit: 10000)
dl.append_job { dl.with(FailsToLoadSource).load(1) }
err = assert_raises RuntimeError do
dl.run
end
expected_message = "FailsToLoadSource#sync tried 1000 times to load pending keys ([1]), but they still weren't loaded. There is likely a circular dependency or `fiber_limit: 10000` is set too low."
assert_equal expected_message, err.message
end

it "is pending when waiting for false and nil" do
Expand Down
132 changes: 132 additions & 0 deletions spec/graphql/dataloader_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,52 @@ class Query < GraphQL::Schema::Object
end

module DataloaderAssertions
module FiberCounting
class << self
attr_accessor :starting_fiber_count, :last_spawn_fiber_count, :last_max_fiber_count

def current_fiber_count
count_active_fibers - starting_fiber_count
end

def count_active_fibers
GC.start
ObjectSpace.each_object(Fiber).count
end
end

def initialize(*args, **kwargs, &block)
super
FiberCounting.starting_fiber_count = FiberCounting.count_active_fibers
FiberCounting.last_max_fiber_count = 0
FiberCounting.last_spawn_fiber_count = 0
end

def spawn_fiber
result = super
update_fiber_counts
result
end

def spawn_source_task(parent_task, condition)
result = super
if result
update_fiber_counts
end
result
end

private

def update_fiber_counts
FiberCounting.last_spawn_fiber_count += 1
current_count = FiberCounting.current_fiber_count
if current_count > FiberCounting.last_max_fiber_count
FiberCounting.last_max_fiber_count = current_count
end
end
end

def self.included(child_class)
child_class.class_eval do
let(:schema) { make_schema_from(FiberSchema) }
Expand Down Expand Up @@ -1038,6 +1084,92 @@ def self.included(child_class)
response = parts_schema.execute(query).to_h
assert_equal [4, 4, 4, 4], response["data"]["manufacturers"].map { |parts_obj| parts_obj["parts"].size }
end

describe "fiber_limit" do
def assert_last_max_fiber_count(expected_last_max_fiber_count)
if schema.dataloader_class == GraphQL::Dataloader::AsyncDataloader && FiberCounting.last_max_fiber_count == (expected_last_max_fiber_count + 1)
# TODO why does this happen sometimes?
warn "AsyncDataloader had +1 last_max_fiber_count"
assert_equal (expected_last_max_fiber_count + 1), FiberCounting.last_max_fiber_count
else
assert_equal expected_last_max_fiber_count, FiberCounting.last_max_fiber_count
end
end

it "respects a configured fiber_limit" do
query_str = <<-GRAPHQL
{
recipes {
ingredients {
name
}
}
nestedIngredient(id: 2) {
name
}
keyIngredient(id: 4) {
name
}
commonIngredientsWithLoad(recipe1Id: 5, recipe2Id: 6) {
name
}
}
GRAPHQL

fiber_counting_dataloader_class = Class.new(schema.dataloader_class)
fiber_counting_dataloader_class.include(FiberCounting)

res = schema.execute(query_str, context: { dataloader: fiber_counting_dataloader_class.new })
assert_nil res.context.dataloader.fiber_limit
assert_equal 12, FiberCounting.last_spawn_fiber_count
assert_last_max_fiber_count(9)

res = schema.execute(query_str, context: { dataloader: fiber_counting_dataloader_class.new(fiber_limit: 4) })
assert_equal 4, res.context.dataloader.fiber_limit
assert_equal 14, FiberCounting.last_spawn_fiber_count
assert_last_max_fiber_count(4)

res = schema.execute(query_str, context: { dataloader: fiber_counting_dataloader_class.new(fiber_limit: 6) })
assert_equal 6, res.context.dataloader.fiber_limit
assert_equal 10, FiberCounting.last_spawn_fiber_count
assert_last_max_fiber_count(6)
end

it "accepts a default fiber_limit config" do
schema = Class.new(FiberSchema) do
use GraphQL::Dataloader, fiber_limit: 4
end
query_str = <<-GRAPHQL
{
recipes {
ingredients {
name
}
}
nestedIngredient(id: 2) {
name
}
keyIngredient(id: 4) {
name
}
commonIngredientsWithLoad(recipe1Id: 5, recipe2Id: 6) {
name
}
}
GRAPHQL
res = schema.execute(query_str)
assert_equal 4, res.context.dataloader.fiber_limit
assert_nil res["errors"]
end

it "requires at least three fibers" do
dl = GraphQL::Dataloader.new(fiber_limit: 2)
err = assert_raises ArgumentError do
dl.run
end
assert_equal "Dataloader fiber limit is too low (2), it must be at least 4", err.message
end
end
end
end
end
Expand Down

0 comments on commit 9268dc3

Please sign in to comment.