diff --git a/runner/internal/runner/api/http_test.go b/runner/internal/runner/api/http_test.go index ddc587592..98a6f5180 100644 --- a/runner/internal/runner/api/http_test.go +++ b/runner/internal/runner/api/http_test.go @@ -21,7 +21,7 @@ func (ds DummyRunner) GetState() (shim.RunnerStatus, shim.ContainerStatus, strin return ds.State, ds.ContainerStatus, "", ds.JobResult } -func (ds DummyRunner) Run(context.Context, shim.DockerImageConfig) error { +func (ds DummyRunner) Run(context.Context, shim.TaskConfig) error { return nil } diff --git a/runner/internal/shim/api/http.go b/runner/internal/shim/api/http.go index fac4103df..a5f8db052 100644 --- a/runner/internal/shim/api/http.go +++ b/runner/internal/shim/api/http.go @@ -28,18 +28,18 @@ func (s *ShimServer) SubmitPostHandler(w http.ResponseWriter, r *http.Request) ( return nil, &api.Error{Status: http.StatusConflict} } - var body DockerTaskBody + var body TaskConfigBody if err := api.DecodeJSONBody(w, r, &body, true); err != nil { log.Println("Failed to decode submit body", "err", err) return nil, err } - go func(taskParams shim.DockerImageConfig) { - err := s.runner.Run(context.Background(), taskParams) + go func(taskConfig shim.TaskConfig) { + err := s.runner.Run(context.Background(), taskConfig) if err != nil { fmt.Printf("failed Run %v\n", err) } - }(body.GetDockerImageConfig()) + }(body.GetTaskConfig()) return nil, nil } diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index c51451781..0eb7a2961 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -2,13 +2,15 @@ package api import "github.com/dstackai/dstack/runner/internal/shim" -type DockerTaskBody struct { +type TaskConfigBody struct { Username string `json:"username"` Password string `json:"password"` ImageName string `json:"image_name"` ContainerName string `json:"container_name"` ShmSize int64 `json:"shm_size"` PublicKeys []string `json:"public_keys"` + SshUser string `json:"ssh_user"` + SshKey string `json:"ssh_key"` } type StopBody struct { @@ -37,14 +39,16 @@ type StopResponse struct { State string `json:"state"` } -func (ra DockerTaskBody) GetDockerImageConfig() shim.DockerImageConfig { - res := shim.DockerImageConfig{ +func (ra TaskConfigBody) GetTaskConfig() shim.TaskConfig { + res := shim.TaskConfig{ ImageName: ra.ImageName, Username: ra.Username, Password: ra.Password, ContainerName: ra.ContainerName, ShmSize: ra.ShmSize, PublicKeys: ra.PublicKeys, + SshUser: ra.SshUser, + SshKey: ra.SshKey, } return res } diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index a6da3c276..03ed8a7ac 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -10,7 +10,7 @@ import ( ) type TaskRunner interface { - Run(context.Context, shim.DockerImageConfig) error + Run(context.Context, shim.TaskConfig) error GetState() (shim.RunnerStatus, shim.ContainerStatus, string, shim.JobResult) Stop(bool) } diff --git a/runner/internal/shim/authorized_keys.go b/runner/internal/shim/authorized_keys.go new file mode 100644 index 000000000..c61785104 --- /dev/null +++ b/runner/internal/shim/authorized_keys.go @@ -0,0 +1,141 @@ +package shim + +import ( + "bufio" + "fmt" + "io" + "os" + "slices" + + "github.com/ztrue/tracerr" + "golang.org/x/crypto/ssh" +) + +func PublicKeyFingerprint(key string) (string, error) { + pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key)) + if err != nil { + return "", tracerr.Wrap(err) + } + keyFingerprint := ssh.FingerprintSHA256(pk) + return keyFingerprint, nil +} + +func IsPublicKeysEqual(left string, right string) bool { + leftFingerprint, err := PublicKeyFingerprint(left) + if err != nil { + return false + } + + rightFingerprint, err := PublicKeyFingerprint(right) + if err != nil { + return false + } + + return leftFingerprint == rightFingerprint +} + +func RemovePublicKeys(fileKeys []string, keysToRemove []string) []string { + newKeys := slices.DeleteFunc(fileKeys, func(fileKey string) bool { + delete := slices.ContainsFunc(keysToRemove, func(removeKey string) bool { + return IsPublicKeysEqual(fileKey, removeKey) + }) + return delete + }) + return newKeys +} + +func AppendPublicKeys(fileKeys []string, keysToAppend []string) []string { + newKeys := []string{} + newKeys = append(newKeys, fileKeys...) + newKeys = append(newKeys, keysToAppend...) + return newKeys +} + +type AuthorizedKeys struct { + user string + rootPath string +} + +func (ak AuthorizedKeys) AppendPublicKeys(publicKeys []string) error { + return ak.transformAuthorizedKeys(AppendPublicKeys, publicKeys) +} + +func (ak AuthorizedKeys) RemovePublicKeys(publicKeys []string) error { + return ak.transformAuthorizedKeys(RemovePublicKeys, publicKeys) +} + +func (ak AuthorizedKeys) read(r io.Reader) ([]string, error) { + lines := []string{} + scanner := bufio.NewScanner(r) + for scanner.Scan() { + text := scanner.Text() + lines = append(lines, text) + } + if err := scanner.Err(); err != nil { + return []string{}, tracerr.Wrap(err) + } + return lines, nil +} + +func (ak AuthorizedKeys) write(w io.Writer, lines []string) error { + wr := bufio.NewWriter(w) + for _, line := range lines { + _, err := fmt.Fprintln(wr, line) + if err != nil { + return tracerr.Wrap(err) + + } + } + return wr.Flush() +} + +func (ak AuthorizedKeys) GetAuthorizedKeysPath() string { + return fmt.Sprintf("%s/home/%s/.ssh/authorized_keys", ak.rootPath, ak.user) +} + +func (ak AuthorizedKeys) transformAuthorizedKeys(transform func([]string, []string) []string, publicKeys []string) error { + authorizedKeysPath := ak.GetAuthorizedKeysPath() + info, err := os.Stat(authorizedKeysPath) + if err != nil { + return tracerr.Wrap(err) + } + fileMode := info.Mode().Perm() + + authorizedKeysFile, err := os.OpenFile(authorizedKeysPath, os.O_RDWR, fileMode) + if err != nil { + return tracerr.Wrap(err) + } + defer authorizedKeysFile.Close() + + lines, err := ak.read(authorizedKeysFile) + if err != nil { + return tracerr.Wrap(err) + } + + // write backup + authorizedKeysPathBackup := ak.GetAuthorizedKeysPath() + ".bak" + authorizedKeysBackup, err := os.OpenFile(authorizedKeysPathBackup, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fileMode) + if err != nil { + return tracerr.Wrap(err) + } + defer authorizedKeysBackup.Close() + if err := ak.write(authorizedKeysBackup, lines); err != nil { + return tracerr.Wrap(err) + } + + // transform lines + newLines := transform(lines, publicKeys) + + // write authorized_keys + if err := authorizedKeysFile.Truncate(0); err != nil { + return tracerr.Wrap(err) + } + if _, err := authorizedKeysFile.Seek(0, 0); err != nil { + return tracerr.Wrap(err) + } + if err := ak.write(authorizedKeysFile, newLines); err != nil { + return tracerr.Wrap(err) + } + + return nil +} diff --git a/runner/internal/shim/authorized_keys_test.go b/runner/internal/shim/authorized_keys_test.go new file mode 100644 index 000000000..82db45f51 --- /dev/null +++ b/runner/internal/shim/authorized_keys_test.go @@ -0,0 +1,179 @@ +package shim + +import ( + "os" + "path" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPublicKeyFingerprint(t *testing.T) { + key := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + expectedFingerprint := "SHA256:9HymzYAtJKNh8gKufl3EVoRSauL4E7Mbmuzqlcvii50" + fingerprint, err := PublicKeyFingerprint(key) + require.NoError(t, err) + require.Equal(t, expectedFingerprint, fingerprint) +} + +func TestPublicKeyFingerprintError(t *testing.T) { + key := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQ= thebits@barracuda" + fingerprint, err := PublicKeyFingerprint(key) + require.Error(t, err) + require.Empty(t, fingerprint) +} + +func TestIsPublicKeysEqual(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + + result := IsPublicKeysEqual(keyLeft, keyRight) + require.True(t, result) +} + +func TestIsPublicKeysEqualBrokenKey(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAP66um5MadfhB5dSnEM= thebits@barracuda" + + resultFwd := IsPublicKeysEqual(keyLeft, keyRight) + require.False(t, resultFwd) + + resultBck := IsPublicKeysEqual(keyRight, keyLeft) + require.False(t, resultBck) +} +func TestIsPublicKeysNotEqual(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + + result := IsPublicKeysEqual(keyLeft, keyRight) + require.False(t, result) +} + +func TestRemovePublicKeys(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + + keys := []string{keyLeft, keyRight} + newKeys := RemovePublicKeys(keys, []string{keyRight}) + + require.Len(t, newKeys, 1) + require.Equal(t, newKeys, []string{keyLeft}) +} + +func TestRemovePublicKeysRemoveAll(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + + keys := []string{keyLeft, keyRight} + newKeys := RemovePublicKeys(keys, []string{keyRight, keyLeft}) + + require.Empty(t, newKeys) +} + +func TestRemovePublicKeysRemoveNotContained(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + + keys := []string{keyLeft, keyRight} + newKeys := RemovePublicKeys(keys, []string{"# line with comment"}) + + require.Equal(t, keys, newKeys) +} + +func TestAppendPublicKeys(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + comment := "# line with coment" + + keys := []string{keyLeft, keyRight} + newKeys := AppendPublicKeys(keys, []string{comment}) + + require.Equal(t, []string{keyLeft, keyRight, comment}, newKeys) +} + +func TestAppendKey(t *testing.T) { + ak := AuthorizedKeys{user: "test_user", rootPath: t.TempDir()} + filePath := ak.GetAuthorizedKeysPath() + err := os.MkdirAll(path.Dir(filePath), os.ModePerm) + require.NoError(t, err) + + key := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + err = os.WriteFile(filePath, []byte(key), os.ModePerm) + require.NoError(t, err) + + commentLine := "# comment line" + err = ak.AppendPublicKeys([]string{commentLine}) + require.NoError(t, err) + + b, err := os.ReadFile(filePath) + require.NoError(t, err) + require.Contains(t, string(b), commentLine) +} + +func TestRemoveKey(t *testing.T) { + ak := AuthorizedKeys{user: "test_user", rootPath: t.TempDir()} + filePath := ak.GetAuthorizedKeysPath() + err := os.MkdirAll(path.Dir(filePath), os.ModePerm) + require.NoError(t, err) + + key := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + err = os.WriteFile(filePath, []byte(key), os.ModePerm) + require.NoError(t, err) + + err = ak.RemovePublicKeys([]string{key}) + require.NoError(t, err) + + b, err := os.ReadFile(filePath) + require.NoError(t, err) + require.Empty(t, string(b)) + + back, err := os.ReadFile(filePath + ".bak") + require.NoError(t, err) + require.Contains(t, string(back), key) +} + +func TestRemoveTwoKey(t *testing.T) { + ak := AuthorizedKeys{user: "test_user", rootPath: t.TempDir()} + filePath := ak.GetAuthorizedKeysPath() + err := os.MkdirAll(path.Dir(filePath), os.ModePerm) + require.NoError(t, err) + + first := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + second := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + third := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDIAGg0prDVeane6xLvMPBKQHxNUpt4q/hmuAAxjOUW0GWMPS2qE3l8YkmWeK80nKvio4M/IYWe67HIVeibdvKPoFJTtgm93WeJT9KD6h7MCschAf78mAIBhzUMK+9UYl5pE2jpfqc0SXkUsXDxMVN+ST9lN7fXUVsCPXO6qJG+0hLA3vs5r0aY1Td72vI4h45DhwjdpYkY1KTNJwfSwyvZpoN9n85JjaqXsjLG/NhieDBKu0VJE1a44aWuFwmULmpDZcUcWtk074pPMMvuh/Go5gbTaIf1gsniBKNLrfTeGjIHE/Hu9o1G3GGpq6CDqOjb0ykukWZbD2qfV0gERwIR dstack" + + err = os.WriteFile(filePath, []byte(first+"\n"+second+"\n"+third), os.ModePerm) + require.NoError(t, err) + + err = ak.RemovePublicKeys([]string{first, third}) + require.NoError(t, err) + + b, err := os.ReadFile(filePath) + require.NoError(t, err) + require.NotContains(t, string(b), first) + require.NotContains(t, string(b), third) + require.Contains(t, string(b), second) +} + +func TestAppendTwoKey(t *testing.T) { + ak := AuthorizedKeys{user: "test_user", rootPath: t.TempDir()} + filePath := ak.GetAuthorizedKeysPath() + err := os.MkdirAll(path.Dir(filePath), os.ModePerm) + require.NoError(t, err) + + first := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + second := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + third := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDIAGg0prDVeane6xLvMPBKQHxNUpt4q/hmuAAxjOUW0GWMPS2qE3l8YkmWeK80nKvio4M/IYWe67HIVeibdvKPoFJTtgm93WeJT9KD6h7MCschAf78mAIBhzUMK+9UYl5pE2jpfqc0SXkUsXDxMVN+ST9lN7fXUVsCPXO6qJG+0hLA3vs5r0aY1Td72vI4h45DhwjdpYkY1KTNJwfSwyvZpoN9n85JjaqXsjLG/NhieDBKu0VJE1a44aWuFwmULmpDZcUcWtk074pPMMvuh/Go5gbTaIf1gsniBKNLrfTeGjIHE/Hu9o1G3GGpq6CDqOjb0ykukWZbD2qfV0gERwIR dstack" + + err = os.WriteFile(filePath, []byte(first), os.ModePerm) + require.NoError(t, err) + + err = ak.AppendPublicKeys([]string{second, third}) + require.NoError(t, err) + + b, err := os.ReadFile(filePath) + require.NoError(t, err) + require.Contains(t, string(b), first) + require.Contains(t, string(b), second) + require.Contains(t, string(b), third) +} diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 34eb90abc..d54ced2b6 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -74,9 +74,22 @@ func NewDockerRunner(dockerParams DockerParameters) (*DockerRunner, error) { return runner, nil } -func (d *DockerRunner) Run(ctx context.Context, cfg DockerImageConfig) error { +func (d *DockerRunner) Run(ctx context.Context, cfg TaskConfig) error { var err error + if cfg.SshKey != "" { + ak := AuthorizedKeys{user: cfg.SshUser} + if err := ak.AppendPublicKeys([]string{cfg.SshKey}); err != nil { + return tracerr.Wrap(err) + } + defer func(cfg TaskConfig) { + err := ak.RemovePublicKeys([]string{cfg.SshKey}) + if err != nil { + log.Printf("Error RemovePublicKeys: %s\n", err.Error()) + } + }(cfg) + } + d.containerStatus = ContainerStatus{ ContainerName: cfg.ContainerName, } @@ -186,30 +199,30 @@ func (d DockerRunner) GetState() (RunnerStatus, ContainerStatus, string, JobResu return d.state, d.containerStatus, d.executorError, d.jobResult } -func pullImage(ctx context.Context, client docker.APIClient, taskParams DockerImageConfig) error { - if !strings.Contains(taskParams.ImageName, ":") { - taskParams.ImageName += ":latest" +func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConfig) error { + if !strings.Contains(taskConfig.ImageName, ":") { + taskConfig.ImageName += ":latest" } images, err := client.ImageList(ctx, image.ListOptions{ - Filters: filters.NewArgs(filters.Arg("reference", taskParams.ImageName)), + Filters: filters.NewArgs(filters.Arg("reference", taskConfig.ImageName)), }) if err != nil { return tracerr.Wrap(err) } // TODO: force pull latset - if len(images) > 0 && !strings.Contains(taskParams.ImageName, ":latest") { + if len(images) > 0 && !strings.Contains(taskConfig.ImageName, ":latest") { return nil } opts := image.PullOptions{} - regAuth, _ := taskParams.EncodeRegistryAuth() + regAuth, _ := taskConfig.EncodeRegistryAuth() if regAuth != "" { opts.RegistryAuth = regAuth } startTime := time.Now() - reader, err := client.ImagePull(ctx, taskParams.ImageName, opts) + reader, err := client.ImagePull(ctx, taskConfig.ImageName, opts) if err != nil { return tracerr.Wrap(err) } @@ -275,16 +288,16 @@ func pullImage(ctx context.Context, client docker.APIClient, taskParams DockerIm return nil } -func createContainer(ctx context.Context, client docker.APIClient, runnerDir string, dockerParams DockerParameters, dockerImageConfig DockerImageConfig) (string, error) { +func createContainer(ctx context.Context, client docker.APIClient, runnerDir string, dockerParams DockerParameters, taskConfig TaskConfig) (string, error) { timeout := int(0) stopOptions := container.StopOptions{Timeout: &timeout} - err := client.ContainerStop(ctx, dockerImageConfig.ContainerName, stopOptions) + err := client.ContainerStop(ctx, taskConfig.ContainerName, stopOptions) if err != nil { log.Printf("Cleanup routine: Cannot stop container: %s", err) } removeOptions := container.RemoveOptions{Force: true} - err = client.ContainerRemove(ctx, dockerImageConfig.ContainerName, removeOptions) + err = client.ContainerRemove(ctx, taskConfig.ContainerName, removeOptions) if err != nil { log.Printf("Cleanup routine: Cannot remove container: %s", err) } @@ -299,8 +312,8 @@ func createContainer(ctx context.Context, client docker.APIClient, runnerDir str } containerConfig := &container.Config{ - Image: dockerImageConfig.ImageName, - Cmd: []string{strings.Join(dockerParams.DockerShellCommands(dockerImageConfig.PublicKeys), " && ")}, + Image: taskConfig.ImageName, + Cmd: []string{strings.Join(dockerParams.DockerShellCommands(taskConfig.PublicKeys), " && ")}, Entrypoint: []string{"/bin/sh", "-c"}, ExposedPorts: exposePorts(dockerParams.DockerPorts()...), } @@ -313,9 +326,9 @@ func createContainer(ctx context.Context, client docker.APIClient, runnerDir str DeviceRequests: gpuRequest, }, Mounts: mounts, - ShmSize: dockerImageConfig.ShmSize, + ShmSize: taskConfig.ShmSize, } - resp, err := client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, dockerImageConfig.ContainerName) + resp, err := client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, taskConfig.ContainerName) if err != nil { return "", tracerr.Wrap(err) } diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index 20c008225..baead505a 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -33,7 +33,7 @@ func TestDocker_SSHServer(t *testing.T) { defer cancel() dockerRunner, _ := NewDockerRunner(params) - assert.NoError(t, dockerRunner.Run(ctx, DockerImageConfig{ImageName: "ubuntu"})) + assert.NoError(t, dockerRunner.Run(ctx, TaskConfig{ImageName: "ubuntu"})) } // TestDocker_SSHServerConnect pulls ubuntu image (without sshd), installs openssh-server and tries to connect via SSH @@ -64,7 +64,7 @@ func TestDocker_SSHServerConnect(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - assert.NoError(t, dockerRunner.Run(ctx, DockerImageConfig{ImageName: "ubuntu"})) + assert.NoError(t, dockerRunner.Run(ctx, TaskConfig{ImageName: "ubuntu"})) }() for i := 0; i < timeout; i++ { diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index 353d57f90..9de45d863 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -41,16 +41,18 @@ type CLIArgs struct { } } -type DockerImageConfig struct { +type TaskConfig struct { Username string Password string ImageName string ContainerName string ShmSize int64 PublicKeys []string + SshUser string + SshKey string } -func (ra DockerImageConfig) EncodeRegistryAuth() (string, error) { +func (ra TaskConfig) EncodeRegistryAuth() (string, error) { if ra.Username == "" && ra.Password == "" { return "", nil } diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 225439105..265c32b2e 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -187,7 +187,6 @@ def run_job( project_name=run.project_name, instance_name=get_instance_name(run, job), # TODO: generate name ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 1dd70624f..5b7f4b1b3 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -165,7 +165,6 @@ def run_job( project_name=run.project_name, instance_name=get_instance_name(run, job), # TODO: generate name ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/core/backends/cudo/compute.py b/src/dstack/_internal/core/backends/cudo/compute.py index 1b6df307b..de850306f 100644 --- a/src/dstack/_internal/core/backends/cudo/compute.py +++ b/src/dstack/_internal/core/backends/cudo/compute.py @@ -57,7 +57,6 @@ def run_job( project_name=run.project_name, instance_name=get_instance_name(run, job), ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index 637c4dc62..1dbb7490b 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -147,7 +147,6 @@ def run_job( project_name=run.project_name, instance_name=job.job_spec.job_name, # TODO: generate name ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index d47bc790b..dfcb798e8 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -167,7 +167,6 @@ def run_job( project_name=run.project_name, instance_name=get_instance_name(run, job), # TODO: generate name ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/core/backends/nebius/compute.py b/src/dstack/_internal/core/backends/nebius/compute.py index 054f6c643..614ec644f 100644 --- a/src/dstack/_internal/core/backends/nebius/compute.py +++ b/src/dstack/_internal/core/backends/nebius/compute.py @@ -125,7 +125,6 @@ def run_job( project_name=run.project_name, instance_name=get_instance_name(run, job), # TODO: generate name ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 360a03e96..c724f9a16 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -152,10 +152,9 @@ async def _process_job(job_id: UUID): fmt(job_model), job_submission.age, ) - public_keys = [ - project.ssh_public_key.strip(), - run.run_spec.ssh_key_pub.strip(), - ] + ssh_user = job_provisioning_data.username + user_ssh_key = run.run_spec.ssh_key_pub.strip() + public_keys = [project.ssh_public_key.strip(), user_ssh_key] success = await run_async( _process_provisioning_with_shim, server_ssh_private_key, @@ -164,6 +163,8 @@ async def _process_job(job_id: UUID): secrets, job.job_spec.registry_auth, public_keys, + ssh_user, + user_ssh_key, ) else: logger.debug( @@ -345,6 +346,8 @@ def _process_provisioning_with_shim( secrets: Dict[str, str], registry_auth: Optional[RegistryAuth], public_keys: List[str], + ssh_user: str, + ssh_key: str, *, ports: Dict[int, int], ) -> bool: @@ -380,6 +383,8 @@ def _process_provisioning_with_shim( container_name=job_model.job_name, shm_size=job_spec.requirements.resources.shm_size, public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=ssh_key, ) job_model.status = JobStatus.PULLING diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 9ade69c09..1e282aa8d 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -71,13 +71,15 @@ class HealthcheckResponse(CoreModel): version: str -class DockerImageBody(CoreModel): +class TaskConfigBody(CoreModel): username: str password: str image_name: str container_name: str shm_size: int public_keys: List[str] + ssh_user: str + ssh_key: str class StopBody(CoreModel): diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index 2ccdb96d6..9ed12ed9f 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -8,12 +8,12 @@ from dstack._internal.core.models.resources import Memory from dstack._internal.core.models.runs import ClusterInfo, JobSpec, RunSpec from dstack._internal.server.schemas.runner import ( - DockerImageBody, HealthcheckResponse, PullBody, PullResponse, StopBody, SubmitBody, + TaskConfigBody, ) REMOTE_SHIM_PORT = 10998 @@ -126,15 +126,19 @@ def submit( container_name: str, shm_size: Optional[Memory], public_keys: List[str], + ssh_user: str, + ssh_key: str, ): _shm_size = int(shm_size * 1024 * 1024 * 1014) if shm_size else 0 - post_body = DockerImageBody( + post_body = TaskConfigBody( username=username, password=password, image_name=image_name, container_name=container_name, shm_size=_shm_size, public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=ssh_key, ).dict() resp = requests.post( self._url("/api/submit"), diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index a4ad52ac1..99ce0b0b2 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -257,7 +257,6 @@ def attach( if control_sock_path_and_port_locks is None: if self._ports_lock is None: self._ports_lock = _reserve_ports(job.job_spec) - logger.debug( "Attaching to %s (%s: %s)", self.name, @@ -266,7 +265,6 @@ def attach( ) else: self._ports_lock = control_sock_path_and_port_locks[1] - logger.debug( "Reusing the existing tunnel to %s (%s: %s)", self.name, diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index d1ed4c034..cc832587f 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -234,6 +234,8 @@ async def test_provisioning_shim(self, test_db, session: AsyncSession): container_name="test-run-0-0", shm_size=None, public_keys=[project_ssh_pub_key, "user_ssh_key"], + ssh_user="ubuntu", + ssh_key="user_ssh_key", ) await session.refresh(job) assert job is not None