diff --git a/pkgs/google_generative_ai/lib/src/client.dart b/pkgs/google_generative_ai/lib/src/client.dart index 8f80afb..8ecf94f 100644 --- a/pkgs/google_generative_ai/lib/src/client.dart +++ b/pkgs/google_generative_ai/lib/src/client.dart @@ -28,13 +28,16 @@ const _clientName = 'genai-dart/$_packageVersion'; final class HttpApiClient implements ApiClient { final String _apiKey; + final http.Client? _httpClient; late final _headers = { 'x-goog-api-key': _apiKey, 'x-goog-api-client': _clientName }; - HttpApiClient({required String model, required String apiKey}) - : _apiKey = apiKey; + HttpApiClient( + {required String model, required String apiKey, http.Client? httpClient}) + : _apiKey = apiKey, + _httpClient = httpClient ?? http.Client(); @override Future makeRequest(Uri uri, Uint8List body) async { @@ -52,7 +55,9 @@ final class HttpApiClient implements ApiClient { ..bodyBytes = body ..headers.addAll(_headers) ..headers['Content-Type'] = 'application/json'; - final response = await request.send(); + final response = _httpClient == null + ? await request.send() + : await _httpClient.send(request); await response.stream .toStringStream() .transform(const LineSplitter()) diff --git a/pkgs/google_generative_ai/lib/src/model.dart b/pkgs/google_generative_ai/lib/src/model.dart index cbdcc5c..4bca087 100644 --- a/pkgs/google_generative_ai/lib/src/model.dart +++ b/pkgs/google_generative_ai/lib/src/model.dart @@ -15,6 +15,8 @@ import 'dart:async'; import 'dart:convert'; +import 'package:http/http.dart' as http; + import 'api.dart'; import 'client.dart'; import 'content.dart'; @@ -46,14 +48,16 @@ final class GenerativeModel { {required String model, required String apiKey, List safetySettings = const [], - GenerationConfig? generationConfig}) + GenerationConfig? generationConfig, + http.Client? httpClient}) : // TODO: Allow `models/` prefix and strip it. // https://github.com/google/generative-ai-js/blob/2be48f8e5427f2f6191f24bcb8000b450715a0de/packages/main/src/models/generative-model.ts#L59 _model = model, _safetySettings = safetySettings, _generationConfig = generationConfig, - _client = HttpApiClient(model: model, apiKey: apiKey); + _client = + HttpApiClient(model: model, apiKey: apiKey, httpClient: httpClient); Future _makeRequest( Task task, Map parameters) async {