Skip to content

Commit

Permalink
Add support for Tower client bearer authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
pditommaso committed Oct 24, 2020
1 parent e4529d4 commit 532282a
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ class SimpleHttpClient {
/**
* Basic auth token
*/
private String authToken
private String basicToken

/**
* Basic Bearer token
*/
private String bearerToken

/**
* Http user agent
Expand All @@ -72,34 +77,68 @@ class SimpleHttpClient {

private int backOffDelay = DEFAULT_BACK_OFF_DELAY

private CookieManager cookieManager

SimpleHttpClient() {
cookieManager = new CookieManager()
CookieHandler.setDefault(cookieManager)
}


HttpCookie getCookie(final String cookieName) {
for( HttpCookie it : cookieManager.cookieStore.cookies ) {
if( it.name == cookieName )
return it
}
return null
}

/**
* Send a json formatted string as HTTP POST request
* @param json Message content as JSON
*/
void sendHttpMessage(String url, String json, String method = 'POST') throws IllegalStateException, IllegalArgumentException{

if (!url)
throw new IllegalStateException("URL needs to be set!")
void sendHttpMessage(String url, String json, String method = 'POST') throws IllegalStateException, IllegalArgumentException {
sendHttpMessage(url, body: json, method: method)
}

void sendHttpMessage(Map args, String url) throws IllegalStateException, IllegalArgumentException{

final method = args.method as String
final body = args.body as String
final contentType = args.contentType as String ?: "application/json"
final charset = args.charset as String

// reset the error count
if( !url )
throw new IllegalStateException("Missing 'url' argument")
if( !method )
throw new IllegalStateException("Missing 'method' argument")

// reset the error count
errorCount = 0

while( true ) {
// Open a connection to the target url
def con = getHttpConnection(url)
// Make header settings
con.setRequestMethod(method)
con.setRequestProperty("Content-Type", "application/json")
con.setRequestProperty("Content-Type", contentType)
con.setRequestProperty("User-Agent", userAgent)
if( authToken )
con.setRequestProperty("Authorization","Basic ${authToken.bytes.encodeBase64()}")
// set charset
if( charset )
con.setRequestProperty("charset", "utf-8")
// set auth
if( bearerToken )
con.setRequestProperty("Authorization","Bearer ${bearerToken}")
else if( basicToken )
con.setRequestProperty("Authorization","Basic ${basicToken.bytes.encodeBase64()}")

con.setDoOutput(true)

// Send POST request
if( json ) {
if( body ) {
DataOutputStream output = new DataOutputStream(con.getOutputStream())
output.writeBytes(json)
output.writeBytes(body)
output.flush()
output.close()
}
Expand Down Expand Up @@ -191,8 +230,13 @@ class SimpleHttpClient {
return this
}

SimpleHttpClient setAuthToken(String tkn) {
authToken = tkn
SimpleHttpClient setBasicToken(String tkn) {
basicToken = tkn
return this
}

SimpleHttpClient setBearerToken(String tkn) {
bearerToken = tkn
return this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class SimpleHttpClientTest extends Specification{
when:
def client = new SimpleHttpClient()
client.setAuthToken(TOKEN)
client.setBasicToken(TOKEN)
client.setUserAgent(AGENT)
client.sendHttpMessage( ENDPOINT, PAYLOAD )
then:
Expand Down
123 changes: 102 additions & 21 deletions modules/nf-tower/src/main/io/seqera/tower/plugin/TowerClient.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import nextflow.util.SimpleHttpClient
@CompileStatic
class TowerClient implements TraceObserver {

static final String DEF_ENDPOINT_URL = 'https://api.tower.nf'
static final public String DEF_ENDPOINT_URL = 'https://api.tower.nf'

static private final int TASKS_PER_REQUEST = 100

Expand Down Expand Up @@ -121,6 +121,8 @@ class TowerClient implements TraceObserver {

private boolean towerLaunch

private String refreshToken

/**
* Constructor that consumes a URL and creates
* a basic HTTP client.
Expand Down Expand Up @@ -228,7 +230,9 @@ class TowerClient implements TraceObserver {
this.aggregator = new ResourcesAggregator(session)
this.runName = session.getRunName()
this.runId = session.getUniqueId()
this.httpClient = new SimpleHttpClient().setAuthToken(TOKEN_PREFIX + getAccessToken())
this.httpClient = new SimpleHttpClient()
// set the auth token
setAuthToken( httpClient, getAccessToken() )

// send hello to verify auth
final req = makeCreateReq(session)
Expand All @@ -243,6 +247,31 @@ class TowerClient implements TraceObserver {
log.warn(ret.message.toString())
}

protected void setAuthToken(SimpleHttpClient client, String token) {
// check for plain jwt token
if( token.count('.')==2 ) {
client.setBearerToken(token)
return
}

// try checking personal access token
try {
final plain = new String(token.decodeBase64())
final p = plain.indexOf('.')
if( p!=-1 && new JsonSlurper().parseText( plain.substring(0, p) ) ) {
// ok this is bearer token
client.setBearerToken(token)
return
}
}
catch ( Exception e ) {
log.trace "Enable to set bearer token ~ Reason: $e.message"
}

// fallback on simple token
client.setBasicToken(TOKEN_PREFIX + token)
}

protected Map makeCreateReq(Session session) {
def result = new HashMap(5)
result.sessionId = session.uniqueId.toString()
Expand Down Expand Up @@ -374,6 +403,37 @@ class TowerClient implements TraceObserver {
events << new ProcessEvent(trace: trace)
}

protected void refreshToken(String refresh) {
log.debug "Token refresh request >> $refresh"
final url = "$endpoint/oauth/access_token"
httpClient.sendHttpMessage(
url,
method: 'POST',
contentType: "application/x-www-form-urlencoded",
body: "grant_type=refresh_token&refresh_token=${URLEncoder.encode(refresh, 'UTF-8')}" )

final authCookie = httpClient.getCookie('JWT')
final refreshCookie = httpClient.getCookie('JWT_REFRESH_TOKEN')

// set the new bearer token
if( authCookie?.value ) {
log.trace "Updating http client bearer token=$authCookie.value"
httpClient.setBearerToken(authCookie.value)
}
else {
log.warn "Missing JWT cookie from refresh token response ~ $authCookie"
}

// set the new refresh token
if( refreshCookie?.value ) {
log.trace "Updating http client refresh token=$refreshCookie.value"
refreshToken = refreshCookie.value
}
else {
log.warn "Missing JWT_REFRESH_TOKEN cookie from refresh token response ~ $refreshCookie"
}
}

/**
* Little helper method that sends a HTTP POST message as JSON with
* the current run status, ISO 8601 UTC timestamp, run name and the TraceRecord
Expand All @@ -383,24 +443,45 @@ class TowerClient implements TraceObserver {
* @param payload An additional object to send. Must be of type TraceRecord or Manifest
*/
protected Response sendHttpMessage(String url, Map payload, String method='POST'){
// The actual HTTP request
final String json = payload != null ? generator.toJson(payload) : null
final String debug = json != null ? JsonOutput.prettyPrint(json).indent() : '-'
log.trace "HTTP url=$url; payload:\n${debug}\n"
try {
httpClient.sendHttpMessage(url, json, method)
return new Response(httpClient.responseCode, httpClient.getResponse())
}
catch( ConnectException e ) {
String msg = "Unable to connect Tower host: ${getHostUrl(url)}"
return new Response(0, msg)
}
catch (IOException e) {
int code = httpClient.responseCode
String msg = ( code == 401
? 'Unauthorized Tower access -- Make sure you have specified the correct access token'
: "Unexpected response code $code for request $url" )
return new Response(code, msg, httpClient.response)

int refreshTries=0
final currentRefresh = refreshToken ?: env.get('TOWER_REFRESH_TOKEN')

while ( true ) {
// The actual HTTP request
final String json = payload != null ? generator.toJson(payload) : null
final String debug = json != null ? JsonOutput.prettyPrint(json).indent() : '-'
log.trace "HTTP url=$url; payload:\n${debug}\n"
try {
if( refreshTries==1 ) {
refreshToken(currentRefresh)
}

httpClient.sendHttpMessage(url, json, method)
return new Response(httpClient.responseCode, httpClient.getResponse())
}
catch( ConnectException e ) {
String msg = "Unable to connect Tower host: ${getHostUrl(url)}"
return new Response(0, msg)
}
catch (IOException e) {
int code = httpClient.responseCode
if( code == 401 && ++refreshTries==1 && currentRefresh ) {
// when 401 Unauthorized error is returned - only the very first time -
// and a refresh token is available, make another iteration trying
// having refreshed the authorization token (see 'refreshToken' invocation above)
log.trace "Got 401 Unauthorized response ~ tries refreshing auth token"
continue
}
else {
log.trace("Got error $code - refreshTries=$refreshTries - currentRefresh=$currentRefresh")
}

String msg = ( code == 401
? 'Unauthorized Tower access -- Make sure you have specified the correct access token'
: "Unexpected response code $code for request $url" )
return new Response(code, msg, httpClient.response)
}
}
}

Expand Down Expand Up @@ -597,7 +678,7 @@ class TowerClient implements TraceObserver {
final long delay = period / 10 as long

while( !complete ) {
final ev = events.poll(delay, TimeUnit.MILLISECONDS)
final ProcessEvent ev = events.poll(delay, TimeUnit.MILLISECONDS)
// reconcile task events ie. send out only the last event
if( ev ) {
log.trace "Tower event=$ev"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TowerFactory implements TraceObserverFactory {
return Collections.emptyList()

if ( !endpoint || endpoint=='-' )
endpoint = TowerClient.DEF_ENDPOINT_URL
endpoint = System.getenv('TOWER_API_ENDPOINT') ?: TowerClient.DEF_ENDPOINT_URL

final tower = new TowerClient(endpoint)
if( aliveInterval )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,4 +375,28 @@ class TowerClientTest extends Specification {
tower.getUrlTraceHeartbeat() == 'https://tower.nf/trace/12345/heartbeat'
tower.getUrlTraceComplete() == 'https://tower.nf/trace/12345/complete'
}

def 'should set the auth token' () {
given:
def http = Mock(SimpleHttpClient)
TowerClient client = Spy(TowerClient, constructorArgs: ['https://tower.nf'])
and:
def SIMPLE = '4ffbf1009ebabea77db3d72efefa836dfbb71271'
def BEARER = 'eyJ0aWQiOiA1fS5jZmM1YjVhOThjZjM2MTk1NjBjZWU1YmMwODUxYzA1ZjkzMDdmN2Iz'

when:
client.setAuthToken(http, SIMPLE)
then:
http.setBasicToken('@token:' + SIMPLE) >> null

when:
client.setAuthToken(http, SIMPLE)
then:
http.setBasicToken('@token:' + SIMPLE) >> null

when:
client.setAuthToken(http, BEARER)
then:
http.setBearerToken(BEARER) >> null
}
}

0 comments on commit 532282a

Please sign in to comment.