Skip to content

Commit

Permalink
Wait for driver to be bound
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Popov <[email protected]>
  • Loading branch information
Vladimir Popov committed Dec 10, 2020
1 parent 9affdbc commit 2f515d2
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 12 deletions.
4 changes: 2 additions & 2 deletions pkg/networkservice/common/resourcepool/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ const (
// PCIPool is a pci.Pool interface
type PCIPool interface {
GetPCIFunction(pciAddr string) (sriov.PCIFunction, error)
BindDriver(iommuGroup uint, driverType sriov.DriverType) error
BindDriver(ctx context.Context, iommuGroup uint, driverType sriov.DriverType) error
}

// ResourcePool is a resource.Pool interface
Expand Down Expand Up @@ -103,7 +103,7 @@ func (s *resourcePoolServer) Request(ctx context.Context, request *networkservic
return errors.Wrapf(err, "failed to get VF IOMMU group: %v", vf.GetPCIAddress())
}

if err := s.pciPool.BindDriver(iommuGroup, s.driverType); err != nil {
if err := s.pciPool.BindDriver(ctx, iommuGroup, s.driverType); err != nil {
return err
}

Expand Down
74 changes: 64 additions & 10 deletions pkg/sriov/pci/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
package pci

import (
"context"
"fmt"
"os"
"time"

"github.com/pkg/errors"

"github.com/networkservicemesh/sdk-sriov/pkg/sriov"
Expand All @@ -27,7 +32,9 @@ import (
)

const (
vfioDriver = "vfio-pci"
vfioDriver = "vfio-pci"
driverBindTimeout = time.Second
driverBindCheck = driverBindTimeout / 10
)

type pciFunction interface {
Expand All @@ -41,6 +48,7 @@ type pciFunction interface {
type Pool struct {
functions map[string]*function // pciAddr -> *function
functionsByIOMMUGroup map[uint][]*function // iommuGroup -> []*function
test bool
}

type function struct {
Expand Down Expand Up @@ -80,6 +88,7 @@ func NewTestPool(physicalFunctions map[string]*sriovtest.PCIPhysicalFunction, cf
p := &Pool{
functions: map[string]*function{},
functionsByIOMMUGroup: map[uint][]*function{},
test: true,
}

for pfPCIAddr, pfCfg := range cfg.PhysicalFunctions {
Expand Down Expand Up @@ -117,15 +126,15 @@ func (p *Pool) addFunction(pcif pciFunction, kernelDriver string) (err error) {

// GetPCIFunction returns PCI function for the given PCI address
func (p *Pool) GetPCIFunction(pciAddr string) (sriov.PCIFunction, error) {
f, err := p.find(pciAddr)
if err != nil {
return nil, err
f, ok := p.functions[pciAddr]
if !ok {
return nil, errors.Errorf("PCI function doesn't exist: %v", pciAddr)
}
return f.function, nil
}

// BindDriver binds selected IOMMU group to the given driver type
func (p *Pool) BindDriver(iommuGroup uint, driverType sriov.DriverType) error {
func (p *Pool) BindDriver(ctx context.Context, iommuGroup uint, driverType sriov.DriverType) error {
for _, f := range p.functionsByIOMMUGroup[iommuGroup] {
switch driverType {
case sriov.KernelDriver:
Expand All @@ -140,13 +149,58 @@ func (p *Pool) BindDriver(iommuGroup uint, driverType sriov.DriverType) error {
return errors.Errorf("driver type is not supported: %v", driverType)
}
}

for _, f := range p.functionsByIOMMUGroup[iommuGroup] {
if err := p.waitDriverGettingBound(ctx, f.function, driverType); err != nil {
return err
}
}

return nil
}

func (p *Pool) find(pciAddr string) (*function, error) {
f, ok := p.functions[pciAddr]
if !ok {
return nil, errors.Errorf("PCI function doesn't exist: %v", pciAddr)
func (p *Pool) waitDriverGettingBound(ctx context.Context, pcif pciFunction, driverType sriov.DriverType) error {
if p.test {
return nil
}

timeoutCh := time.After(driverBindTimeout)
for {
var driverCheck func(pciFunction) error
switch driverType {
case sriov.KernelDriver:
driverCheck = kernelDriverCheck
case sriov.VFIOPCIDriver:
driverCheck = vfioDriverCheck
default:
return errors.Errorf("driver type is not supported: %v", driverType)
}

if driverCheck(pcif) == nil {
return nil
}

select {
case <-ctx.Done():
return ctx.Err()
case <-timeoutCh:
return errors.Errorf("time for binding kernel driver exceeded: %s", pcif.GetPCIAddress())
case <-time.After(driverBindCheck):
}
}
return f, nil
}

func kernelDriverCheck(pcif pciFunction) error {
_, err := pcif.GetNetInterfaceName()
return err
}

func vfioDriverCheck(pcif pciFunction) error {
iommuGroup, err := pcif.GetIOMMUGroup()
if err != nil {
return err
}

_, err = os.Stat(fmt.Sprintf("/dev/vfio/%d", iommuGroup))
return err
}

0 comments on commit 2f515d2

Please sign in to comment.