diff --git a/airbyte-integrations/connectors/source-github/source_github/source.py b/airbyte-integrations/connectors/source-github/source_github/source.py index 75d10581fce4..b5f00b743643 100644 --- a/airbyte-integrations/connectors/source-github/source_github/source.py +++ b/airbyte-integrations/connectors/source-github/source_github/source.py @@ -107,13 +107,41 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]: incremental_args = {**full_refresh_args, "start_date": config["start_date"]} organization_args = {"authenticator": authenticator, "organizations": organizations} + # Get the default branch for each repository + default_branches = {} + repository_stats_stream = RepositoryStats( + authenticator=authenticator, + repositories=repositories, + ) + for stream_slice in repository_stats_stream.stream_slices(sync_mode=SyncMode.full_refresh): + default_branches.update({ + r["full_name"]: r["default_branch"] + for r in repository_stats_stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice) + }) + + # Create mapping of repository to list of branches to pull commits for + # If no branches are specified for a repo, use its default branch + branches = set(filter(None, config.get("branch", "").split(" "))) + branches_to_pull: Mapping[str, List[str]] = {} + for branch in branches: + parts = branch.split("/", 2) + repo = parts[0] + "/" + parts[1] + if repo not in repositories: + continue + if repo not in branches_to_pull: + branches_to_pull[repo] = [] + branches_to_pull[repo].append(parts[2]) + for repo in repositories: + if not branches_to_pull.get(repo, []): + branches_to_pull[repo] = [default_branches[repo]] + return [ Assignees(**full_refresh_args), Branches(**full_refresh_args), Collaborators(**full_refresh_args), Comments(**incremental_args), CommitComments(**incremental_args), - Commits(**incremental_args), + Commits(**{**incremental_args, "branches_to_pull": branches_to_pull, "default_branches": default_branches}), Events(**incremental_args), IssueEvents(**incremental_args), IssueLabels(**full_refresh_args), diff --git a/airbyte-integrations/connectors/source-github/source_github/spec.json b/airbyte-integrations/connectors/source-github/source_github/spec.json index 35ccbcb4e641..166e52a162bd 100644 --- a/airbyte-integrations/connectors/source-github/source_github/spec.json +++ b/airbyte-integrations/connectors/source-github/source_github/spec.json @@ -23,6 +23,11 @@ "description": "The date from which you'd like to replicate data for GitHub in the format YYYY-MM-DDT00:00:00Z. All data generated after this date will be replicated. Note that it will be used only in the following incremental streams: comments, commits and issues.", "examples": ["2021-03-01T00:00:00Z"], "pattern": "^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}Z$" + }, + "branch": { + "type": "string", + "examples": ["airbytehq/airbyte/master"], + "description": "Space-delimited list of GitHub repository branches to pull commits for, e.g. `airbytehq/airbyte/master`. If no branches are specified for a repository, the default branch will be pulled." } } } diff --git a/airbyte-integrations/connectors/source-github/source_github/streams.py b/airbyte-integrations/connectors/source-github/source_github/streams.py index aa652f7ea867..79c674694d4f 100644 --- a/airbyte-integrations/connectors/source-github/source_github/streams.py +++ b/airbyte-integrations/connectors/source-github/source_github/streams.py @@ -129,6 +129,10 @@ def read_records(self, stream_slice: Mapping[str, any] = None, **kwargs) -> Iter # For private repositories `Teams` stream is not available and we get "404 Client Error: Not Found for # url: https://api.github.com/orgs/sherifnada/teams?per_page=100" error. error_msg = f"Syncing `Team` stream isn't available for repository `{stream_slice['repository']}`." + elif e.response.status_code == requests.codes.NOT_FOUND and "/repos?" in error_msg: + # `Repositories` stream is not available for repositories not in an organization. + # Handle "404 Client Error: Not Found for url: https://api.github.com/orgs/cjwooo/repos?per_page=100" error. + error_msg = f"Syncing `Repositories` stream isn't available for organization `{stream_slice['organization']}`." elif e.response.status_code == requests.codes.CONFLICT: error_msg = ( f"Syncing `{self.name}` stream isn't available for repository " @@ -645,17 +649,97 @@ class Commits(IncrementalGithubStream): "committer", ) - def transform(self, record: MutableMapping[str, Any], repository: str = None, **kwargs) -> MutableMapping[str, Any]: + def __init__(self, branches_to_pull: Mapping[str, List[str]], default_branches: Mapping[str, str], **kwargs): + super().__init__(**kwargs) + self.branches_to_pull = branches_to_pull + self.default_branches = default_branches + + """ + Pull commits from each branch of each repository, tracking state for each branch + """ + + def request_params(self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, **kwargs) -> MutableMapping[str, Any]: + params = super(IncrementalGithubStream, self).request_params(stream_state=stream_state, stream_slice=stream_slice, **kwargs) + params["since"] = self.get_starting_point(stream_state=stream_state, repository=stream_slice["repository"], branch=stream_slice["branch"]) + params["sha"] = stream_slice["branch"] + return params + + def stream_slices(self, **kwargs) -> Iterable[Optional[Mapping[str, Any]]]: + for stream_slice in super().stream_slices(**kwargs): + repository = stream_slice["repository"] + for branch in self.branches_to_pull.get(repository, []): + yield {"branch": branch, "repository": repository} + + def parse_response( + self, + response: requests.Response, + stream_state: Mapping[str, Any], + stream_slice: Mapping[str, Any] = None, + next_page_token: Mapping[str, Any] = None, + ) -> Iterable[Mapping]: + for record in response.json(): # GitHub puts records in an array. + yield self.transform(record=record, repository=stream_slice["repository"], branch=stream_slice["branch"]) + + def transform(self, record: MutableMapping[str, Any], repository: str = None, branch: str = None, **kwargs) -> MutableMapping[str, Any]: record = super().transform(record=record, repository=repository) # Record of the `commits` stream doesn't have an updated_at/created_at field at the top level (so we could # just write `record["updated_at"]` or `record["created_at"]`). Instead each record has such value in # `commit.author.date`. So the easiest way is to just enrich the record returned from API with top level # field `created_at` and use it as cursor_field. + # Include the branch in the record record["created_at"] = record["commit"]["author"]["date"] + record["branch"] = branch return record + def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]): + state_value = latest_cursor_value = latest_record.get(self.cursor_field) + current_repository = latest_record["repository"] + current_branch = latest_record["branch"] + + if current_stream_state.get(current_repository): + repository_commits_state = current_stream_state[current_repository] + if repository_commits_state.get(self.cursor_field): + # transfer state from old source version to per-branch version + if current_branch == self.default_branches[current_repository]: + state_value = max(latest_cursor_value, repository_commits_state[self.cursor_field]) + del repository_commits_state[self.cursor_field] + elif repository_commits_state.get(current_branch, {}).get(self.cursor_field): + state_value = max(latest_cursor_value, repository_commits_state[current_branch][self.cursor_field]) + if current_repository not in current_stream_state: + current_stream_state[current_repository] = {} + current_stream_state[current_repository][current_branch] = {self.cursor_field: state_value} + return current_stream_state + + def get_starting_point(self, stream_state: Mapping[str, Any], repository: str, branch: str) -> str: + start_point = self._start_date + if stream_state and stream_state.get(repository, {}).get(branch, {}).get(self.cursor_field): + return max(start_point, stream_state[repository][branch][self.cursor_field]) + if branch == self.default_branches[repository]: + return super().get_starting_point(stream_state=stream_state, repository=repository) + return start_point + + def read_records( + self, + sync_mode: SyncMode, + cursor_field: List[str] = None, + stream_slice: Mapping[str, Any] = None, + stream_state: Mapping[str, Any] = None, + ) -> Iterable[Mapping[str, Any]]: + repository = stream_slice["repository"] + start_point_map = { + branch: self.get_starting_point(stream_state=stream_state, repository=repository, branch=branch) + for branch in self.branches_to_pull.get(repository, []) + } + for record in super(SemiIncrementalGithubStream, self).read_records( + sync_mode=sync_mode, cursor_field=cursor_field, stream_slice=stream_slice, stream_state=stream_state + ): + if record.get(self.cursor_field) > start_point_map[stream_slice["branch"]]: + yield record + elif self.is_sorted_descending and record.get(self.cursor_field) < start_point_map[stream_slice["branch"]]: + break + class Issues(IncrementalGithubStream): """