Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Source GitHub: Add option to pull commits from user-specified branches #5931

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,41 @@ def streams(self, config: Mapping[str, Any]) -> List[Stream]:
incremental_args = {**full_refresh_args, "start_date": config["start_date"]}
organization_args = {"authenticator": authenticator, "organizations": organizations}

# Get the default branch for each repository
default_branches = {}
repository_stats_stream = RepositoryStats(
authenticator=authenticator,
repositories=repositories,
)
for stream_slice in repository_stats_stream.stream_slices(sync_mode=SyncMode.full_refresh):
default_branches.update({
r["full_name"]: r["default_branch"]
for r in repository_stats_stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice)
})

# Create mapping of repository to list of branches to pull commits for
# If no branches are specified for a repo, use its default branch
branches = set(filter(None, config.get("branch", "").split(" ")))
branches_to_pull: Mapping[str, List[str]] = {}
for branch in branches:
parts = branch.split("/", 2)
repo = parts[0] + "/" + parts[1]
if repo not in repositories:
continue
if repo not in branches_to_pull:
branches_to_pull[repo] = []
branches_to_pull[repo].append(parts[2])
for repo in repositories:
if not branches_to_pull.get(repo, []):
branches_to_pull[repo] = [default_branches[repo]]

return [
Assignees(**full_refresh_args),
Branches(**full_refresh_args),
Collaborators(**full_refresh_args),
Comments(**incremental_args),
CommitComments(**incremental_args),
Commits(**incremental_args),
Commits(**{**incremental_args, "branches_to_pull": branches_to_pull, "default_branches": default_branches}),
Events(**incremental_args),
IssueEvents(**incremental_args),
IssueLabels(**full_refresh_args),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
"description": "The date from which you'd like to replicate data for GitHub in the format YYYY-MM-DDT00:00:00Z. All data generated after this date will be replicated. Note that it will be used only in the following incremental streams: comments, commits and issues.",
"examples": ["2021-03-01T00:00:00Z"],
"pattern": "^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}Z$"
},
"branch": {
"type": "string",
"examples": ["airbytehq/airbyte/master"],
"description": "Space-delimited list of GitHub repository branches to pull commits for, e.g. `airbytehq/airbyte/master`. If no branches are specified for a repository, the default branch will be pulled."
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def read_records(self, stream_slice: Mapping[str, any] = None, **kwargs) -> Iter
# For private repositories `Teams` stream is not available and we get "404 Client Error: Not Found for
# url: https://api.github.com/orgs/sherifnada/teams?per_page=100" error.
error_msg = f"Syncing `Team` stream isn't available for repository `{stream_slice['repository']}`."
elif e.response.status_code == requests.codes.NOT_FOUND and "/repos?" in error_msg:
# `Repositories` stream is not available for repositories not in an organization.
# Handle "404 Client Error: Not Found for url: https://api.github.com/orgs/cjwooo/repos?per_page=100" error.
error_msg = f"Syncing `Repositories` stream isn't available for organization `{stream_slice['organization']}`."
elif e.response.status_code == requests.codes.CONFLICT:
error_msg = (
f"Syncing `{self.name}` stream isn't available for repository "
Expand Down Expand Up @@ -645,17 +649,97 @@ class Commits(IncrementalGithubStream):
"committer",
)

def transform(self, record: MutableMapping[str, Any], repository: str = None, **kwargs) -> MutableMapping[str, Any]:
def __init__(self, branches_to_pull: Mapping[str, List[str]], default_branches: Mapping[str, str], **kwargs):
super().__init__(**kwargs)
self.branches_to_pull = branches_to_pull
self.default_branches = default_branches

"""
Pull commits from each branch of each repository, tracking state for each branch
"""

def request_params(self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, **kwargs) -> MutableMapping[str, Any]:
params = super(IncrementalGithubStream, self).request_params(stream_state=stream_state, stream_slice=stream_slice, **kwargs)
params["since"] = self.get_starting_point(stream_state=stream_state, repository=stream_slice["repository"], branch=stream_slice["branch"])
params["sha"] = stream_slice["branch"]
return params

def stream_slices(self, **kwargs) -> Iterable[Optional[Mapping[str, Any]]]:
for stream_slice in super().stream_slices(**kwargs):
repository = stream_slice["repository"]
for branch in self.branches_to_pull.get(repository, []):
yield {"branch": branch, "repository": repository}

def parse_response(
self,
response: requests.Response,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> Iterable[Mapping]:
for record in response.json(): # GitHub puts records in an array.
yield self.transform(record=record, repository=stream_slice["repository"], branch=stream_slice["branch"])

def transform(self, record: MutableMapping[str, Any], repository: str = None, branch: str = None, **kwargs) -> MutableMapping[str, Any]:
record = super().transform(record=record, repository=repository)

# Record of the `commits` stream doesn't have an updated_at/created_at field at the top level (so we could
# just write `record["updated_at"]` or `record["created_at"]`). Instead each record has such value in
# `commit.author.date`. So the easiest way is to just enrich the record returned from API with top level
# field `created_at` and use it as cursor_field.
# Include the branch in the record
record["created_at"] = record["commit"]["author"]["date"]
record["branch"] = branch

return record

def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]):
state_value = latest_cursor_value = latest_record.get(self.cursor_field)
current_repository = latest_record["repository"]
current_branch = latest_record["branch"]

if current_stream_state.get(current_repository):
repository_commits_state = current_stream_state[current_repository]
if repository_commits_state.get(self.cursor_field):
# transfer state from old source version to per-branch version
if current_branch == self.default_branches[current_repository]:
state_value = max(latest_cursor_value, repository_commits_state[self.cursor_field])
del repository_commits_state[self.cursor_field]
elif repository_commits_state.get(current_branch, {}).get(self.cursor_field):
state_value = max(latest_cursor_value, repository_commits_state[current_branch][self.cursor_field])
if current_repository not in current_stream_state:
current_stream_state[current_repository] = {}
current_stream_state[current_repository][current_branch] = {self.cursor_field: state_value}
return current_stream_state

def get_starting_point(self, stream_state: Mapping[str, Any], repository: str, branch: str) -> str:
start_point = self._start_date
if stream_state and stream_state.get(repository, {}).get(branch, {}).get(self.cursor_field):
return max(start_point, stream_state[repository][branch][self.cursor_field])
if branch == self.default_branches[repository]:
return super().get_starting_point(stream_state=stream_state, repository=repository)
return start_point

def read_records(
self,
sync_mode: SyncMode,
cursor_field: List[str] = None,
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Mapping[str, Any]]:
repository = stream_slice["repository"]
start_point_map = {
branch: self.get_starting_point(stream_state=stream_state, repository=repository, branch=branch)
for branch in self.branches_to_pull.get(repository, [])
}
for record in super(SemiIncrementalGithubStream, self).read_records(
sync_mode=sync_mode, cursor_field=cursor_field, stream_slice=stream_slice, stream_state=stream_state
):
if record.get(self.cursor_field) > start_point_map[stream_slice["branch"]]:
yield record
elif self.is_sorted_descending and record.get(self.cursor_field) < start_point_map[stream_slice["branch"]]:
break


class Issues(IncrementalGithubStream):
"""
Expand Down