Skip to content

Commit

Permalink
fix(csharp/src/Drivers/Apache/Spark): correct BatchSize implementatio…
Browse files Browse the repository at this point in the history
…n for base reader (#2199)

Fixes 
1. How HiveServer2Statement.BatchSize is used in the HiveServer2Reader
2. Batch size valid range
3. Test cases
  • Loading branch information
birschick-bq authored Sep 30, 2024
1 parent b6b2377 commit 155c2f5
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 12 deletions.
7 changes: 2 additions & 5 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ internal class HiveServer2Reader : IArrowArrayStream
private const char AsciiPeriod = '.';

private HiveServer2Statement? _statement;
private readonly long _batchSize;
private readonly DataTypeConversion _dataTypeConversion;
private static readonly IReadOnlyDictionary<ArrowTypeId, Func<StringArray, IArrowType, IArrowArray>> s_arrowStringConverters =
new Dictionary<ArrowTypeId, Func<StringArray, IArrowType, IArrowArray>>()
Expand All @@ -51,12 +50,10 @@ internal class HiveServer2Reader : IArrowArrayStream
public HiveServer2Reader(
HiveServer2Statement statement,
Schema schema,
DataTypeConversion dataTypeConversion,
long batchSize = HiveServer2Connection.BatchSizeDefault)
DataTypeConversion dataTypeConversion)
{
_statement = statement;
Schema = schema;
_batchSize = batchSize;
_dataTypeConversion = dataTypeConversion;
}

Expand All @@ -69,7 +66,7 @@ public HiveServer2Reader(
return null;
}

var request = new TFetchResultsReq(_statement.OperationHandle, TFetchOrientation.FETCH_NEXT, _batchSize);
var request = new TFetchResultsReq(_statement.OperationHandle, TFetchOrientation.FETCH_NEXT, _statement.BatchSize);
TFetchResultsResp response = await _statement.Connection.Client.FetchResults(request, cancellationToken);

int columnCount = response.Results.Columns.Count;
Expand Down
6 changes: 3 additions & 3 deletions csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,11 @@ public class Options

private void UpdatePollTimeIfValid(string key, string value) => PollTimeMilliseconds = !string.IsNullOrEmpty(key) && int.TryParse(value, result: out int pollTimeMilliseconds) && pollTimeMilliseconds >= 0
? pollTimeMilliseconds
: throw new ArgumentException($"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than or equal to zero.", nameof(value));
: throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than or equal to -1.");

private void UpdateBatchSizeIfValid(string key, string value) => BatchSize = !string.IsNullOrEmpty(value) && int.TryParse(value, out int batchSize) && batchSize > 0
private void UpdateBatchSizeIfValid(string key, string value) => BatchSize = !string.IsNullOrEmpty(value) && long.TryParse(value, out long batchSize) && batchSize > 0
? batchSize
: throw new ArgumentException($"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than zero.", nameof(value));
: throw new ArgumentOutOfRangeException(key, value, $"The value '{value}' for option '{key}' is invalid. Must be a numeric value greater than zero.");

public override void Dispose()
{
Expand Down
5 changes: 4 additions & 1 deletion csharp/test/Apache.Arrow.Adbc.Tests/TestConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
* limitations under the License.
*/

using System;
using System.Text.Json.Serialization;

namespace Apache.Arrow.Adbc.Tests
{
/// <summary>
/// Base test configuration values.
/// </summary>
public abstract class TestConfiguration
public abstract class TestConfiguration : ICloneable
{
/// <summary>
/// The query to run.
Expand All @@ -41,6 +42,8 @@ public abstract class TestConfiguration
/// </summary>
[JsonPropertyName("metadata")]
public TestMetadata Metadata { get; set; } = new TestMetadata();

public virtual object Clone() => MemberwiseClone();
}

/// <summary>
Expand Down
6 changes: 6 additions & 0 deletions csharp/test/Drivers/Apache/ApacheTestConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,11 @@ public class ApacheTestConfiguration : TestConfiguration
[JsonPropertyName("uri"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public string Uri { get; set; } = string.Empty;

[JsonPropertyName("batch_size"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public string BatchSize { get; set; } = string.Empty;

[JsonPropertyName("polltime_milliseconds"), JsonIgnore(Condition = JsonIgnoreCondition.WhenWritingDefault)]
public string PollTimeMilliseconds { get; set; } = string.Empty;

}
}
8 changes: 8 additions & 0 deletions csharp/test/Drivers/Apache/Spark/SparkTestEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,14 @@ public override Dictionary<string, string> GetDriverParameters(SparkTestConfigur
{
parameters.Add(SparkParameters.TLSOptions, testConfiguration.TlsOptions!);
}
if (!string.IsNullOrEmpty(testConfiguration.BatchSize))
{
parameters.Add(HiveServer2Statement.Options.BatchSize, testConfiguration.BatchSize!);
}
if (!string.IsNullOrEmpty(testConfiguration.PollTimeMilliseconds))
{
parameters.Add(HiveServer2Statement.Options.PollTimeMilliseconds, testConfiguration.PollTimeMilliseconds!);
}

return parameters;
}
Expand Down
22 changes: 19 additions & 3 deletions csharp/test/Drivers/Apache/Spark/StatementTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,17 @@ public class StatementTests : TestBase<SparkTestConfiguration, SparkTestEnvironm
[InlineData("2147483647")]
public void CanSetOptionPollTime(string value, bool throws = false)
{
var testConfiguration = TestConfiguration.Clone() as SparkTestConfiguration;
testConfiguration!.PollTimeMilliseconds = value;
if (throws)
{
Assert.Throws<ArgumentOutOfRangeException>(() => NewConnection(testConfiguration).CreateStatement());
}

AdbcStatement statement = NewConnection().CreateStatement();
if (throws)
{
Assert.Throws<ArgumentException>(() => statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, value));
Assert.Throws<ArgumentOutOfRangeException>(() => statement.SetOption(SparkStatement.Options.PollTimeMilliseconds, value));
}
else
{
Expand All @@ -73,16 +80,25 @@ public void CanSetOptionPollTime(string value, bool throws = false)
[InlineData("-1", true)]
[InlineData("one", true)]
[InlineData("-2147483648", true)]
[InlineData("2147483648", true)]
[InlineData("2147483648", false)]
[InlineData("9223372036854775807", false)]
[InlineData("9223372036854775808", true)]
[InlineData("0", true)]
[InlineData("1")]
[InlineData("2147483647")]
public void CanSetOptionBatchSize(string value, bool throws = false)
{
var testConfiguration = TestConfiguration.Clone() as SparkTestConfiguration;
testConfiguration!.BatchSize = value;
if (throws)
{
Assert.Throws<ArgumentOutOfRangeException>(() => NewConnection(testConfiguration).CreateStatement());
}

AdbcStatement statement = NewConnection().CreateStatement();
if (throws)
{
Assert.Throws<ArgumentException>(() => statement.SetOption(SparkStatement.Options.BatchSize, value));
Assert.Throws<ArgumentOutOfRangeException>(() => statement!.SetOption(SparkStatement.Options.BatchSize, value));
}
else
{
Expand Down

0 comments on commit 155c2f5

Please sign in to comment.