From 789c4c8e0351f8cb468d6846b1bc91baf1beb98d Mon Sep 17 00:00:00 2001 From: Sergey Mezentsev Date: Fri, 10 May 2024 09:08:02 +0300 Subject: [PATCH] The instance cannot be reused by another user (#1204) * Upload the ssh public key of the user in addition to the job configuration to the shim * Install only the project key on the instance * Remove golangci-lint warnings from Github Actions * Allow to use Run.attach() --- .github/workflows/build.yml | 28 +-- .github/workflows/release.yml | 12 +- .pre-commit-config.yaml | 4 +- runner/cmd/runner/main.go | 2 +- runner/cmd/shim/main.go | 18 +- runner/internal/runner/api/http_test.go | 2 +- runner/internal/shim/api/http.go | 8 +- runner/internal/shim/api/schemas.go | 22 ++- runner/internal/shim/api/server.go | 2 +- runner/internal/shim/authorized_keys.go | 140 ++++++++++++++ runner/internal/shim/authorized_keys_test.go | 180 ++++++++++++++++++ runner/internal/shim/docker.go | 55 ++++-- runner/internal/shim/docker_test.go | 13 +- runner/internal/shim/models.go | 15 +- runner/internal/shim/runner.go | 2 +- src/dstack/_internal/cli/commands/pool.py | 25 +-- .../_internal/core/backends/aws/compute.py | 1 - .../_internal/core/backends/azure/compute.py | 1 - .../_internal/core/backends/cudo/compute.py | 1 - .../core/backends/datacrunch/compute.py | 1 - .../_internal/core/backends/gcp/compute.py | 1 - .../_internal/core/backends/nebius/compute.py | 1 - .../background/tasks/process_running_jobs.py | 41 ++-- src/dstack/_internal/server/routers/runs.py | 1 - src/dstack/_internal/server/schemas/runner.py | 5 +- src/dstack/_internal/server/schemas/runs.py | 1 - .../server/services/runner/client.py | 12 +- src/dstack/_internal/server/services/runs.py | 4 +- src/dstack/_internal/server/testing/common.py | 2 +- src/dstack/api/_public/runs.py | 9 +- src/dstack/api/server/_runs.py | 4 +- .../tasks/test_process_running_jobs.py | 6 +- .../_internal/server/routers/test_runs.py | 4 - 33 files changed, 479 insertions(+), 144 deletions(-) create mode 100644 runner/internal/shim/authorized_keys.go create mode 100644 runner/internal/shim/authorized_keys_test.go diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0a8d662c7..70b520311 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -66,16 +66,16 @@ jobs: working-directory: runner runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: - go-version: 1.21.1 - - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + go-version: "1.22" + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v6 with: - version: v1.51.2 - args: --issues-exit-code=0 --timeout=20m + version: v1.58 + args: --timeout=20m working-directory: runner - name: Test run: | @@ -85,7 +85,7 @@ jobs: go test -race $(go list ./... | grep -v /vendor/) runner-compile: - needs: [ runner-test ] + needs: [runner-test] defaults: run: working-directory: runner @@ -94,14 +94,14 @@ jobs: strategy: matrix: include: - - {goos: "linux", goarch: "amd64", runson: "ubuntu-latest"} + - { goos: "linux", goarch: "amd64", runson: "ubuntu-latest" } runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: - go-version: 1.21.1 + go-version: "1.22" - name: build env: GOOS: ${{ matrix.goos }} @@ -122,11 +122,11 @@ jobs: retention-days: 1 runner-upload: - needs: [ runner-compile ] + needs: [runner-compile] runs-on: ubuntu-latest steps: - name: Install AWS - run: pip install awscli + run: pip install awscli - name: Download Runner uses: actions/download-artifact@v3 with: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8f44baf5b..be3c666b9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -68,7 +68,7 @@ jobs: go test -race $(go list ./... | grep -v /vendor/) runner-compile: - needs: [ runner-test ] + needs: [runner-test] defaults: run: working-directory: runner @@ -77,14 +77,14 @@ jobs: strategy: matrix: include: - - {goos: "linux", goarch: "amd64", runson: "ubuntu-latest"} + - { goos: "linux", goarch: "amd64", runson: "ubuntu-latest" } runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Go - uses: actions/setup-go@v3 + uses: actions/setup-go@v5 with: - go-version: 1.21.1 + go-version: "1.22" - name: build env: GOOS: ${{ matrix.goos }} @@ -107,7 +107,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Install AWS - run: pip install awscli + run: pip install awscli - name: Download Runner uses: actions/download-artifact@v3 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a647e272..405145522 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,13 +1,13 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.1 + rev: v0.4.4 hooks: - id: ruff name: ruff common args: ['--fix'] - id: ruff-format - repo: https://github.com/golangci/golangci-lint - rev: v1.56.2 + rev: v1.58.1 hooks: - id: golangci-lint-full entry: bash -c 'cd runner && golangci-lint run -D depguard --presets import,module,unused "$@"' diff --git a/runner/cmd/runner/main.go b/runner/cmd/runner/main.go index d3e06b1f5..443eefa92 100644 --- a/runner/cmd/runner/main.go +++ b/runner/cmd/runner/main.go @@ -20,7 +20,7 @@ func main() { } func start(tempDir string, homeDir string, workingDir string, httpPort int, logLevel int, version string) error { - if err := os.MkdirAll(tempDir, 0755); err != nil { + if err := os.MkdirAll(tempDir, 0o755); err != nil { return tracerr.Errorf("Failed to create temp directory: %w", err) } diff --git a/runner/cmd/shim/main.go b/runner/cmd/shim/main.go index 1160c4f0b..8ae390b61 100644 --- a/runner/cmd/shim/main.go +++ b/runner/cmd/shim/main.go @@ -101,7 +101,7 @@ func main() { Name: "ssh-key", Usage: "Public SSH key", Required: true, - Destination: &args.Docker.PublicSSHKey, + Destination: &args.Docker.ConcatinatedPublicSSHKeys, EnvVars: []string{"DSTACK_PUBLIC_SSH_KEY"}, }, &cli.BoolFlag{ @@ -112,7 +112,6 @@ func main() { }, }, Action: func(c *cli.Context) error { - if args.Runner.BinaryPath == "" { if err := args.DownloadRunner(); err != nil { return cli.Exit(err, 1) @@ -230,7 +229,10 @@ func writeHostInfo() { panic(err) } - f.Sync() + err = f.Sync() + if err != nil { + panic(err) + } } func getGpuInfo() [][]string { @@ -272,7 +274,7 @@ func getGpuInfo() [][]string { if err != nil { log.Fatal(err) } - fmt.Printf("gpu record %v\n", record) + gpus = append(gpus, record) } return gpus @@ -284,6 +286,7 @@ func getInterfaces() []string { if err != nil { panic("cannot get interfaces") } + for _, i := range ifaces { addrs, err := i.Addrs() if err != nil { @@ -293,10 +296,10 @@ func getInterfaces() []string { for _, addr := range addrs { switch v := addr.(type) { case *net.IPNet: - fmt.Println(v.IP) if v.IP.IsLoopback() { continue } + addresses = append(addresses, addr.String()) } } @@ -307,10 +310,13 @@ func getInterfaces() []string { func getDiskSize() uint64 { var stat unix.Statfs_t wd, err := os.Getwd() + if err != nil { + panic("cannot get current disk") + } + err = unix.Statfs(wd, &stat) if err != nil { panic("cannot get disk size") } - unix.Statfs(wd, &stat) size := stat.Bavail * uint64(stat.Bsize) return size } diff --git a/runner/internal/runner/api/http_test.go b/runner/internal/runner/api/http_test.go index ddc587592..98a6f5180 100644 --- a/runner/internal/runner/api/http_test.go +++ b/runner/internal/runner/api/http_test.go @@ -21,7 +21,7 @@ func (ds DummyRunner) GetState() (shim.RunnerStatus, shim.ContainerStatus, strin return ds.State, ds.ContainerStatus, "", ds.JobResult } -func (ds DummyRunner) Run(context.Context, shim.DockerImageConfig) error { +func (ds DummyRunner) Run(context.Context, shim.TaskConfig) error { return nil } diff --git a/runner/internal/shim/api/http.go b/runner/internal/shim/api/http.go index 449c74c27..a5f8db052 100644 --- a/runner/internal/shim/api/http.go +++ b/runner/internal/shim/api/http.go @@ -28,18 +28,18 @@ func (s *ShimServer) SubmitPostHandler(w http.ResponseWriter, r *http.Request) ( return nil, &api.Error{Status: http.StatusConflict} } - var body DockerTaskBody + var body TaskConfigBody if err := api.DecodeJSONBody(w, r, &body, true); err != nil { log.Println("Failed to decode submit body", "err", err) return nil, err } - go func(taskParams shim.DockerImageConfig) { - err := s.runner.Run(context.Background(), taskParams) + go func(taskConfig shim.TaskConfig) { + err := s.runner.Run(context.Background(), taskConfig) if err != nil { fmt.Printf("failed Run %v\n", err) } - }(body.TaskParams()) + }(body.GetTaskConfig()) return nil, nil } diff --git a/runner/internal/shim/api/schemas.go b/runner/internal/shim/api/schemas.go index 2a60c45a6..0eb7a2961 100644 --- a/runner/internal/shim/api/schemas.go +++ b/runner/internal/shim/api/schemas.go @@ -2,12 +2,15 @@ package api import "github.com/dstackai/dstack/runner/internal/shim" -type DockerTaskBody struct { - Username string `json:"username"` - Password string `json:"password"` - ImageName string `json:"image_name"` - ContainerName string `json:"container_name"` - ShmSize int64 `json:"shm_size"` +type TaskConfigBody struct { + Username string `json:"username"` + Password string `json:"password"` + ImageName string `json:"image_name"` + ContainerName string `json:"container_name"` + ShmSize int64 `json:"shm_size"` + PublicKeys []string `json:"public_keys"` + SshUser string `json:"ssh_user"` + SshKey string `json:"ssh_key"` } type StopBody struct { @@ -36,13 +39,16 @@ type StopResponse struct { State string `json:"state"` } -func (ra DockerTaskBody) TaskParams() shim.DockerImageConfig { - res := shim.DockerImageConfig{ +func (ra TaskConfigBody) GetTaskConfig() shim.TaskConfig { + res := shim.TaskConfig{ ImageName: ra.ImageName, Username: ra.Username, Password: ra.Password, ContainerName: ra.ContainerName, ShmSize: ra.ShmSize, + PublicKeys: ra.PublicKeys, + SshUser: ra.SshUser, + SshKey: ra.SshKey, } return res } diff --git a/runner/internal/shim/api/server.go b/runner/internal/shim/api/server.go index a6da3c276..03ed8a7ac 100644 --- a/runner/internal/shim/api/server.go +++ b/runner/internal/shim/api/server.go @@ -10,7 +10,7 @@ import ( ) type TaskRunner interface { - Run(context.Context, shim.DockerImageConfig) error + Run(context.Context, shim.TaskConfig) error GetState() (shim.RunnerStatus, shim.ContainerStatus, string, shim.JobResult) Stop(bool) } diff --git a/runner/internal/shim/authorized_keys.go b/runner/internal/shim/authorized_keys.go new file mode 100644 index 000000000..2ca614797 --- /dev/null +++ b/runner/internal/shim/authorized_keys.go @@ -0,0 +1,140 @@ +package shim + +import ( + "bufio" + "fmt" + "io" + "os" + "slices" + + "github.com/ztrue/tracerr" + "golang.org/x/crypto/ssh" +) + +func PublicKeyFingerprint(key string) (string, error) { + pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key)) + if err != nil { + return "", tracerr.Wrap(err) + } + keyFingerprint := ssh.FingerprintSHA256(pk) + return keyFingerprint, nil +} + +func IsPublicKeysEqual(left string, right string) bool { + leftFingerprint, err := PublicKeyFingerprint(left) + if err != nil { + return false + } + + rightFingerprint, err := PublicKeyFingerprint(right) + if err != nil { + return false + } + + return leftFingerprint == rightFingerprint +} + +func RemovePublicKeys(fileKeys []string, keysToRemove []string) []string { + newKeys := slices.DeleteFunc(fileKeys, func(fileKey string) bool { + delete := slices.ContainsFunc(keysToRemove, func(removeKey string) bool { + return IsPublicKeysEqual(fileKey, removeKey) + }) + return delete + }) + return newKeys +} + +func AppendPublicKeys(fileKeys []string, keysToAppend []string) []string { + newKeys := []string{} + newKeys = append(newKeys, fileKeys...) + newKeys = append(newKeys, keysToAppend...) + return newKeys +} + +type AuthorizedKeys struct { + user string + rootPath string +} + +func (ak AuthorizedKeys) AppendPublicKeys(publicKeys []string) error { + return ak.transformAuthorizedKeys(AppendPublicKeys, publicKeys) +} + +func (ak AuthorizedKeys) RemovePublicKeys(publicKeys []string) error { + return ak.transformAuthorizedKeys(RemovePublicKeys, publicKeys) +} + +func (ak AuthorizedKeys) read(r io.Reader) ([]string, error) { + lines := []string{} + scanner := bufio.NewScanner(r) + for scanner.Scan() { + text := scanner.Text() + lines = append(lines, text) + } + if err := scanner.Err(); err != nil { + return []string{}, tracerr.Wrap(err) + } + return lines, nil +} + +func (ak AuthorizedKeys) write(w io.Writer, lines []string) error { + wr := bufio.NewWriter(w) + for _, line := range lines { + _, err := fmt.Fprintln(wr, line) + if err != nil { + return tracerr.Wrap(err) + } + } + return wr.Flush() +} + +func (ak AuthorizedKeys) GetAuthorizedKeysPath() string { + return fmt.Sprintf("%s/home/%s/.ssh/authorized_keys", ak.rootPath, ak.user) +} + +func (ak AuthorizedKeys) transformAuthorizedKeys(transform func([]string, []string) []string, publicKeys []string) error { + authorizedKeysPath := ak.GetAuthorizedKeysPath() + info, err := os.Stat(authorizedKeysPath) + if err != nil { + return tracerr.Wrap(err) + } + fileMode := info.Mode().Perm() + + authorizedKeysFile, err := os.OpenFile(authorizedKeysPath, os.O_RDWR, fileMode) + if err != nil { + return tracerr.Wrap(err) + } + defer authorizedKeysFile.Close() + + lines, err := ak.read(authorizedKeysFile) + if err != nil { + return tracerr.Wrap(err) + } + + // write backup + authorizedKeysPathBackup := ak.GetAuthorizedKeysPath() + ".bak" + authorizedKeysBackup, err := os.OpenFile(authorizedKeysPathBackup, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fileMode) + if err != nil { + return tracerr.Wrap(err) + } + defer authorizedKeysBackup.Close() + if err := ak.write(authorizedKeysBackup, lines); err != nil { + return tracerr.Wrap(err) + } + + // transform lines + newLines := transform(lines, publicKeys) + + // write authorized_keys + if err := authorizedKeysFile.Truncate(0); err != nil { + return tracerr.Wrap(err) + } + if _, err := authorizedKeysFile.Seek(0, 0); err != nil { + return tracerr.Wrap(err) + } + if err := ak.write(authorizedKeysFile, newLines); err != nil { + return tracerr.Wrap(err) + } + + return nil +} diff --git a/runner/internal/shim/authorized_keys_test.go b/runner/internal/shim/authorized_keys_test.go new file mode 100644 index 000000000..b0011b9be --- /dev/null +++ b/runner/internal/shim/authorized_keys_test.go @@ -0,0 +1,180 @@ +package shim + +import ( + "os" + "path" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPublicKeyFingerprint(t *testing.T) { + key := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + expectedFingerprint := "SHA256:9HymzYAtJKNh8gKufl3EVoRSauL4E7Mbmuzqlcvii50" + fingerprint, err := PublicKeyFingerprint(key) + require.NoError(t, err) + require.Equal(t, expectedFingerprint, fingerprint) +} + +func TestPublicKeyFingerprintError(t *testing.T) { + key := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQ= thebits@barracuda" + fingerprint, err := PublicKeyFingerprint(key) + require.Error(t, err) + require.Empty(t, fingerprint) +} + +func TestIsPublicKeysEqual(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + + result := IsPublicKeysEqual(keyLeft, keyRight) + require.True(t, result) +} + +func TestIsPublicKeysEqualBrokenKey(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAP66um5MadfhB5dSnEM= thebits@barracuda" + + resultFwd := IsPublicKeysEqual(keyLeft, keyRight) + require.False(t, resultFwd) + + resultBck := IsPublicKeysEqual(keyRight, keyLeft) + require.False(t, resultBck) +} + +func TestIsPublicKeysNotEqual(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + + result := IsPublicKeysEqual(keyLeft, keyRight) + require.False(t, result) +} + +func TestRemovePublicKeys(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + + keys := []string{keyLeft, keyRight} + newKeys := RemovePublicKeys(keys, []string{keyRight}) + + require.Len(t, newKeys, 1) + require.Equal(t, newKeys, []string{keyLeft}) +} + +func TestRemovePublicKeysRemoveAll(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + + keys := []string{keyLeft, keyRight} + newKeys := RemovePublicKeys(keys, []string{keyRight, keyLeft}) + + require.Empty(t, newKeys) +} + +func TestRemovePublicKeysRemoveNotContained(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + + keys := []string{keyLeft, keyRight} + newKeys := RemovePublicKeys(keys, []string{"# line with comment"}) + + require.Equal(t, keys, newKeys) +} + +func TestAppendPublicKeys(t *testing.T) { + keyLeft := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + keyRight := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + comment := "# line with coment" + + keys := []string{keyLeft, keyRight} + newKeys := AppendPublicKeys(keys, []string{comment}) + + require.Equal(t, []string{keyLeft, keyRight, comment}, newKeys) +} + +func TestAppendKey(t *testing.T) { + ak := AuthorizedKeys{user: "test_user", rootPath: t.TempDir()} + filePath := ak.GetAuthorizedKeysPath() + err := os.MkdirAll(path.Dir(filePath), os.ModePerm) + require.NoError(t, err) + + key := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + err = os.WriteFile(filePath, []byte(key), os.ModePerm) + require.NoError(t, err) + + commentLine := "# comment line" + err = ak.AppendPublicKeys([]string{commentLine}) + require.NoError(t, err) + + b, err := os.ReadFile(filePath) + require.NoError(t, err) + require.Contains(t, string(b), commentLine) +} + +func TestRemoveKey(t *testing.T) { + ak := AuthorizedKeys{user: "test_user", rootPath: t.TempDir()} + filePath := ak.GetAuthorizedKeysPath() + err := os.MkdirAll(path.Dir(filePath), os.ModePerm) + require.NoError(t, err) + + key := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + err = os.WriteFile(filePath, []byte(key), os.ModePerm) + require.NoError(t, err) + + err = ak.RemovePublicKeys([]string{key}) + require.NoError(t, err) + + b, err := os.ReadFile(filePath) + require.NoError(t, err) + require.Empty(t, string(b)) + + back, err := os.ReadFile(filePath + ".bak") + require.NoError(t, err) + require.Contains(t, string(back), key) +} + +func TestRemoveTwoKey(t *testing.T) { + ak := AuthorizedKeys{user: "test_user", rootPath: t.TempDir()} + filePath := ak.GetAuthorizedKeysPath() + err := os.MkdirAll(path.Dir(filePath), os.ModePerm) + require.NoError(t, err) + + first := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + second := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + third := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDIAGg0prDVeane6xLvMPBKQHxNUpt4q/hmuAAxjOUW0GWMPS2qE3l8YkmWeK80nKvio4M/IYWe67HIVeibdvKPoFJTtgm93WeJT9KD6h7MCschAf78mAIBhzUMK+9UYl5pE2jpfqc0SXkUsXDxMVN+ST9lN7fXUVsCPXO6qJG+0hLA3vs5r0aY1Td72vI4h45DhwjdpYkY1KTNJwfSwyvZpoN9n85JjaqXsjLG/NhieDBKu0VJE1a44aWuFwmULmpDZcUcWtk074pPMMvuh/Go5gbTaIf1gsniBKNLrfTeGjIHE/Hu9o1G3GGpq6CDqOjb0ykukWZbD2qfV0gERwIR dstack" + + err = os.WriteFile(filePath, []byte(first+"\n"+second+"\n"+third), os.ModePerm) + require.NoError(t, err) + + err = ak.RemovePublicKeys([]string{first, third}) + require.NoError(t, err) + + b, err := os.ReadFile(filePath) + require.NoError(t, err) + require.NotContains(t, string(b), first) + require.NotContains(t, string(b), third) + require.Contains(t, string(b), second) +} + +func TestAppendTwoKey(t *testing.T) { + ak := AuthorizedKeys{user: "test_user", rootPath: t.TempDir()} + filePath := ak.GetAuthorizedKeysPath() + err := os.MkdirAll(path.Dir(filePath), os.ModePerm) + require.NoError(t, err) + + first := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCdqa9VimGtCppxtz6T0kXfA6csnRlGS0zmTNvH2XCIYYbNFcymjL1SpFXfYQvXrnoK7nR+4dHP66um5Mi4OWHC1pB4t2OPYNnEYuYJ/VFpPv0/ykGAijV+IZjh6wS5r1o/EfiG8kMlv2TGhDb/jjsJXl9zb3i0urTrG0Sk6iw7F7QL/pXUe1cKuhdxOUzw/ddNZ5fBCikAr2cYfI0kiqe4U/pRSV5mPNAuQvBFK+K7UDdKfKIf4YxTFjXFbcgD7XUC5nInhIdSvGFYLdHSuafwWz8Q5ds/EyAPCyMU2wsA+AIP5XpdIraJLDTQT1J4PjcYwecNibWU2rkobl9FDVcflZq+0s0HbmJRlB4uExTNRZP7ykMKp9MtJsQGB6uA41KYNsvV5a+7SX39syNDHGTB13gHQHmYEHgSmHIcyEE2tEh7Zb6OAFCsytUKzBl51FIS3V70ve9kqJUcldBEkGJh6PeFOvYQZ95Gl2Uob0ujKCVDrzMylepnadfhB5dSnEM= thebits@barracuda" + second := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCfAFHwfyMKPFbKq+D/vYNaXjqer4uV5+zvlrPY2bvkdRT4GiH4hm2s1Z7+fUEYQBNfw5O9SgxGotqyguUJbuVUc2BCNdD8HC3PxKtEev35ga4G3jjyuVeHcL2T9pn+F8IW1o3SpDGATAHJyFtArPYz31Hwg6PiuggPNdPLMSzZNrwNVuPwT1uDMKFqAh+1ryIVi7389fjZ7aBR9F06VIPpWIVVKqSVD+NbHtwWqCw8AsprJE3bPwVW09OJeQX8GXryKasaX4t4HMXmO/UI8tprnyf05dAl7NQOPY9Iut5PgfzEVY/T0M1RSnZi7i+1x7WBWX3aMM/Hv+NUeX2YtuAN" + third := "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDIAGg0prDVeane6xLvMPBKQHxNUpt4q/hmuAAxjOUW0GWMPS2qE3l8YkmWeK80nKvio4M/IYWe67HIVeibdvKPoFJTtgm93WeJT9KD6h7MCschAf78mAIBhzUMK+9UYl5pE2jpfqc0SXkUsXDxMVN+ST9lN7fXUVsCPXO6qJG+0hLA3vs5r0aY1Td72vI4h45DhwjdpYkY1KTNJwfSwyvZpoN9n85JjaqXsjLG/NhieDBKu0VJE1a44aWuFwmULmpDZcUcWtk074pPMMvuh/Go5gbTaIf1gsniBKNLrfTeGjIHE/Hu9o1G3GGpq6CDqOjb0ykukWZbD2qfV0gERwIR dstack" + + err = os.WriteFile(filePath, []byte(first), os.ModePerm) + require.NoError(t, err) + + err = ak.AppendPublicKeys([]string{second, third}) + require.NoError(t, err) + + b, err := os.ReadFile(filePath) + require.NoError(t, err) + require.Contains(t, string(b), first) + require.Contains(t, string(b), second) + require.Contains(t, string(b), third) +} diff --git a/runner/internal/shim/docker.go b/runner/internal/shim/docker.go index 6334d893d..b45795382 100644 --- a/runner/internal/shim/docker.go +++ b/runner/internal/shim/docker.go @@ -74,9 +74,22 @@ func NewDockerRunner(dockerParams DockerParameters) (*DockerRunner, error) { return runner, nil } -func (d *DockerRunner) Run(ctx context.Context, cfg DockerImageConfig) error { +func (d *DockerRunner) Run(ctx context.Context, cfg TaskConfig) error { var err error + if cfg.SshKey != "" { + ak := AuthorizedKeys{user: cfg.SshUser} + if err := ak.AppendPublicKeys([]string{cfg.SshKey}); err != nil { + return tracerr.Wrap(err) + } + defer func(cfg TaskConfig) { + err := ak.RemovePublicKeys([]string{cfg.SshKey}) + if err != nil { + log.Printf("Error RemovePublicKeys: %s\n", err.Error()) + } + }(cfg) + } + d.containerStatus = ContainerStatus{ ContainerName: cfg.ContainerName, } @@ -155,7 +168,7 @@ func (d *DockerRunner) Run(ctx context.Context, cfg DockerImageConfig) error { d.state = Pending d.currentContainer = "" - var jobResult = JobResult{Reason: "DONE_BY_RUNNER"} + jobResult := JobResult{Reason: "DONE_BY_RUNNER"} if d.containerStatus.ExitCode != 0 { jobResult = JobResult{Reason: "CONTAINER_EXITED_WITH_ERROR", ReasonMessage: d.containerStatus.Error} } @@ -186,30 +199,30 @@ func (d DockerRunner) GetState() (RunnerStatus, ContainerStatus, string, JobResu return d.state, d.containerStatus, d.executorError, d.jobResult } -func pullImage(ctx context.Context, client docker.APIClient, taskParams DockerImageConfig) error { - if !strings.Contains(taskParams.ImageName, ":") { - taskParams.ImageName += ":latest" +func pullImage(ctx context.Context, client docker.APIClient, taskConfig TaskConfig) error { + if !strings.Contains(taskConfig.ImageName, ":") { + taskConfig.ImageName += ":latest" } images, err := client.ImageList(ctx, image.ListOptions{ - Filters: filters.NewArgs(filters.Arg("reference", taskParams.ImageName)), + Filters: filters.NewArgs(filters.Arg("reference", taskConfig.ImageName)), }) if err != nil { return tracerr.Wrap(err) } // TODO: force pull latset - if len(images) > 0 && !strings.Contains(taskParams.ImageName, ":latest") { + if len(images) > 0 && !strings.Contains(taskConfig.ImageName, ":latest") { return nil } opts := image.PullOptions{} - regAuth, _ := taskParams.EncodeRegistryAuth() + regAuth, _ := taskConfig.EncodeRegistryAuth() if regAuth != "" { opts.RegistryAuth = regAuth } startTime := time.Now() - reader, err := client.ImagePull(ctx, taskParams.ImageName, opts) + reader, err := client.ImagePull(ctx, taskConfig.ImageName, opts) if err != nil { return tracerr.Wrap(err) } @@ -275,16 +288,16 @@ func pullImage(ctx context.Context, client docker.APIClient, taskParams DockerIm return nil } -func createContainer(ctx context.Context, client docker.APIClient, runnerDir string, dockerParams DockerParameters, taskParams DockerImageConfig) (string, error) { +func createContainer(ctx context.Context, client docker.APIClient, runnerDir string, dockerParams DockerParameters, taskConfig TaskConfig) (string, error) { timeout := int(0) stopOptions := container.StopOptions{Timeout: &timeout} - err := client.ContainerStop(ctx, taskParams.ContainerName, stopOptions) + err := client.ContainerStop(ctx, taskConfig.ContainerName, stopOptions) if err != nil { log.Printf("Cleanup routine: Cannot stop container: %s", err) } removeOptions := container.RemoveOptions{Force: true} - err = client.ContainerRemove(ctx, taskParams.ContainerName, removeOptions) + err = client.ContainerRemove(ctx, taskConfig.ContainerName, removeOptions) if err != nil { log.Printf("Cleanup routine: Cannot remove container: %s", err) } @@ -299,8 +312,8 @@ func createContainer(ctx context.Context, client docker.APIClient, runnerDir str } containerConfig := &container.Config{ - Image: taskParams.ImageName, - Cmd: []string{strings.Join(dockerParams.DockerShellCommands(), " && ")}, + Image: taskConfig.ImageName, + Cmd: []string{strings.Join(dockerParams.DockerShellCommands(taskConfig.PublicKeys), " && ")}, Entrypoint: []string{"/bin/sh", "-c"}, ExposedPorts: exposePorts(dockerParams.DockerPorts()...), } @@ -313,9 +326,9 @@ func createContainer(ctx context.Context, client docker.APIClient, runnerDir str DeviceRequests: gpuRequest, }, Mounts: mounts, - ShmSize: taskParams.ShmSize, + ShmSize: taskConfig.ShmSize, } - resp, err := client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, taskParams.ContainerName) + resp, err := client.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, taskConfig.ContainerName) if err != nil { return "", tracerr.Wrap(err) } @@ -412,8 +425,12 @@ func (c CLIArgs) DockerKeepContainer() bool { return c.Docker.KeepContainer } -func (c CLIArgs) DockerShellCommands() []string { - commands := getSSHShellCommands(c.Docker.SSHPort, c.Docker.PublicSSHKey) +func (c CLIArgs) DockerShellCommands(publicKeys []string) []string { + concatinatedPublicKeys := c.Docker.ConcatinatedPublicSSHKeys + if len(publicKeys) > 0 { + concatinatedPublicKeys = strings.Join(publicKeys, "\n") + } + commands := getSSHShellCommands(c.Docker.SSHPort, concatinatedPublicKeys) commands = append(commands, fmt.Sprintf("%s %s", DstackRunnerBinaryName, strings.Join(c.getRunnerArgs(), " "))) return commands } @@ -439,7 +456,7 @@ func (c CLIArgs) DockerPorts() []int { func (c CLIArgs) MakeRunnerDir() (string, error) { runnerTemp := filepath.Join(c.Shim.HomeDir, "runners", time.Now().Format("20060102-150405")) - if err := os.MkdirAll(runnerTemp, 0755); err != nil { + if err := os.MkdirAll(runnerTemp, 0o755); err != nil { return "", tracerr.Wrap(err) } return runnerTemp, nil diff --git a/runner/internal/shim/docker_test.go b/runner/internal/shim/docker_test.go index bdd0a32db..baead505a 100644 --- a/runner/internal/shim/docker_test.go +++ b/runner/internal/shim/docker_test.go @@ -5,6 +5,7 @@ import ( "os" "os/exec" "strconv" + "strings" "sync" "sync/atomic" "testing" @@ -32,7 +33,7 @@ func TestDocker_SSHServer(t *testing.T) { defer cancel() dockerRunner, _ := NewDockerRunner(params) - assert.NoError(t, dockerRunner.Run(ctx, DockerImageConfig{ImageName: "ubuntu"})) + assert.NoError(t, dockerRunner.Run(ctx, TaskConfig{ImageName: "ubuntu"})) } // TestDocker_SSHServerConnect pulls ubuntu image (without sshd), installs openssh-server and tries to connect via SSH @@ -63,7 +64,7 @@ func TestDocker_SSHServerConnect(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - assert.NoError(t, dockerRunner.Run(ctx, DockerImageConfig{ImageName: "ubuntu"})) + assert.NoError(t, dockerRunner.Run(ctx, TaskConfig{ImageName: "ubuntu"})) }() for i := 0; i < timeout; i++ { @@ -97,9 +98,13 @@ func (c *dockerParametersMock) DockerKeepContainer() bool { return false } -func (c *dockerParametersMock) DockerShellCommands() []string { +func (c *dockerParametersMock) DockerShellCommands(publicKeys []string) []string { + userPublicKey := c.publicSSHKey + if len(publicKeys) > 0 { + userPublicKey = strings.Join(publicKeys, "\n") + } commands := make([]string, 0) - commands = append(commands, getSSHShellCommands(c.sshPort, c.publicSSHKey)...) + commands = append(commands, getSSHShellCommands(c.sshPort, userPublicKey)...) commands = append(commands, c.commands...) return commands } diff --git a/runner/internal/shim/models.go b/runner/internal/shim/models.go index f766f83df..9de45d863 100644 --- a/runner/internal/shim/models.go +++ b/runner/internal/shim/models.go @@ -11,7 +11,7 @@ import ( type DockerParameters interface { DockerKeepContainer() bool - DockerShellCommands() []string + DockerShellCommands([]string) []string DockerMounts(string) ([]mount.Mount, error) DockerPorts() []int MakeRunnerDir() (string, error) @@ -35,21 +35,24 @@ type CLIArgs struct { } Docker struct { - SSHPort int - KeepContainer bool - PublicSSHKey string + SSHPort int + KeepContainer bool + ConcatinatedPublicSSHKeys string } } -type DockerImageConfig struct { +type TaskConfig struct { Username string Password string ImageName string ContainerName string ShmSize int64 + PublicKeys []string + SshUser string + SshKey string } -func (ra DockerImageConfig) EncodeRegistryAuth() (string, error) { +func (ra TaskConfig) EncodeRegistryAuth() (string, error) { if ra.Username == "" && ra.Password == "" { return "", nil } diff --git a/runner/internal/shim/runner.go b/runner/internal/shim/runner.go index afbf98031..f96e9afaf 100644 --- a/runner/internal/shim/runner.go +++ b/runner/internal/shim/runner.go @@ -118,7 +118,7 @@ func downloadRunner(url string) (string, error) { log.Printf("The runner was downloaded successfully (%d bytes)", written) } - if err := tempFile.Chmod(0755); err != nil { + if err := tempFile.Chmod(0o755); err != nil { return "", gerrors.Wrap(err) } diff --git a/src/dstack/_internal/cli/commands/pool.py b/src/dstack/_internal/cli/commands/pool.py index cda3a4e45..4c570fd15 100644 --- a/src/dstack/_internal/cli/commands/pool.py +++ b/src/dstack/_internal/cli/commands/pool.py @@ -273,19 +273,11 @@ def _add(self, args: argparse.Namespace) -> None: console.print("\nExiting...") return - # TODO(egor-s): user key must be added during the `run`, not `pool add` - user_priv_key = Path("~/.dstack/ssh/id_rsa").expanduser().read_text().strip() - try: - user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip() - except FileNotFoundError: - user_pub_key = generate_public_key(rsa_pkey_from_str(user_priv_key)) - user_ssh_key = SSHKey(public=user_pub_key, private=user_priv_key) - try: with console.status("Creating instance..."): # TODO: Instance name is not passed, so --instance does not work. # There is profile.instance_name but it makes sense for `dstack run` only. - instance = self.api.runs.create_instance(profile, requirements, user_ssh_key) + instance = self.api.runs.create_instance(profile, requirements) except ServerClientError as e: raise CLIError(e.msg) console.print() @@ -295,21 +287,6 @@ def _add_ssh(self, args: argparse.Namespace) -> None: super()._command(args) ssh_keys = [] - - try: - # TODO: user key must be added during the `run`, not `pool add` - user_priv_key = convert_pkcs8_to_pem( - Path("~/.dstack/ssh/id_rsa").expanduser().read_text().strip() - ) - try: - user_pub_key = Path("~/.dstack/ssh/id_rsa.pub").expanduser().read_text().strip() - except FileNotFoundError: - user_pub_key = generate_public_key(rsa_pkey_from_str(user_priv_key)) - user_ssh_key = SSHKey(public=user_pub_key, private=user_priv_key) - ssh_keys.append(user_ssh_key) - except OSError: - pass - if args.ssh_identity_file: try: private_key = convert_pkcs8_to_pem(args.ssh_identity_file.read_text()) diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 225439105..265c32b2e 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -187,7 +187,6 @@ def run_job( project_name=run.project_name, instance_name=get_instance_name(run, job), # TODO: generate name ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 1dd70624f..5b7f4b1b3 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -165,7 +165,6 @@ def run_job( project_name=run.project_name, instance_name=get_instance_name(run, job), # TODO: generate name ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/core/backends/cudo/compute.py b/src/dstack/_internal/core/backends/cudo/compute.py index 1b6df307b..de850306f 100644 --- a/src/dstack/_internal/core/backends/cudo/compute.py +++ b/src/dstack/_internal/core/backends/cudo/compute.py @@ -57,7 +57,6 @@ def run_job( project_name=run.project_name, instance_name=get_instance_name(run, job), ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/core/backends/datacrunch/compute.py b/src/dstack/_internal/core/backends/datacrunch/compute.py index 637c4dc62..1dbb7490b 100644 --- a/src/dstack/_internal/core/backends/datacrunch/compute.py +++ b/src/dstack/_internal/core/backends/datacrunch/compute.py @@ -147,7 +147,6 @@ def run_job( project_name=run.project_name, instance_name=job.job_spec.job_name, # TODO: generate name ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index d47bc790b..dfcb798e8 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -167,7 +167,6 @@ def run_job( project_name=run.project_name, instance_name=get_instance_name(run, job), # TODO: generate name ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/core/backends/nebius/compute.py b/src/dstack/_internal/core/backends/nebius/compute.py index 054f6c643..614ec644f 100644 --- a/src/dstack/_internal/core/backends/nebius/compute.py +++ b/src/dstack/_internal/core/backends/nebius/compute.py @@ -125,7 +125,6 @@ def run_job( project_name=run.project_name, instance_name=get_instance_name(run, job), # TODO: generate name ssh_keys=[ - SSHKey(public=run.run_spec.ssh_key_pub.strip()), SSHKey(public=project_ssh_public_key.strip()), ], job_docker_config=None, diff --git a/src/dstack/_internal/server/background/tasks/process_running_jobs.py b/src/dstack/_internal/server/background/tasks/process_running_jobs.py index 3e7b8974c..c724f9a16 100644 --- a/src/dstack/_internal/server/background/tasks/process_running_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_running_jobs.py @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Dict, Optional +from typing import Dict, List, Optional from uuid import UUID from sqlalchemy import select @@ -152,6 +152,9 @@ async def _process_job(job_id: UUID): fmt(job_model), job_submission.age, ) + ssh_user = job_provisioning_data.username + user_ssh_key = run.run_spec.ssh_key_pub.strip() + public_keys = [project.ssh_public_key.strip(), user_ssh_key] success = await run_async( _process_provisioning_with_shim, server_ssh_private_key, @@ -159,6 +162,9 @@ async def _process_job(job_id: UUID): job_model, secrets, job.job_spec.registry_auth, + public_keys, + ssh_user, + user_ssh_key, ) else: logger.debug( @@ -339,6 +345,9 @@ def _process_provisioning_with_shim( job_model: JobModel, secrets: Dict[str, str], registry_auth: Optional[RegistryAuth], + public_keys: List[str], + ssh_user: str, + ssh_key: str, *, ports: Dict[int, int], ) -> bool: @@ -359,24 +368,24 @@ def _process_provisioning_with_shim( logger.debug("%s: shim is not available yet", fmt(job_model)) return False # shim is not available yet + username = "" + password = "" if registry_auth is not None: logger.debug("%s: authenticating to the registry...", fmt(job_model)) interpolate = VariablesInterpolator({"secrets": secrets}).interpolate - shim_client.submit( - username=interpolate(registry_auth.username), - password=interpolate(registry_auth.password), - image_name=job_spec.image_name, - container_name=job_model.job_name, - shm_size=job_spec.requirements.resources.shm_size, - ) - else: - shim_client.submit( - username="", - password="", - image_name=job_spec.image_name, - container_name=job_model.job_name, - shm_size=job_spec.requirements.resources.shm_size, - ) + username = interpolate(registry_auth.username) + password = interpolate(registry_auth.password) + + shim_client.submit( + username=username, + password=password, + image_name=job_spec.image_name, + container_name=job_model.job_name, + shm_size=job_spec.requirements.resources.shm_size, + public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=ssh_key, + ) job_model.status = JobStatus.PULLING logger.info("%s: now is %s", fmt(job_model), job_model.status.name) diff --git a/src/dstack/_internal/server/routers/runs.py b/src/dstack/_internal/server/routers/runs.py index 49ad06e73..2a8af76e8 100644 --- a/src/dstack/_internal/server/routers/runs.py +++ b/src/dstack/_internal/server/routers/runs.py @@ -166,7 +166,6 @@ async def create_instance( session=session, project=project, user=user, - ssh_key=body.ssh_key, profile=body.profile, requirements=body.requirements, ) diff --git a/src/dstack/_internal/server/schemas/runner.py b/src/dstack/_internal/server/schemas/runner.py index 5aebb59ff..1e282aa8d 100644 --- a/src/dstack/_internal/server/schemas/runner.py +++ b/src/dstack/_internal/server/schemas/runner.py @@ -71,12 +71,15 @@ class HealthcheckResponse(CoreModel): version: str -class DockerImageBody(CoreModel): +class TaskConfigBody(CoreModel): username: str password: str image_name: str container_name: str shm_size: int + public_keys: List[str] + ssh_user: str + ssh_key: str class StopBody(CoreModel): diff --git a/src/dstack/_internal/server/schemas/runs.py b/src/dstack/_internal/server/schemas/runs.py index e8ac13d70..c6e634889 100644 --- a/src/dstack/_internal/server/schemas/runs.py +++ b/src/dstack/_internal/server/schemas/runs.py @@ -35,7 +35,6 @@ class GetOffersRequest(CoreModel): class CreateInstanceRequest(CoreModel): profile: Profile requirements: Requirements - ssh_key: SSHKey class AddRemoteInstanceRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/runner/client.py b/src/dstack/_internal/server/services/runner/client.py index b36044604..9ed12ed9f 100644 --- a/src/dstack/_internal/server/services/runner/client.py +++ b/src/dstack/_internal/server/services/runner/client.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import BinaryIO, Dict, Optional, Union +from typing import BinaryIO, Dict, List, Optional, Union import requests import requests.exceptions @@ -8,12 +8,12 @@ from dstack._internal.core.models.resources import Memory from dstack._internal.core.models.runs import ClusterInfo, JobSpec, RunSpec from dstack._internal.server.schemas.runner import ( - DockerImageBody, HealthcheckResponse, PullBody, PullResponse, StopBody, SubmitBody, + TaskConfigBody, ) REMOTE_SHIM_PORT = 10998 @@ -125,14 +125,20 @@ def submit( image_name: str, container_name: str, shm_size: Optional[Memory], + public_keys: List[str], + ssh_user: str, + ssh_key: str, ): _shm_size = int(shm_size * 1024 * 1024 * 1014) if shm_size else 0 - post_body = DockerImageBody( + post_body = TaskConfigBody( username=username, password=password, image_name=image_name, container_name=container_name, shm_size=_shm_size, + public_keys=public_keys, + ssh_user=ssh_user, + ssh_key=ssh_key, ).dict() resp = requests.post( self._url("/api/submit"), diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 4d4562803..f593be87c 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -553,7 +553,6 @@ async def create_instance( session: AsyncSession, project: ProjectModel, user: UserModel, - ssh_key: SSHKey, profile: Profile, requirements: Requirements, ) -> Instance: @@ -584,7 +583,6 @@ async def create_instance( instance_name = await generate_instance_name( session=session, project=project, pool_name=pool.name ) - user_ssh_key = ssh_key project_ssh_key = SSHKey( public=project.ssh_public_key.strip(), private=project.ssh_private_key.strip(), @@ -593,7 +591,7 @@ async def create_instance( instance_config = InstanceConfiguration( project_name=project.name, instance_name=instance_name, - ssh_keys=[user_ssh_key, project_ssh_key], + ssh_keys=[project_ssh_key], job_docker_config=DockerConfig( image=dstack_default_image, registry_auth=None, diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 7f6e0ef39..5f3e17539 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -168,7 +168,7 @@ def get_run_spec( configuration_path="dstack.yaml", configuration=configuration or DevEnvironmentConfiguration(ide="vscode"), profile=profile, - ssh_key_pub="", + ssh_key_pub="user_ssh_key", ) diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index d38dee49a..99ce0b0b2 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -17,7 +17,6 @@ from dstack._internal.core.errors import ConfigurationError, ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import AnyRunConfiguration -from dstack._internal.core.models.instances import SSHKey from dstack._internal.core.models.pools import Instance from dstack._internal.core.models.profiles import ( DEFAULT_RUN_TERMINATION_IDLE_TIME, @@ -258,7 +257,6 @@ def attach( if control_sock_path_and_port_locks is None: if self._ports_lock is None: self._ports_lock = _reserve_ports(job.job_spec) - logger.debug( "Attaching to %s (%s: %s)", self.name, @@ -267,7 +265,6 @@ def attach( ) else: self._ports_lock = control_sock_path_and_port_locks[1] - logger.debug( "Reusing the existing tunnel to %s (%s: %s)", self.name, @@ -391,10 +388,8 @@ def submit( def get_offers(self, profile: Profile, requirements: Requirements) -> PoolInstanceOffers: return self._api_client.runs.get_offers(self._project, profile, requirements) - def create_instance( - self, profile: Profile, requirements: Requirements, ssh_key: SSHKey - ) -> Instance: - return self._api_client.runs.create_instance(self._project, profile, requirements, ssh_key) + def create_instance(self, profile: Profile, requirements: Requirements) -> Instance: + return self._api_client.runs.create_instance(self._project, profile, requirements) def get_plan( self, diff --git a/src/dstack/api/server/_runs.py b/src/dstack/api/server/_runs.py index 07c121a7a..e189c5164 100644 --- a/src/dstack/api/server/_runs.py +++ b/src/dstack/api/server/_runs.py @@ -2,7 +2,6 @@ from pydantic import parse_obj_as -from dstack._internal.core.models.instances import SSHKey from dstack._internal.core.models.pools import Instance from dstack._internal.core.models.profiles import Profile from dstack._internal.core.models.runs import ( @@ -68,8 +67,7 @@ def create_instance( project_name: str, profile: Profile, requirements: Requirements, - ssh_key: SSHKey, ) -> Instance: - body = CreateInstanceRequest(profile=profile, requirements=requirements, ssh_key=ssh_key) + body = CreateInstanceRequest(profile=profile, requirements=requirements) resp = self._request(f"/api/project/{project_name}/runs/create_instance", body=body.json()) return parse_obj_as(Instance.__response__, resp.json()) diff --git a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py index 48342884d..cc832587f 100644 --- a/src/tests/_internal/server/background/tasks/test_process_running_jobs.py +++ b/src/tests/_internal/server/background/tasks/test_process_running_jobs.py @@ -191,7 +191,8 @@ async def test_updates_running_job(self, test_db, session: AsyncSession, tmp_pat @pytest.mark.asyncio async def test_provisioning_shim(self, test_db, session: AsyncSession): - project = await create_project(session=session) + project_ssh_pub_key = "__project_ssh_pub_key__" + project = await create_project(session=session, ssh_public_key=project_ssh_pub_key) user = await create_user(session=session) repo = await create_repo( session=session, @@ -232,6 +233,9 @@ async def test_provisioning_shim(self, test_db, session: AsyncSession): image_name="dstackai/base:py3.11-0.4rc4-cuda-12.1", container_name="test-run-0-0", shm_size=None, + public_keys=[project_ssh_pub_key, "user_ssh_key"], + ssh_user="ubuntu", + ssh_key="user_ssh_key", ) await session.refresh(job) assert job is not None diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 9c21c1d07..26f07fbdb 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -15,7 +15,6 @@ InstanceOfferWithAvailability, InstanceType, Resources, - SSHKey, ) from dstack._internal.core.models.profiles import DEFAULT_POOL_NAME, Profile from dstack._internal.core.models.resources import ResourcesSpec @@ -900,7 +899,6 @@ async def test_creates_instance(self, test_db, session: AsyncSession): request = CreateInstanceRequest( profile=Profile(name="test_profile"), requirements=Requirements(resources=ResourcesSpec(cpu=1)), - ssh_key=SSHKey(public="test_public_key"), ) with patch( "dstack._internal.server.services.runs.get_offers_by_requirements" @@ -966,7 +964,6 @@ async def test_error_if_backends_do_not_support_create_instance( request = CreateInstanceRequest( profile=Profile(name="test_profile"), requirements=Requirements(resources=ResourcesSpec(cpu=1)), - ssh_key=SSHKey(public="test_public_key"), ) with patch( "dstack._internal.server.services.runs.get_offers_by_requirements" @@ -1004,7 +1001,6 @@ async def test_backend_does_not_support_create_instance(self, test_db, session: request = CreateInstanceRequest( profile=Profile(name="test_profile"), requirements=Requirements(resources=ResourcesSpec(cpu=1)), - ssh_key=SSHKey(public="test_public_key"), ) with patch(