Skip to content

Commit

Permalink
fix(langchain_tools_demo): Add ID Token credential flow for GCE (#198)
Browse files Browse the repository at this point in the history
Fixes:
#190
  • Loading branch information
duwenxin99 authored Jan 30, 2024
1 parent 960f6b1 commit ed9b6c2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/run_langchain_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
```bash
gcloud auth application-default login
```
* Tip: if you are running into `403` error, check to make sure the service account you are using has the `Cloud Run Invoker` IAM in the retrieval service project.

1. Change into the `langchain_tools_demo` directory:

Expand Down
13 changes: 12 additions & 1 deletion langchain_tools_demo/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import aiohttp
import google.oauth2.id_token # type: ignore
from google.auth import compute_engine # type: ignore
from google.auth.transport.requests import Request # type: ignore
from langchain.agents.agent import ExceptionTool # type: ignore
from langchain.tools import StructuredTool
Expand All @@ -34,9 +35,19 @@ def get_id_token():
global CREDENTIALS
if CREDENTIALS is None:
CREDENTIALS, _ = google.auth.default()
if not hasattr(CREDENTIALS, "id_token"):
# Use Compute Engine default credential
CREDENTIALS = compute_engine.IDTokenCredentials(
request=Request(),
target_audience=BASE_URL,
use_metadata_identity_endpoint=True,
)
if not CREDENTIALS.valid:
CREDENTIALS.refresh(Request())
return CREDENTIALS.id_token
if hasattr(CREDENTIALS, "id_token"):
return CREDENTIALS.id_token
else:
return CREDENTIALS.token


def get_headers(client: aiohttp.ClientSession):
Expand Down

0 comments on commit ed9b6c2

Please sign in to comment.