Skip to content

Commit

Permalink
Rework message request and response
Browse files Browse the repository at this point in the history
  • Loading branch information
iychoi committed Sep 12, 2024
1 parent 20d132c commit 8c63d48
Show file tree
Hide file tree
Showing 105 changed files with 1,097 additions and 433 deletions.
2 changes: 2 additions & 0 deletions fs/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type MetadataCacheTimeoutSetting struct {
Inherit bool
}

// CacheConfig defines cache config
type CacheConfig struct {
Timeout time.Duration // cache timeout
CleanupTime time.Duration //
Expand All @@ -30,6 +31,7 @@ type CacheConfig struct {
StartNewTransaction bool
}

// NewDefaultCacheConfig creates a new default CacheConfig
func NewDefaultCacheConfig() CacheConfig {
return CacheConfig{
Timeout: FileSystemTimeoutDefault,
Expand Down
4 changes: 2 additions & 2 deletions icommands/environment.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func (manager *ICommandsEnvironmentManager) Load(processID int) error {
// continue
} else {
authScheme := types.GetAuthScheme(manager.Environment.AuthenticationScheme)
if authScheme == types.AuthSchemePAM {
if authScheme.IsPAM() {
manager.Password = ""
manager.PamToken = password
} else {
Expand Down Expand Up @@ -247,7 +247,7 @@ func (manager *ICommandsEnvironmentManager) SaveEnvironment() error {
authScheme := types.GetAuthScheme(manager.Environment.AuthenticationScheme)

password := manager.Password
if authScheme == types.AuthSchemePAM {
if authScheme.IsPAM() {
password = manager.PamToken
}

Expand Down
180 changes: 107 additions & 73 deletions irods/connection/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ func (conn *IRODSConnection) SupportParallelUpload() bool {
return conn.serverVersion.HasHigherVersionThan(4, 2, 9)
}

func (conn *IRODSConnection) requirePAMPassword() bool {
return conn.serverVersion.HasHigherVersionThan(4, 3, 0)
}

func (conn *IRODSConnection) requiresCSNegotiation() bool {
return conn.account.ClientServerNegotiation
}
Expand Down Expand Up @@ -203,27 +207,13 @@ func (conn *IRODSConnection) setSocketOpt(socket net.Conn, bufferSize int) {
}
}

// Connect connects to iRODS
func (conn *IRODSConnection) Connect() error {
func (conn *IRODSConnection) connectTCP() error {
logger := log.WithFields(log.Fields{
"package": "connection",
"struct": "IRODSConnection",
"function": "Connect",
"function": "connectTCP",
})

conn.connected = false

conn.account.FixAuthConfiguration()

err := conn.account.Validate()
if err != nil {
return xerrors.Errorf("invalid account (%q): %w", err.Error(), types.NewConnectionConfigError(conn.account))
}

// lock the connection
conn.Lock()
defer conn.Unlock()

server := fmt.Sprintf("%s:%d", conn.account.Host, conn.account.Port)
logger.Debugf("Connecting to %s", server)

Expand All @@ -250,16 +240,37 @@ func (conn *IRODSConnection) Connect() error {
}

conn.socket = socket
var irodsVersion *types.IRODSVersion
return nil
}

if conn.requiresCSNegotiation() {
// client-server negotiation
irodsVersion, err = conn.connectWithCSNegotiation()
} else {
// No client-server negotiation
irodsVersion, err = conn.connectWithoutCSNegotiation()
// Connect connects to iRODS
func (conn *IRODSConnection) Connect() error {
logger := log.WithFields(log.Fields{
"package": "connection",
"struct": "IRODSConnection",
"function": "Connect",
})

conn.account.FixAuthConfiguration()

err := conn.account.Validate()
if err != nil {
return xerrors.Errorf("invalid account (%q): %w", err.Error(), types.NewConnectionConfigError(conn.account))
}

conn.connected = false

// lock the connection
conn.Lock()
defer conn.Unlock()

// connect TCP
err = conn.connectTCP()
if err != nil {
return err
}

irodsVersion, err := conn.startup()
if err != nil {
connErr := xerrors.Errorf("failed to startup an iRODS connection to server %q and port %d (%s): %w", conn.account.Host, conn.account.Port, err.Error(), types.NewConnectionError())
logger.Errorf("%+v", connErr)
Expand All @@ -277,11 +288,38 @@ func (conn *IRODSConnection) Connect() error {
err = conn.loginNative()
case types.AuthSchemeGSI:
err = conn.loginGSI()
case types.AuthSchemePAM:
case types.AuthSchemePAM, types.AuthSchemePAMPassword:
if len(conn.account.PamToken) > 0 {
err = conn.loginPAMWithToken()
} else {
err = conn.loginPAMWithPassword()
if err != nil {
connErr := xerrors.Errorf("failed to login to irods using PAM authentication: %w", err)
logger.Errorf("%+v", connErr)
return connErr
}

// reconnect when success
conn.disconnectNow()

// connect TCP
err = conn.connectTCP()
if err != nil {
return err
}

_, err = conn.startup()
if err != nil {
connErr := xerrors.Errorf("failed to startup an iRODS connection to server %q and port %d (%s): %w", conn.account.Host, conn.account.Port, err.Error(), types.NewConnectionError())
logger.Errorf("%+v", connErr)
_ = conn.disconnectNow()
if conn.metrics != nil {
conn.metrics.IncreaseCounterForConnectionFailures(1)
}
return connErr
}

err = conn.loginPAMWithToken()
}
default:
err = xerrors.Errorf("unknown Authentication Scheme %q: %w", conn.account.AuthenticationScheme, types.NewConnectionConfigError(conn.account))
Expand All @@ -296,41 +334,54 @@ func (conn *IRODSConnection) Connect() error {

if conn.account.UseTicket() {
req := message.NewIRODSMessageTicketAdminRequest("session", conn.account.Ticket)
err := conn.RequestAndCheck(req, &message.IRODSMessageAdminResponse{}, nil)
err := conn.RequestAndCheck(req, &message.IRODSMessageTicketAdminResponse{}, nil)
if err != nil {
return xerrors.Errorf("received supply ticket error (%s): %w", err.Error(), types.NewAuthError(conn.account))
}
}

conn.connected = true
conn.lastSuccessfulAccess = time.Now()

return nil
}

func (conn *IRODSConnection) connectWithCSNegotiation() (*types.IRODSVersion, error) {
func (conn *IRODSConnection) startup() (*types.IRODSVersion, error) {
logger := log.WithFields(log.Fields{
"package": "connection",
"struct": "IRODSConnection",
"function": "connectWithCSNegotiation",
"function": "startup",
})

// Get client negotiation policy
clientPolicy := types.CSNegotiationRequireTCP
if len(conn.account.CSNegotiationPolicy) > 0 {
clientPolicy = conn.account.CSNegotiationPolicy
if conn.requiresCSNegotiation() {
// Get client negotiation policy
if len(conn.account.CSNegotiationPolicy) > 0 {
clientPolicy = conn.account.CSNegotiationPolicy
}
}

logger.Debug("Start up an iRODS connection")

// Send a startup message
logger.Debug("Start up a connection with CS Negotiation")
startup := message.NewIRODSMessageStartupPack(conn.account, conn.applicationName, conn.requiresCSNegotiation())

startup := message.NewIRODSMessageStartupPack(conn.account, conn.applicationName, true)
err := conn.RequestWithoutResponse(startup)
if err != nil {
return nil, xerrors.Errorf("failed to send startup (%s): %w", err.Error(), types.NewConnectionError())
if conn.requiresCSNegotiation() {
err := conn.RequestWithoutResponse(startup)
if err != nil {
return nil, xerrors.Errorf("failed to send startup (%s): %w", err.Error(), types.NewConnectionError())
}
} else {
// no cs negotiation
version := message.IRODSMessageVersion{}
err := conn.Request(startup, &version, nil)
if err != nil {
return nil, xerrors.Errorf("failed to receive version message (%s): %w", err.Error(), types.NewConnectionError())
}

return version.GetVersion(), nil
}

// Server responds with negotiation response
// cs negotiation response
negotiationMessage, err := conn.ReadMessage(nil)
if err != nil {
return nil, xerrors.Errorf("failed to receive negotiation message (%s): %w", err.Error(), types.NewConnectionError())
Expand Down Expand Up @@ -394,27 +445,7 @@ func (conn *IRODSConnection) connectWithCSNegotiation() (*types.IRODSVersion, er
}

return nil, xerrors.Errorf("unknown response message %q: %w", negotiationMessage.Body.Type, types.NewConnectionError())
}

func (conn *IRODSConnection) connectWithoutCSNegotiation() (*types.IRODSVersion, error) {
logger := log.WithFields(log.Fields{
"package": "connection",
"struct": "IRODSConnection",
"function": "connectWithoutCSNegotiation",
})

// No client-server negotiation
// Send a startup message
logger.Debug("Start up connection without CS Negotiation")

startup := message.NewIRODSMessageStartupPack(conn.account, conn.applicationName, false)
version := message.IRODSMessageVersion{}
err := conn.Request(startup, &version, nil)
if err != nil {
return nil, xerrors.Errorf("failed to receive version message (%s): %w", err.Error(), types.NewConnectionError())
}

return version.GetVersion(), nil
}

func (conn *IRODSConnection) sslStartup() error {
Expand Down Expand Up @@ -476,7 +507,7 @@ func (conn *IRODSConnection) sslStartup() error {

// Send a shared secret
sslSharedSecret := message.NewIRODSMessageSSLSharedSecret(encryptionKey)
err = conn.RequestWithoutResponseNoXML(sslSharedSecret)
err = conn.RequestWithoutResponse(sslSharedSecret)
if err != nil {
return xerrors.Errorf("failed to send ssl shared secret message (%s): %w", err.Error(), types.NewConnectionError())
}
Expand All @@ -490,7 +521,7 @@ func (conn *IRODSConnection) login(password string) error {
// authenticate
authRequest := message.NewIRODSMessageAuthRequest()
authChallenge := message.IRODSMessageAuthChallengeResponse{}
err := conn.Request(authRequest, &authChallenge, nil)
err := conn.RequestAndCheck(authRequest, &authChallenge, nil)
if err != nil {
return xerrors.Errorf("failed to receive authentication challenge message body (%s): %w", err.Error(), types.NewAuthError(conn.account))
}
Expand Down Expand Up @@ -551,7 +582,7 @@ func (conn *IRODSConnection) loginPAMWithPassword() error {

ttl := conn.account.PamTTL
if ttl < 0 {
ttl = 0 // decided by server
ttl = 0 // server decides
}

pamPassword := conn.getSafePAMPassword(conn.account.Password)
Expand All @@ -562,41 +593,43 @@ func (conn *IRODSConnection) loginPAMWithPassword() error {

authContext := strings.Join([]string{userKV, passwordKV, ttlKV}, ";")

useDedicatedPAMApi := false
if strings.ContainsAny(pamPassword, ";=") {
useDedicatedPAMApi = true
} else {
// from python-irodsclient code
if len(authContext) >= 1024+64 {
useDedicatedPAMApi = true
}
useDedicatedPAMApi := true
if conn.requirePAMPassword() {
useDedicatedPAMApi = strings.ContainsAny(pamPassword, ";=") || len(authContext) >= 1024+64
}

// authenticate
pamToken := ""

if useDedicatedPAMApi {
logger.Debugf("use dedicated PAM api")

pamAuthRequest := message.NewIRODSMessagePamAuthRequest(conn.account.ClientUser, pamPassword, ttl)
pamAuthResponse := message.IRODSMessagePamAuthResponse{}
err := conn.Request(pamAuthRequest, &pamAuthResponse, nil)
err := conn.RequestAndCheck(pamAuthRequest, &pamAuthResponse, nil)
if err != nil {
return xerrors.Errorf("failed to receive an authentication challenge message (%s): %w", err.Error(), types.NewAuthError(conn.account))
return xerrors.Errorf("failed to receive a PAM token (%s): %w", err.Error(), types.NewAuthError(conn.account))
}

pamToken = pamAuthResponse.GeneratedPassword
} else {
logger.Debugf("use auth plugin api: scheme %q", string(types.AuthSchemePAM))

pamAuthRequest := message.NewIRODSMessageAuthPluginRequest(string(types.AuthSchemePAM), authContext)
pamAuthResponse := message.IRODSMessageAuthPluginResponse{}
err := conn.Request(pamAuthRequest, &pamAuthResponse, nil)
err := conn.RequestAndCheck(pamAuthRequest, &pamAuthResponse, nil)
if err != nil {
return xerrors.Errorf("failed to receive an authentication challenge message (%s): %w", err.Error(), types.NewAuthError(conn.account))
return xerrors.Errorf("failed to receive a PAM token (%s): %w", err.Error(), types.NewAuthError(conn.account))
}

pamToken = pamAuthResponse.Result
pamToken = string(pamAuthResponse.GeneratedPassword)
}

// save irods generated password for possible future use
conn.account.PamToken = pamToken

// disconnect and connect

// retry native auth with generated password
return conn.login(conn.account.PamToken)
}
Expand Down Expand Up @@ -1053,6 +1086,7 @@ func (conn *IRODSConnection) poorMansEndTransaction(dummyCol string, commit bool
if commit {
request.AddKeyVal(common.COLLECTION_TYPE_KW, "NULL_SPECIAL_VALUE")
}

response := message.IRODSMessageModifyCollectionResponse{}
err := conn.Request(request, &response, nil)
if err != nil {
Expand Down
Loading

0 comments on commit 8c63d48

Please sign in to comment.