Skip to content

Commit

Permalink
new push provider: nanopush (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
jessepeterson committed Nov 29, 2023
1 parent 2142a7e commit 012ad61
Show file tree
Hide file tree
Showing 8 changed files with 404 additions and 22 deletions.
4 changes: 2 additions & 2 deletions cmd/nanomdm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/micromdm/nanomdm/http/authproxy"
httpmdm "github.com/micromdm/nanomdm/http/mdm"
"github.com/micromdm/nanomdm/log/stdlogfmt"
"github.com/micromdm/nanomdm/push/buford"
"github.com/micromdm/nanomdm/push/nanopush"
pushsvc "github.com/micromdm/nanomdm/push/service"
"github.com/micromdm/nanomdm/service"
"github.com/micromdm/nanomdm/service/certauth"
Expand Down Expand Up @@ -191,7 +191,7 @@ func main() {
const apiUsername = "nanomdm"

// create our push provider and push service
pushProviderFactory := buford.NewPushProviderFactory()
pushProviderFactory := nanopush.NewFactory()
pushService := pushsvc.New(mdmStorage, mdmStorage, pushProviderFactory, logger.With("service", "push"))

// register API handler for push cert storage/upload.
Expand Down
6 changes: 2 additions & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ require (
github.com/groob/plist v0.0.0-20220217120414-63fa881b19a5
github.com/lib/pq v1.10.9
go.mozilla.org/pkcs7 v0.0.0-20210826202110-33d05740a352
golang.org/x/net v0.0.0-20191009170851-d66e71096ffb
)

require (
golang.org/x/net v0.0.0-20191009170851-d66e71096ffb // indirect
golang.org/x/text v0.3.0 // indirect
)
require golang.org/x/text v0.3.0 // indirect
3 changes: 2 additions & 1 deletion push/buford/buford.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package buford

import (
"context"
"crypto/tls"
"errors"
"net/http"
Expand Down Expand Up @@ -125,7 +126,7 @@ func (c *bufordPushProvider) pushMulti(pushInfos []*mdm.Push) map[string]*push.R
}

// Push sends 'raw' MDM APNs push notifications to service in c.
func (c *bufordPushProvider) Push(pushInfos []*mdm.Push) (map[string]*push.Response, error) {
func (c *bufordPushProvider) Push(_ context.Context, pushInfos []*mdm.Push) (map[string]*push.Response, error) {
if len(pushInfos) < 1 {
return nil, errors.New("no push data provided")
}
Expand Down
97 changes: 97 additions & 0 deletions push/nanopush/nanopush.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// Pacakge nanopush implements an Apple APNs HTTP/2 service for MDM.
// It implements the PushProvider and PushProviderFactory interfaces.
package nanopush

import (
"crypto/tls"
"errors"
"net/http"
"time"

"github.com/micromdm/nanomdm/push"
"golang.org/x/net/http2"
)

// NewClient describes a callback for setting up an HTTP client for Push notifications.
type NewClient func(*tls.Certificate) (*http.Client, error)

// ClientWithCert configures an mTLS client cert on the HTTP client.
func ClientWithCert(client *http.Client, cert *tls.Certificate) (*http.Client, error) {
if cert == nil {
return client, errors.New("no cert provided")
}
if client == nil {
clone := *http.DefaultClient
client = &clone
}
config := &tls.Config{
Certificates: []tls.Certificate{*cert},
}
if client.Transport == nil {
client.Transport = &http.Transport{}
}
transport := client.Transport.(*http.Transport)
transport.TLSClientConfig = config
// force HTTP/2
err := http2.ConfigureTransport(transport)
return client, err
}

func defaultNewClient(cert *tls.Certificate) (*http.Client, error) {
return ClientWithCert(nil, cert)
}

// Factory instantiates new PushProviders.
type Factory struct {
newClient NewClient
expiration time.Duration
workers int
}

type Option func(*Factory)

// WithNewClient sets a callback to setup an HTTP client for each
// new Push provider.
func WithNewClient(newClient NewClient) Option {
return func(f *Factory) {
f.newClient = newClient
}
}

// WithExpiration sets the APNs expiration time for the push notifications.
func WithExpiration(expiration time.Duration) Option {
return func(f *Factory) {
f.expiration = expiration
}
}

// WithWorkers sets how many worker goroutines to use when sending pushes.
func WithWorkers(workers int) Option {
return func(f *Factory) {
f.workers = workers
}
}

// NewFactory creates a new Factory.
func NewFactory(opts ...Option) *Factory {
f := &Factory{
newClient: defaultNewClient,
workers: 5,
}
for _, opt := range opts {
opt(f)
}
return f
}

// NewPushProvider generates a new PushProvider given a tls keypair.
func (f *Factory) NewPushProvider(cert *tls.Certificate) (push.PushProvider, error) {
p := &Provider{
expiration: f.expiration,
workers: f.workers,
baseURL: Production,
}
var err error
p.client, err = f.newClient(cert)
return p, err
}
180 changes: 180 additions & 0 deletions push/nanopush/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
package nanopush

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync"
"time"

"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/push"
"golang.org/x/net/http2"
)

// Doer is ostensibly an *http.Client
type Doer interface {
Do(req *http.Request) (*http.Response, error)
}

const (
Development = "https://api.development.push.apple.com"
Development2197 = "https://api.development.push.apple.com:2197"
Production = "https://api.push.apple.com"
Production2197 = "https://api.push.apple.com:2197"
)

// Provider sends pushes to Apple's APNs servers.
type Provider struct {
client Doer
expiration time.Duration
workers int
baseURL string
}

// JSONPushError is a JSON error returned from the APNs service.
type JSONPushError struct {
Reason string `json:"reason"`
Timestamp int64 `json:"timestamp"`
}

func (e *JSONPushError) Error() string {
s := "APNs push error"
if e == nil {
return s + ": nil"
}
if e.Reason != "" {
s += ": " + e.Reason
}
if e.Timestamp > 0 {
s += ": timestamp " + strconv.FormatInt(e.Timestamp, 10)
}
return s
}

func newError(body io.Reader, statusCode int) error {
var err error = new(JSONPushError)
if decodeErr := json.NewDecoder(body).Decode(err); decodeErr != nil {
err = fmt.Errorf("decoding JSON push error: %w", decodeErr)
}
return fmt.Errorf("push HTTP status: %d: %w", statusCode, err)
}

// do performs the HTTP push request
func (p *Provider) do(ctx context.Context, pushInfo *mdm.Push) *push.Response {
jsonPayload := []byte(`{"mdm":"` + pushInfo.PushMagic + `"}`)

url := p.baseURL + "/3/device/" + pushInfo.Token.String()
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonPayload))

if err != nil {
return &push.Response{Err: err}
}

req.Header.Set("Content-Type", "application/json")
if p.expiration > 0 {
exp := time.Now().Add(p.expiration)
req.Header.Set("apns-expiration", strconv.FormatInt(exp.Unix(), 10))
}
r, err := p.client.Do(req)
var goAwayErr http2.GoAwayError
if errors.As(err, &goAwayErr) {
body := strings.NewReader(goAwayErr.DebugData)
return &push.Response{Err: newError(body, r.StatusCode)}
} else if err != nil {
return &push.Response{Err: err}
}

defer r.Body.Close()
response := &push.Response{Id: r.Header.Get("apns-id")}
if r.StatusCode != http.StatusOK {
response.Err = newError(r.Body, r.StatusCode)
}
return response
}

// pushSerial performs APNs pushes serially.
func (p *Provider) pushSerial(ctx context.Context, pushInfos []*mdm.Push) (map[string]*push.Response, error) {
ret := make(map[string]*push.Response)
for _, pushInfo := range pushInfos {
if pushInfo == nil {
continue
}
ret[pushInfo.Token.String()] = p.do(ctx, pushInfo)
}
return ret, nil
}

// pushConcurrent performs APNs pushes concurrently.
// It spawns worker goroutines and feeds them from the list of pushInfos.
func (p *Provider) pushConcurrent(ctx context.Context, pushInfos []*mdm.Push) (map[string]*push.Response, error) {
// don't start more workers than we have pushes to send
workers := p.workers
if len(pushInfos) > workers {
workers = len(pushInfos)
}

// response associates push.Response with token
type response struct {
token string
response *push.Response
}

jobs := make(chan *mdm.Push)
results := make(chan response)
var wg sync.WaitGroup

// start our workers
wg.Add(workers)
for i := 0; i < workers; i++ {
go func() {
defer wg.Done()
for pushInfo := range jobs {
results <- response{
token: pushInfo.Token.String(),
response: p.do(ctx, pushInfo),
}
}
}()
}

// start the "feeder" (queue source)
go func() {
for _, pushInfo := range pushInfos {
jobs <- pushInfo
}
close(jobs)
}()

// watch for our workers finishing (they should after feeding is done)
// stop the collector when the workers have finished.
go func() {
wg.Wait()
close(results)
}()

// collect our results
ret := make(map[string]*push.Response)
for r := range results {
ret[r.token] = r.response
}

return ret, nil
}

// Push sends APNs pushes to MDM enrollments.
func (p *Provider) Push(ctx context.Context, pushInfos []*mdm.Push) (map[string]*push.Response, error) {
if len(pushInfos) < 1 {
return nil, errors.New("no push data provided")
} else if len(pushInfos) == 1 {
return p.pushSerial(ctx, pushInfos)
} else {
return p.pushConcurrent(ctx, pushInfos)
}
}
Loading

0 comments on commit 012ad61

Please sign in to comment.