diff --git a/.travis.yml b/.travis.yml index 16bfc8eb..e4c63799 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,6 +16,10 @@ go: go_import_path: aahframework.org/aah.v0 +install: + - git config --global http.https://aahframework.org.followRedirects true + - go get -t -v ./... + script: - bash <(curl -s https://aahframework.org/go-test) diff --git a/README.md b/README.md index ae8948b0..32dcc642 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ # aah web framework for Go [![Build Status](https://travis-ci.org/go-aah/aah.svg?branch=master)](https://travis-ci.org/go-aah/aah) [![codecov](https://codecov.io/gh/go-aah/aah/branch/master/graph/badge.svg)](https://codecov.io/gh/go-aah/aah/branch/master) [![Go Report Card](https://goreportcard.com/badge/aahframework.org/aah.v0)](https://goreportcard.com/report/aahframework.org/aah.v0) [![Powered by Go](https://img.shields.io/badge/powered_by-go-blue.svg)](https://golang.org) -[![Version](https://img.shields.io/badge/version-0.6-blue.svg)](https://github.com/go-aah/aah/releases/latest) [![GoDoc](https://godoc.org/aahframework.org/aah.v0?status.svg)](https://godoc.org/aahframework.org/aah.v0) -[![License](https://img.shields.io/github/license/go-aah/aah.svg)](LICENSE) +[![Version](https://img.shields.io/badge/version-0.7-blue.svg)](https://github.com/go-aah/aah/releases/latest) [![GoDoc](https://godoc.org/aahframework.org/aah.v0?status.svg)](https://godoc.org/aahframework.org/aah.v0) +[![License](https://img.shields.io/github/license/go-aah/aah.svg)](LICENSE) [![Twitter](https://img.shields.io/badge/twitter-@aahframework-55acee.svg)](https://twitter.com/aahframework) -***Release [v0.6](https://github.com/go-aah/aah/releases/latest) tagged on Jun 07, 2017*** +***Release [v0.7](https://github.com/go-aah/aah/releases/latest) tagged on Aug 01, 2017*** aah framework - A scalable, performant, rapid development Web framework for Go. diff --git a/aah.go b/aah.go index 7afe078d..a8148167 100644 --- a/aah.go +++ b/aah.go @@ -21,7 +21,7 @@ import ( ) // Version no. of aah framework -const Version = "0.6" +const Version = "0.7" // aah application variables var ( @@ -49,6 +49,7 @@ var ( appDefaultHTTPPort = "8080" appDefaultDateFormat = "2006-01-02" appDefaultDateTimeFormat = "2006-01-02 15:04:05" + appLogFatal = log.Fatal goPath string goSrcDir string @@ -63,7 +64,7 @@ type BuildInfo struct { } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ // AppName method returns aah application name from app config otherwise app name @@ -105,16 +106,9 @@ func AppHTTPAddress() string { // AppHTTPPort method returns aah application HTTP port number based on `server.port` // value. Possible outcomes are user-defined port, `80`, `443` and `8080`. func AppHTTPPort() string { - port := AppConfig().StringDefault("server.port", appDefaultHTTPPort) - if !ess.IsStrEmpty(port) { - return port - } - - if AppIsSSLEnabled() { - return "443" - } - - return "80" + port := firstNonEmpty(AppConfig().StringDefault("server.proxyport", ""), + AppConfig().StringDefault("server.port", appDefaultHTTPPort)) + return parsePort(port) } // AppDateFormat method returns aah application date format @@ -200,8 +194,11 @@ func Init(importPath string) { logAsFatal(initLogs(appLogsDir(), AppConfig())) logAsFatal(initI18n(appI18nDir())) logAsFatal(initRoutes(appConfigDir(), AppConfig())) - logAsFatal(initSecurity(appConfigDir(), AppConfig())) + logAsFatal(initSecurity(AppConfig())) logAsFatal(initViewEngine(appViewsDir(), AppConfig())) + if AppConfig().BoolDefault("server.access_log.enable", false) { + logAsFatal(initAccessLog(appLogsDir(), AppConfig())) + } } appInitialized = true @@ -228,7 +225,7 @@ func appLogsDir() string { func logAsFatal(err error) { if err != nil { - log.Fatal(err) + appLogFatal(err) } } diff --git a/aah_test.go b/aah_test.go index b02a213b..996a07b5 100644 --- a/aah_test.go +++ b/aah_test.go @@ -5,6 +5,7 @@ package aah import ( + "errors" "fmt" "io/ioutil" "path/filepath" @@ -172,6 +173,35 @@ func TestAahLogDir(t *testing.T) { cfg, _ := config.ParseString("") logger, _ := log.New(cfg) log.SetDefaultLogger(logger) + + // relative path filename + cfgRelativeFile, _ := config.ParseString(` + log { + receiver = "file" + file = "my-test-file.log" + } + `) + err = initLogs(logsDir, cfgRelativeFile) + assert.Nil(t, err) + + // no filename mentioned + cfgNoFile, _ := config.ParseString(` + log { + receiver = "file" + } + `) + SetAppBuildInfo(&BuildInfo{ + BinaryName: "testapp", + Date: time.Now().Format(time.RFC3339), + Version: "1.0.0", + }) + err = initLogs(logsDir, cfgNoFile) + assert.Nil(t, err) + appBuildInfo = nil + + appLogFatal = func(v ...interface{}) { fmt.Println(v) } + logAsFatal(errors.New("test msg")) + } func TestWritePID(t *testing.T) { diff --git a/access_log.go b/access_log.go new file mode 100644 index 00000000..383a3a10 --- /dev/null +++ b/access_log.go @@ -0,0 +1,240 @@ +// Copyright (c) Jeevanandam M. (https://github.com/jeevatkm) +// go-aah/aah source code and usage is governed by a MIT style +// license that can be found in the LICENSE file. + +package aah + +import ( + "fmt" + "net/http" + "path/filepath" + "strings" + "sync" + "time" + + "aahframework.org/ahttp.v0" + "aahframework.org/config.v0" + "aahframework.org/essentials.v0" + "aahframework.org/log.v0" +) + +const ( + fmtFlagClientIP ess.FmtFlag = iota + fmtFlagRequestTime + fmtFlagRequestURL + fmtFlagRequestMethod + fmtFlagRequestProto + fmtFlagRequestID + fmtFlagRequestHeader + fmtFlagQueryString + fmtFlagResponseStatus + fmtFlagResponseSize + fmtFlagResponseHeader + fmtFlagResponseTime + fmtFlagCustom +) + +var ( + accessLogFmtFlags = map[string]ess.FmtFlag{ + "clientip": fmtFlagClientIP, + "reqtime": fmtFlagRequestTime, + "requrl": fmtFlagRequestURL, + "reqmethod": fmtFlagRequestMethod, + "reqproto": fmtFlagRequestProto, + "reqid": fmtFlagRequestID, + "reqhdr": fmtFlagRequestHeader, + "querystr": fmtFlagQueryString, + "resstatus": fmtFlagResponseStatus, + "ressize": fmtFlagResponseSize, + "reshdr": fmtFlagResponseHeader, + "restime": fmtFlagResponseTime, + "custom": fmtFlagCustom, + } + + appDefaultAccessLogPattern = "%clientip %custom:- %reqtime %reqmethod %requrl %reqproto %resstatus %ressize %restime %reqhdr:referer" + appReqStartTimeKey = "_appReqStartTimeKey" + appReqIDHdrKey = ahttp.HeaderXRequestID + appAccessLog *log.Logger + appAccessLogFmtFlags []ess.FmtFlagPart + appAccessLogChan chan *accessLog + accessLogPool = &sync.Pool{New: func() interface{} { return &accessLog{} }} +) + +type ( + //accessLog contains data about the current request + accessLog struct { + StartTime time.Time + ElapsedDuration time.Duration + Request *ahttp.Request + ResStatus int + ResBytes int + ResHdr http.Header + } +) + +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// accessLog methods +//___________________________________ + +// FmtRequestTime method returns the formatted request time. There are three +// possibilities to handle, `%reqtime`, `%reqtime:` and `%reqtime:`. +func (al *accessLog) FmtRequestTime(format string) string { + if format == "%v" || ess.IsStrEmpty(format) { + return al.StartTime.Format(time.RFC3339) + } + return al.StartTime.Format(format) +} + +func (al *accessLog) GetRequestHdr(hdrKey string) string { + hdrValues := al.Request.Header[http.CanonicalHeaderKey(hdrKey)] + if len(hdrValues) == 0 { + return "-" + } + return `"` + strings.Join(hdrValues, ", ") + `"` +} + +func (al *accessLog) GetResponseHdr(hdrKey string) string { + hdrValues := al.ResHdr[http.CanonicalHeaderKey(hdrKey)] + if len(hdrValues) == 0 { + return "-" + } + return `"` + strings.Join(hdrValues, ", ") + `"` +} + +func (al *accessLog) GetQueryString() string { + queryStr := al.Request.Raw.URL.Query().Encode() + if ess.IsStrEmpty(queryStr) { + return "-" + } + return `"` + queryStr + `"` +} + +func (al *accessLog) Reset() { + al.StartTime = time.Time{} + al.ElapsedDuration = 0 + al.Request = nil + al.ResStatus = 0 + al.ResBytes = 0 + al.ResHdr = nil +} + +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// Unexported methods +//___________________________________ + +func initAccessLog(logsDir string, appCfg *config.Config) error { + // log file configuration + cfg, _ := config.ParseString("") + file := appCfg.StringDefault("server.access_log.file", "") + + cfg.SetString("log.receiver", "file") + if ess.IsStrEmpty(file) { + cfg.SetString("log.file", filepath.Join(logsDir, getBinaryFileName()+"-access.log")) + } else { + abspath, err := filepath.Abs(file) + if err != nil { + return err + } + cfg.SetString("log.file", abspath) + } + + cfg.SetString("log.pattern", "%message") + + var err error + + // initialize request access log file + appAccessLog, err = log.New(cfg) + if err != nil { + return err + } + + // parse request access log pattern + pattern := appCfg.StringDefault("server.access_log.pattern", appDefaultAccessLogPattern) + appAccessLogFmtFlags, err = ess.ParseFmtFlag(pattern, accessLogFmtFlags) + if err != nil { + return err + } + + // initialize request access log channel + if appAccessLogChan == nil { + appAccessLogChan = make(chan *accessLog, cfg.IntDefault("server.access_log.channel_buffer_size", 500)) + go listenForAccessLog() + } + + appReqIDHdrKey = cfg.StringDefault("request.id.header", ahttp.HeaderXRequestID) + + return nil +} + +func listenForAccessLog() { + for { + appAccessLog.Print(accessLogFormatter(<-appAccessLogChan)) + } +} + +func sendToAccessLog(ctx *Context) { + al := acquireAccessLog() + al.StartTime = ctx.values[appReqStartTimeKey].(time.Time) + + // All the bytes have been written on the wire + // so calculate elapsed time + al.ElapsedDuration = time.Since(al.StartTime) + + req := *ctx.Req + al.Request = &req + al.ResStatus = ctx.Res.Status() + al.ResBytes = ctx.Res.BytesWritten() + al.ResHdr = ctx.Res.Header() + + appAccessLogChan <- al +} + +func accessLogFormatter(al *accessLog) string { + defer releaseAccessLog(al) + buf := acquireBuffer() + defer releaseBuffer(buf) + + for _, part := range appAccessLogFmtFlags { + switch part.Flag { + case fmtFlagClientIP: + buf.WriteString(al.Request.ClientIP) + case fmtFlagRequestTime: + buf.WriteString(al.FmtRequestTime(part.Format)) + case fmtFlagRequestURL: + buf.WriteString(al.Request.Path) + case fmtFlagRequestMethod: + buf.WriteString(al.Request.Method) + case fmtFlagRequestProto: + buf.WriteString(al.Request.Unwrap().Proto) + case fmtFlagRequestID: + buf.WriteString(al.GetRequestHdr(appReqIDHdrKey)) + case fmtFlagRequestHeader: + buf.WriteString(al.GetRequestHdr(part.Format)) + case fmtFlagQueryString: + buf.WriteString(al.GetQueryString()) + case fmtFlagResponseStatus: + buf.WriteString(fmt.Sprintf(part.Format, al.ResStatus)) + case fmtFlagResponseSize: + buf.WriteString(fmt.Sprintf(part.Format, al.ResBytes)) + case fmtFlagResponseHeader: + buf.WriteString(al.GetResponseHdr(part.Format)) + case fmtFlagResponseTime: + buf.WriteString(fmt.Sprintf("%.4f", al.ElapsedDuration.Seconds()*1e3)) + case fmtFlagCustom: + buf.WriteString(part.Format) + } + buf.WriteByte(' ') + } + return strings.TrimSpace(buf.String()) +} + +func acquireAccessLog() *accessLog { + return accessLogPool.Get().(*accessLog) +} + +func releaseAccessLog(al *accessLog) { + if al != nil { + al.Reset() + accessLogPool.Put(al) + } +} diff --git a/access_log_test.go b/access_log_test.go new file mode 100644 index 00000000..96cb1fad --- /dev/null +++ b/access_log_test.go @@ -0,0 +1,180 @@ +// Copyright (c) Jeevanandam M. (https://github.com/jeevatkm) +// go-aah/aah source code and usage is governed by a MIT style +// license that can be found in the LICENSE file. + +package aah + +import ( + "fmt" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + "time" + + "aahframework.org/ahttp.v0" + "aahframework.org/config.v0" + "aahframework.org/essentials.v0" + "aahframework.org/test.v0/assert" +) + +func TestAccessLogFormatter(t *testing.T) { + al := createTestAccessLog() + + // Since we are not bootstrapping the framework's engine, + // We need to manually set this + al.Request.Path = "/oops" + al.Request.Header.Set(ahttp.HeaderXRequestID, "5946ed129bf23409520736de") + + // Testing for the default access log pattern first + expectedDefaultFormat := fmt.Sprintf("%v - %v %v %v %v %v %v %v %v", + "[::1]", al.StartTime.Format(time.RFC3339), + al.Request.Method, al.Request.Path, al.Request.Raw.Proto, al.ResStatus, + al.ResBytes, fmt.Sprintf("%.4f", al.ElapsedDuration.Seconds()*1e3), "-") + + testFormatter(t, al, appDefaultAccessLogPattern, expectedDefaultFormat) + + // Testing custom access log pattern + al = createTestAccessLog() + al.ResHdr.Add("content-type", "application/json") + pattern := "%reqtime:2016-05-16 %reqhdr %querystr %reshdr:content-type" + expected := fmt.Sprintf(`%s %s "%s" "%s"`, al.StartTime.Format("2016-05-16"), "-", "me=human", al.ResHdr.Get("Content-Type")) + + testFormatter(t, al, pattern, expected) + + // Testing all available access log pattern + al = createTestAccessLog() + al.Request.Header = al.Request.Raw.Header + al.Request.Header.Add(ahttp.HeaderAccept, "text/html") + al.Request.Header.Set(ahttp.HeaderXRequestID, "5946ed129bf23409520736de") + al.Request.ClientIP = "127.0.0.1" + al.ResHdr.Add("content-type", "application/json") + allAvailablePatterns := "%clientip %reqid %reqtime %restime %resstatus %ressize %reqmethod %requrl %reqhdr:accept %querystr %reshdr" + expectedForAllAvailablePatterns := fmt.Sprintf(`%s "%s" %s %v %d %d %s %s "%s" "%s" %s`, + al.Request.ClientIP, al.Request.Header.Get(ahttp.HeaderXRequestID), + al.StartTime.Format(time.RFC3339), fmt.Sprintf("%.4f", al.ElapsedDuration.Seconds()*1e3), + al.ResStatus, al.ResBytes, al.Request.Method, + al.Request.Path, "text/html", "me=human", "-") + + testFormatter(t, al, allAvailablePatterns, expectedForAllAvailablePatterns) +} + +func TestAccessLogFormatterInvalidPattern(t *testing.T) { + _, err := ess.ParseFmtFlag("%oops", accessLogFmtFlags) + + assert.NotNil(t, err) +} + +func TestAccessLogInitDefault(t *testing.T) { + testAccessInit(t, ` + server { + access_log { + # Default value is false + enable = true + } + } + `) + + testAccessInit(t, ` + server { + access_log { + # Default value is false + enable = true + + file = "testdata/test-access.log" + } + } + `) + + testAccessInit(t, ` + server { + access_log { + # Default value is false + enable = true + + file = "/tmp/test-access.log" + } + } + `) +} + +func TestEngineAccessLog(t *testing.T) { + // App Config + cfgDir := filepath.Join(getTestdataPath(), appConfigDir()) + err := initConfig(cfgDir) + assert.Nil(t, err) + assert.NotNil(t, AppConfig()) + + AppConfig().SetString("server.port", "8080") + + // Router + err = initRoutes(cfgDir, AppConfig()) + assert.Nil(t, err) + assert.NotNil(t, AppRouter()) + + // Security + err = initSecurity(AppConfig()) + assert.Nil(t, err) + assert.True(t, AppSessionManager().IsStateful()) + + // Controllers + cRegistry = controllerRegistry{} + + AddController((*Site)(nil), []*MethodInfo{ + { + Name: "GetInvolved", + Parameters: []*ParameterInfo{}, + }, + }) + + AppConfig().SetBool("server.access_log.enable", true) + + e := newEngine(AppConfig()) + req := httptest.NewRequest("GET", "localhost:8080/get-involved.html", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.True(t, e.isAccessLogEnabled) +} + +func testFormatter(t *testing.T, al *accessLog, pattern, expected string) { + var err error + appAccessLogFmtFlags, err = ess.ParseFmtFlag(pattern, accessLogFmtFlags) + + assert.Nil(t, err) + assert.Equal(t, expected, accessLogFormatter(al)) +} + +func testAccessInit(t *testing.T, cfgStr string) { + buildTime := time.Now().Format(time.RFC3339) + SetAppBuildInfo(&BuildInfo{ + BinaryName: "testapp", + Date: buildTime, + Version: "1.0.0", + }) + + cfg, _ := config.ParseString(cfgStr) + logsDir := filepath.Join(getTestdataPath(), appLogsDir()) + err := initAccessLog(logsDir, cfg) + + assert.Nil(t, err) + assert.NotNil(t, appAccessLog) +} + +func createTestAccessLog() *accessLog { + startTime := time.Now() + req := httptest.NewRequest("GET", "/oops?me=human", nil) + req.Header = http.Header{} + + w := httptest.NewRecorder() + + al := acquireAccessLog() + al.StartTime = startTime + al.ElapsedDuration = time.Now().Add(2 * time.Second).Sub(startTime) + al.Request = &ahttp.Request{Raw: req, Header: req.Header, ClientIP: "[::1]"} + al.ResStatus = 200 + al.ResBytes = 63 + al.ResHdr = w.HeaderMap + + return al +} diff --git a/config.go b/config.go index d150f5c8..355727e0 100644 --- a/config.go +++ b/config.go @@ -15,7 +15,7 @@ import ( var appConfig *config.Config //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ // AppConfig method returns aah application configuration instance. diff --git a/context.go b/context.go index 14e6c0bb..48de7deb 100644 --- a/context.go +++ b/context.go @@ -14,6 +14,7 @@ import ( "aahframework.org/essentials.v0" "aahframework.org/log.v0" "aahframework.org/router.v0" + "aahframework.org/security.v0" "aahframework.org/security.v0/session" ) @@ -44,9 +45,10 @@ type ( target interface{} domain *router.Domain route *router.Route - session *session.Session + subject *security.Subject reply *Reply viewArgs map[string]interface{} + values map[string]interface{} abort bool decorated bool } @@ -109,14 +111,19 @@ func (ctx *Context) Subdomain() string { return "" } +// Subject method the subject (aka application user) of current request. +func (ctx *Context) Subject() *security.Subject { + return ctx.subject +} + // Session method always returns `session.Session` object. Use `Session.IsNew` // to identify whether sesison is newly created or restored from the request // which was already created. func (ctx *Context) Session() *session.Session { - if ctx.session == nil { - ctx.session = AppSessionManager().NewSession() + if ctx.subject.Session == nil { + ctx.subject.Session = AppSessionManager().NewSession() } - return ctx.session + return ctx.subject.Session } // Abort method sets the abort to true. It means framework will not proceed with @@ -156,7 +163,7 @@ func (ctx *Context) SetURL(pathURL string) { return } - rawReq := ctx.Req.Raw + rawReq := ctx.Req.Unwrap() if !ess.IsStrEmpty(u.Host) { log.Debugf("Host have been updated from '%s' to '%s'", ctx.Req.Host, u.Host) rawReq.Host = u.Host @@ -187,7 +194,7 @@ func (ctx *Context) SetMethod(method string) { } log.Debugf("Request method have been updated from '%s' to '%s'", ctx.Req.Method, method) - ctx.Req.Raw.Method = method + ctx.Req.Unwrap().Method = method ctx.Req.Method = method } @@ -200,9 +207,10 @@ func (ctx *Context) Reset() { ctx.target = nil ctx.domain = nil ctx.route = nil - ctx.session = nil + ctx.subject = nil ctx.reply = nil - ctx.viewArgs = nil + ctx.viewArgs = make(map[string]interface{}) + ctx.values = make(map[string]interface{}) ctx.abort = false ctx.decorated = false } diff --git a/context_test.go b/context_test.go index 6bc9e160..405999e6 100644 --- a/context_test.go +++ b/context_test.go @@ -14,6 +14,7 @@ import ( "aahframework.org/ahttp.v0" "aahframework.org/config.v0" "aahframework.org/router.v0" + "aahframework.org/security.v0" "aahframework.org/test.v0/assert" ) @@ -119,6 +120,9 @@ func TestContextSetTarget(t *testing.T) { assert.NotNil(t, ctx.action.Parameters) assert.Equal(t, "userId", ctx.action.Parameters[0].Name) + ctx.controller.Namespace = "" + assert.Equal(t, "Level3", resolveControllerName(ctx)) + err2 := ctx.setTarget(&router.Route{Controller: "NoController"}) assert.Equal(t, errTargetNotFound, err2) @@ -131,10 +135,10 @@ func TestContextSession(t *testing.T) { err := initConfig(cfgDir) assert.Nil(t, err) - err = initSecurity(cfgDir, AppConfig()) + err = initSecurity(AppConfig()) assert.Nil(t, err) - ctx := &Context{viewArgs: make(map[string]interface{})} + ctx := &Context{viewArgs: make(map[string]interface{}), subject: &security.Subject{}} s1 := ctx.Session() assert.NotNil(t, s1) assert.True(t, s1.IsNew) diff --git a/controller.go b/controller.go index bf82dc66..eccb5737 100644 --- a/controller.go +++ b/controller.go @@ -55,7 +55,7 @@ type ( ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ // AddController method adds given controller into controller registory. diff --git a/engine.go b/engine.go index c3a0ce67..a4db4d3c 100644 --- a/engine.go +++ b/engine.go @@ -5,19 +5,19 @@ package aah import ( - "bytes" "errors" "fmt" "io" "net/http" - "os" + "sync" + "time" "aahframework.org/ahttp.v0" "aahframework.org/aruntime.v0" "aahframework.org/config.v0" "aahframework.org/essentials.v0" "aahframework.org/log.v0" - "aahframework.org/pool.v0" + "aahframework.org/security.v0" ) const ( @@ -26,17 +26,17 @@ const ( ) const ( - aahServerName = "aah-go-server" - gzipContentEncoding = "gzip" - hstsHeaderValue = "max-age=31536000; includeSubDomains" - defaultGlobalPoolSize = 500 - defaultBufPoolSize = 200 + aahServerName = "aah-go-server" + gzipContentEncoding = "gzip" + hstsHeaderValue = "max-age=31536000; includeSubDomains" ) var ( - minifier MinifierFunc - errFileNotFound = errors.New("file not found") - noGzipStatusCodes = []int{http.StatusNotModified, http.StatusNoContent} + errFileNotFound = errors.New("file not found") + ctHTML = ahttp.ContentTypeHTML + + minifier MinifierFunc + ctxPool *sync.Pool ) type ( @@ -50,16 +50,12 @@ type ( // Engine is the aah framework application server handler for request and response. // Implements `http.Handler` interface. engine struct { - isRequestIDEnabled bool - requestIDHeader string - isGzipEnabled bool - ctxPool *pool.Pool - reqPool *pool.Pool - replyPool *pool.Pool - bufPool *pool.Pool + isRequestIDEnabled bool + requestIDHeader string + isGzipEnabled bool + isAccessLogEnabled bool + isStaticAccessLogEnabled bool } - - byName []os.FileInfo ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ @@ -68,8 +64,13 @@ type ( // ServeHTTP method implementation of http.Handler interface. func (e *engine) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // Capture the startTime earlier. + // This value is as accurate as could be. + startTime := time.Now() + ctx := e.prepareContext(w, r) - defer e.putContext(ctx) + ctx.values[appReqStartTimeKey] = startTime + defer releaseContext(ctx) // Recovery handling, capture every possible panic(s) defer e.handleRecovery(ctx) @@ -83,14 +84,21 @@ func (e *engine) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Handling route if e.handleRoute(ctx) == flowStop { - return + goto wReply } // Load session e.loadSession(ctx) + // Authentication and Authorization + if e.handleAuthcAndAuthz(ctx) == flowStop { + goto wReply + } + // Parsing request params - e.parseRequestParams(ctx) + if e.parseRequestParams(ctx) == flowStop { + goto wReply + } // Set defaults when actual value not found e.setDefaults(ctx) @@ -98,6 +106,7 @@ func (e *engine) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Middlewares, interceptors, targeted controller e.executeMiddlewares(ctx) +wReply: // Write Reply on the wire e.writeReply(ctx) } @@ -109,22 +118,13 @@ func (e *engine) handleRecovery(ctx *Context) { log.Errorf("Internal Server Error on %s", ctx.Req.Path) st := aruntime.NewStacktrace(r, AppConfig()) - buf := e.getBuffer() - defer e.putBuffer(buf) + buf := acquireBuffer() + defer releaseBuffer(buf) st.Print(buf) log.Error(buf.String()) - ctx.Reply().InternalServerError() - e.negotiateContentType(ctx) - if ahttp.ContentTypeJSON.IsEqual(ctx.Reply().ContType) { - ctx.Reply().JSON(Data{"code": "500", "message": "Internal Server Error"}) - } else if ahttp.ContentTypeXML.IsEqual(ctx.Reply().ContType) { - ctx.Reply().XML(Data{"code": "500", "message": "Internal Server Error"}) - } else { - ctx.Reply().Text("500 Internal Server Error") - } - + writeErrorInfo(ctx, http.StatusInternalServerError, "Internal Server Error") e.writeReply(ctx) } } @@ -133,9 +133,7 @@ func (e *engine) handleRecovery(ctx *Context) { // It won't set new request id header already present. func (e *engine) setRequestID(ctx *Context) { if ess.IsStrEmpty(ctx.Req.Header.Get(e.requestIDHeader)) { - guid := ess.NewGUID() - log.Debugf("Request ID: %v", guid) - ctx.Req.Header.Set(e.requestIDHeader, guid) + ctx.Req.Header.Set(e.requestIDHeader, ess.NewGUID()) } else { log.Debugf("Request already has ID: %v", ctx.Req.Header.Get(e.requestIDHeader)) } @@ -144,13 +142,12 @@ func (e *engine) setRequestID(ctx *Context) { // prepareContext method gets controller, request from pool, set the targeted // controller, parses the request and returns the controller. -func (e *engine) prepareContext(w http.ResponseWriter, req *http.Request) *Context { - ctx, r := e.getContext(), e.getRequest() - ctx.Req = ahttp.ParseRequest(req, r) - ctx.Res = ahttp.GetResponseWriter(w) - ctx.reply = e.getReply() - ctx.viewArgs = make(map[string]interface{}) - +func (e *engine) prepareContext(w http.ResponseWriter, r *http.Request) *Context { + ctx := acquireContext() + ctx.Req = ahttp.AcquireRequest(r) + ctx.Res = ahttp.AcquireResponseWriter(w) + ctx.reply = acquireReply() + ctx.subject = security.AcquireSubject() return ctx } @@ -171,27 +168,29 @@ func (e *engine) prepareContext(w http.ResponseWriter, req *http.Request) *Conte func (e *engine) handleRoute(ctx *Context) flowResult { domain := AppRouter().FindDomain(ctx.Req) if domain == nil { - ctx.Reply().NotFound().Text("404 Not Found") - e.writeReply(ctx) + writeErrorInfo(ctx, http.StatusNotFound, "Not Found") return flowStop } route, pathParams, rts := domain.Lookup(ctx.Req) if route == nil { // route not found if err := handleRtsOptionsMna(ctx, domain, rts); err == nil { - e.writeReply(ctx) return flowStop } ctx.route = domain.NotFoundRoute handleRouteNotFound(ctx, domain, domain.NotFoundRoute) - e.writeReply(ctx) return flowStop } ctx.route = route ctx.domain = domain + // security form auth case + if isFormAuthLoginRoute(ctx) { + return flowCont + } + // Path parameters if pathParams.Len() > 0 { ctx.Req.Params.Path = make(map[string]string, pathParams.Len()) @@ -204,7 +203,7 @@ func (e *engine) handleRoute(ctx *Context) flowResult { if route.IsStatic { if err := e.serveStatic(ctx); err == errFileNotFound { handleRouteNotFound(ctx, domain, route) - e.writeReply(ctx) + ctx.Reply().done = false // override } return flowStop } @@ -212,7 +211,6 @@ func (e *engine) handleRoute(ctx *Context) flowResult { // No controller or action found for the route if err := ctx.setTarget(route); err == errTargetNotFound { handleRouteNotFound(ctx, domain, route) - e.writeReply(ctx) return flowStop } @@ -222,7 +220,7 @@ func (e *engine) handleRoute(ctx *Context) flowResult { // loadSession method loads session from request for `stateful` session. func (e *engine) loadSession(ctx *Context) { if AppSessionManager().IsStateful() { - ctx.session = AppSessionManager().GetSession(ctx.Req.Raw) + ctx.subject.Session = AppSessionManager().GetSession(ctx.Req.Unwrap()) } } @@ -241,28 +239,34 @@ func (e *engine) executeMiddlewares(ctx *Context) { // writeReply method writes the response on the wire based on `Reply` instance. func (e *engine) writeReply(ctx *Context) { - reply := ctx.Reply() - // Response already written on the wire, don't go forward. - // refer `ctx.Abort()` method. - if reply.done { + // refer to `Reply().Done()` method. + if ctx.Reply().done { return } + // 'OnPreReply' server extension point + publishOnPreReplyEvent(ctx) + // HTTP headers e.writeHeaders(ctx) // Set Cookies e.setCookies(ctx) + reply := ctx.Reply() if reply.redirect { // handle redirects log.Debugf("Redirecting to '%s' with status '%d'", reply.path, reply.Code) - http.Redirect(ctx.Res, ctx.Req.Raw, reply.path, reply.Code) + http.Redirect(ctx.Res, ctx.Req.Unwrap(), reply.path, reply.Code) return } // ContentType - e.negotiateContentType(ctx) + if !reply.IsContentTypeSet() { + if ct := identifyContentType(ctx); ct != nil { + reply.ContentType(ct.String()) + } + } // resolving view template e.resolveView(ctx) @@ -271,45 +275,38 @@ func (e *engine) writeReply(ctx *Context) { // error info without messing with response on the wire. e.doRender(ctx) - // Gzip - if !isNoGzipStatusCode(reply.Code) && reply.body.Len() != 0 { + isBodyAllowed := isResponseBodyAllowed(reply.Code) + // Gzip, 1kb above TODO make it configurable for bytes size value + if isBodyAllowed && reply.body.Len() > 1024 { e.wrapGzipWriter(ctx) } // ContentType, if it's not set then auto detect later in the writer - if ctx.Reply().IsContentTypeSet() { + if reply.IsContentTypeSet() { ctx.Res.Header().Set(ahttp.HeaderContentType, reply.ContType) } - // 'OnPreReply' server extension point - publishOnPreReplyEvent(ctx) - // HTTP status ctx.Res.WriteHeader(reply.Code) - // Write response buffer on the wire - if minifier == nil || !appIsProfileProd || - isNoGzipStatusCode(reply.Code) || - !ahttp.ContentTypeHTML.IsEqual(reply.ContType) { - _, _ = reply.body.WriteTo(ctx.Res) - } else if err := minifier(reply.ContType, ctx.Res, reply.body); err != nil { - log.Errorf("Minifier error: %v", err) + if isBodyAllowed { + // Write response on the wire + var err error + if minifier == nil || !appIsProfileProd || !ctHTML.IsEqual(reply.ContType) { + if _, err = reply.body.WriteTo(ctx.Res); err != nil { + log.Error(err) + } + } else if err = minifier(reply.ContType, ctx.Res, reply.body); err != nil { + log.Errorf("Minifier error: %s", err.Error()) + } } // 'OnAfterReply' server extension point publishOnAfterReplyEvent(ctx) -} -// negotiateContentType method tries to identify if reply.ContType is empty. -// Not necessarily it will set one. -func (e *engine) negotiateContentType(ctx *Context) { - if !ctx.Reply().IsContentTypeSet() { - if !ess.IsStrEmpty(ctx.Req.AcceptContentType.Mime) && - ctx.Req.AcceptContentType.Mime != "*/*" { // based on 'Accept' Header - ctx.Reply().ContentType(ctx.Req.AcceptContentType.String()) - } else if ct := defaultContentType(); ct != nil { // as per 'render.default' in aah.conf - ctx.Reply().ContentType(ct.String()) - } + // Send data to access log channel + if e.isAccessLogEnabled { + sendToAccessLog(ctx) } } @@ -320,7 +317,7 @@ func (e *engine) wrapGzipWriter(ctx *Context) { ctx.Res.Header().Add(ahttp.HeaderVary, ahttp.HeaderAcceptEncoding) ctx.Res.Header().Add(ahttp.HeaderContentEncoding, gzipContentEncoding) ctx.Res.Header().Del(ahttp.HeaderContentLength) - ctx.Res = ahttp.GetGzipResponseWriter(ctx.Res) + ctx.Res = ahttp.WrapGzipWriter(ctx.Res) } } @@ -348,74 +345,13 @@ func (e *engine) setCookies(ctx *Context) { http.SetCookie(ctx.Res, c) } - if AppSessionManager().IsStateful() && ctx.session != nil { - // Pass it to view args before saving cookie - session := *ctx.session - ctx.AddViewArg(keySessionValues, &session) - if err := AppSessionManager().SaveSession(ctx.Res, ctx.session); err != nil { + if AppSessionManager().IsStateful() && ctx.subject.Session != nil { + if err := AppSessionManager().SaveSession(ctx.Res, ctx.subject.Session); err != nil { log.Error(err) } } } -// getContext method gets context instance from the pool -func (e *engine) getContext() *Context { - return e.ctxPool.Get().(*Context) -} - -// getRequest method gets request instance from the pool -func (e *engine) getRequest() *ahttp.Request { - return e.reqPool.Get().(*ahttp.Request) -} - -// getReply method gets reply instance from the pool -func (e *engine) getReply() *Reply { - return e.replyPool.Get().(*Reply) -} - -// putContext method puts context back to pool -func (e *engine) putContext(ctx *Context) { - // Close the writer and Put back to pool - if ctx.Res != nil { - if _, ok := ctx.Res.(*ahttp.GzipResponse); ok { - ahttp.PutGzipResponseWiriter(ctx.Res) - } else { - ahttp.PutResponseWriter(ctx.Res) - } - } - - // clear and put `ahttp.Request` into pool - if ctx.Req != nil { - ctx.Req.Reset() - e.reqPool.Put(ctx.Req) - } - - // clear and put `Reply` into pool - if ctx.reply != nil { - e.putBuffer(ctx.reply.body) - ctx.reply.Reset() - e.replyPool.Put(ctx.reply) - } - - // clear and put `aah.Context` into pool - ctx.Reset() - e.ctxPool.Put(ctx) -} - -// getBuffer method gets buffer from pool -func (e *engine) getBuffer() *bytes.Buffer { - return e.bufPool.Get().(*bytes.Buffer) -} - -// putBPool puts buffer into pool -func (e *engine) putBuffer(b *bytes.Buffer) { - if b == nil { - return - } - b.Reset() - e.bufPool.Put(b) -} - //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Unexported methods //___________________________________ @@ -427,32 +363,33 @@ func newEngine(cfg *config.Config) *engine { } return &engine{ - isRequestIDEnabled: cfg.BoolDefault("request.id.enable", true), - requestIDHeader: cfg.StringDefault("request.id.header", ahttp.HeaderXRequestID), - isGzipEnabled: cfg.BoolDefault("render.gzip.enable", true), - ctxPool: pool.NewPool( - cfg.IntDefault("runtime.pooling.global", defaultGlobalPoolSize), - func() interface{} { - return &Context{} - }, - ), - reqPool: pool.NewPool( - cfg.IntDefault("runtime.pooling.global", defaultGlobalPoolSize), - func() interface{} { - return &ahttp.Request{} - }, - ), - replyPool: pool.NewPool( - cfg.IntDefault("runtime.pooling.global", defaultGlobalPoolSize), - func() interface{} { - return NewReply() - }, - ), - bufPool: pool.NewPool( - cfg.IntDefault("runtime.pooling.buffer", defaultBufPoolSize), - func() interface{} { - return &bytes.Buffer{} - }, - ), + isRequestIDEnabled: cfg.BoolDefault("request.id.enable", true), + requestIDHeader: cfg.StringDefault("request.id.header", ahttp.HeaderXRequestID), + isGzipEnabled: cfg.BoolDefault("render.gzip.enable", true), + isAccessLogEnabled: cfg.BoolDefault("server.access_log.enable", false), + isStaticAccessLogEnabled: cfg.BoolDefault("server.access_log.static_file", true), } } + +func acquireContext() *Context { + return ctxPool.Get().(*Context) +} + +func releaseContext(ctx *Context) { + ahttp.ReleaseResponseWriter(ctx.Res) + ahttp.ReleaseRequest(ctx.Req) + security.ReleaseSubject(ctx.subject) + releaseReply(ctx.reply) + + ctx.Reset() + ctxPool.Put(ctx) +} + +func init() { + ctxPool = &sync.Pool{New: func() interface{} { + return &Context{ + viewArgs: make(map[string]interface{}), + values: make(map[string]interface{}), + } + }} +} diff --git a/engine_test.go b/engine_test.go index aaa76d6f..67ddeac7 100644 --- a/engine_test.go +++ b/engine_test.go @@ -6,6 +6,7 @@ package aah import ( "compress/gzip" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -16,6 +17,7 @@ import ( "aahframework.org/ahttp.v0" "aahframework.org/config.v0" + "aahframework.org/essentials.v0" "aahframework.org/log.v0" "aahframework.org/test.v0/assert" ) @@ -77,21 +79,21 @@ func TestEngineNew(t *testing.T) { assert.Equal(t, "X-Test-Request-Id", e.requestIDHeader) assert.True(t, e.isRequestIDEnabled) assert.True(t, e.isGzipEnabled) - assert.NotNil(t, e.ctxPool) - assert.NotNil(t, e.bufPool) - assert.NotNil(t, e.reqPool) - req := e.getRequest() - ctx := e.getContext() - ctx.Req = req + ctx := acquireContext() + ctx.Req = &ahttp.Request{} assert.NotNil(t, ctx) - assert.NotNil(t, req) assert.NotNil(t, ctx.Req) - e.putContext(ctx) + releaseContext(ctx) - buf := e.getBuffer() + buf := acquireBuffer() assert.NotNil(t, buf) - e.putBuffer(buf) + releaseBuffer(buf) + + appLogFatal = func(v ...interface{}) { fmt.Println(v) } + AppConfig().SetInt("render.gzip.level", 10) + e = newEngine(AppConfig()) + fmt.Println(e) } func TestEngineServeHTTP(t *testing.T) { @@ -109,7 +111,7 @@ func TestEngineServeHTTP(t *testing.T) { assert.NotNil(t, AppRouter()) // Security - err = initSecurity(cfgDir, AppConfig()) + err = initSecurity(AppConfig()) assert.Nil(t, err) assert.True(t, AppSessionManager().IsStateful()) @@ -154,11 +156,10 @@ func TestEngineServeHTTP(t *testing.T) { e.ServeHTTP(w2, r2) resp2 := w2.Result() - gr2, _ := gzip.NewReader(resp2.Body) - body2, _ := ioutil.ReadAll(gr2) + body2 := getResponseBody(resp2) assert.Equal(t, 200, resp2.StatusCode) assert.True(t, strings.Contains(resp2.Status, "OK")) - assert.Equal(t, "GetInvolved action", string(body2)) + assert.Equal(t, "GetInvolved action", body2) assert.Equal(t, "test engine middleware", resp2.Header.Get("X-Custom-Name")) // Request 3 @@ -167,10 +168,10 @@ func TestEngineServeHTTP(t *testing.T) { e.ServeHTTP(w3, r3) resp3 := w3.Result() - body3, _ := ioutil.ReadAll(resp3.Body) + body3 := getResponseBody(resp3) assert.Equal(t, 500, resp3.StatusCode) assert.True(t, strings.Contains(resp3.Status, "Internal Server Error")) - assert.True(t, strings.Contains(string(body3), "Internal Server Error")) + assert.True(t, strings.Contains(body3, "Internal Server Error")) // Request 4 static r4 := httptest.NewRequest("GET", "http://localhost:8080/assets/logo.png", nil) @@ -201,10 +202,9 @@ func TestEngineServeHTTP(t *testing.T) { e.ServeHTTP(w6, r6) resp6 := w6.Result() - body6, _ := ioutil.ReadAll(resp6.Body) - body6Str := string(body6) - assert.True(t, strings.Contains(body6Str, "Listing of /testdata/")) - assert.True(t, strings.Contains(body6Str, "config/")) + body6 := getResponseBody(resp6) + assert.True(t, strings.Contains(body6, "Listing of /testdata/")) + assert.True(t, strings.Contains(body6, "config/")) // Custom Headers r7 := httptest.NewRequest("GET", "http://localhost:8080/credits", nil) @@ -213,9 +213,8 @@ func TestEngineServeHTTP(t *testing.T) { e.ServeHTTP(w7, r7) resp7 := w7.Result() - body7, _ := ioutil.ReadAll(resp7.Body) - body7Str := string(body7) - assert.Equal(t, `{"code":1000001,"message":"This is credits page"}`, body7Str) + body7 := getResponseBody(resp7) + assert.Equal(t, `{"code":1000001,"message":"This is credits page"}`, body7) assert.Equal(t, "custom value", resp7.Header.Get("X-Custom-Header")) r8 := httptest.NewRequest("POST", "http://localhost:8080/credits", nil) @@ -225,11 +224,21 @@ func TestEngineServeHTTP(t *testing.T) { // Method Not Allowed 405 response resp8 := w8.Result() - reader8, _ := gzip.NewReader(resp8.Body) - body8, _ := ioutil.ReadAll(reader8) - assert.Equal(t, "405 Method Not Allowed", string(body8)) + body8 := getResponseBody(resp8) + assert.Equal(t, "405 Method Not Allowed", body8) assert.Equal(t, "GET, OPTIONS", resp8.Header.Get("Allow")) + // Auto Options + r9 := httptest.NewRequest("OPTIONS", "http://localhost:8080/credits", nil) + r9.Header.Add(ahttp.HeaderAcceptEncoding, "gzip, deflate, sdch, br") + w9 := httptest.NewRecorder() + e.ServeHTTP(w9, r9) + + resp9 := w9.Result() + body9 := getResponseBody(resp9) + assert.Equal(t, "200 'OPTIONS' allowed HTTP Methods", body9) + assert.Equal(t, "GET, OPTIONS", resp9.Header.Get("Allow")) + appBaseDir = "" } @@ -245,4 +254,17 @@ func TestEngineGzipHeaders(t *testing.T) { assert.True(t, ctx.Req.IsGzipAccepted) assert.Equal(t, "gzip", ctx.Res.Header().Get(ahttp.HeaderContentEncoding)) assert.Equal(t, "Accept-Encoding", ctx.Res.Header().Get(ahttp.HeaderVary)) + assert.False(t, isResponseBodyAllowed(199)) + assert.False(t, isResponseBodyAllowed(304)) + assert.False(t, isResponseBodyAllowed(100)) +} + +func getResponseBody(res *http.Response) string { + r := res.Body + defer ess.CloseQuietly(r) + if strings.Contains(res.Header.Get("Content-Encoding"), "gzip") { + r, _ = gzip.NewReader(r) + } + body, _ := ioutil.ReadAll(r) + return string(body) } diff --git a/error.go b/error.go new file mode 100644 index 00000000..1de8ce34 --- /dev/null +++ b/error.go @@ -0,0 +1,40 @@ +// Copyright (c) Jeevanandam M. (https://github.com/jeevatkm) +// go-aah/aah source code and usage is governed by a MIT style +// license that can be found in the LICENSE file. + +package aah + +import ( + "strings" + + "aahframework.org/ahttp.v0" + "aahframework.org/essentials.v0" +) + +// writeError method writes the server error response based content type. +// It's handy internal method. +func writeErrorInfo(ctx *Context, code int, msg string) { + ct := ctx.Reply().ContType + if ess.IsStrEmpty(ct) { + if ict := identifyContentType(ctx); ict != nil { + ct = ict.Mime + } + } else if idx := strings.IndexByte(ct, ';'); idx > 0 { + ct = ct[:idx] + } + + switch ct { + case ahttp.ContentTypeJSON.Mime, ahttp.ContentTypeJSONText.Mime: + ctx.Reply().Status(code).JSON(Data{ + "code": code, + "message": msg, + }) + case ahttp.ContentTypeXML.Mime, ahttp.ContentTypeXMLText.Mime: + ctx.Reply().Status(code).XML(Data{ + "code": code, + "message": msg, + }) + default: + ctx.Reply().Status(code).Text("%d %s", code, msg) + } +} diff --git a/error_test.go b/error_test.go new file mode 100644 index 00000000..21db6a65 --- /dev/null +++ b/error_test.go @@ -0,0 +1,35 @@ +// Copyright (c) Jeevanandam M. (https://github.com/jeevatkm) +// go-aah/aah source code and usage is governed by a MIT style +// license that can be found in the LICENSE file. + +package aah + +import ( + "testing" + + "aahframework.org/test.v0/assert" +) + +func TestErrorWriteInfo(t *testing.T) { + ctx1 := &Context{reply: acquireReply()} + ctx1.Reply().ContentType("application/json") + writeErrorInfo(ctx1, 400, "Bad Request") + + assert.NotNil(t, ctx1.Reply().Rdr) + jsonr := ctx1.Reply().Rdr.(*JSON) + assert.NotNil(t, jsonr) + assert.NotNil(t, jsonr.Data) + assert.Equal(t, 400, jsonr.Data.(Data)["code"]) + assert.Equal(t, "Bad Request", jsonr.Data.(Data)["message"]) + + ctx2 := &Context{reply: acquireReply()} + ctx2.Reply().ContentType("application/xml") + writeErrorInfo(ctx2, 500, "Internal Server Error") + + assert.NotNil(t, ctx2.Reply().Rdr) + xmlr := ctx2.Reply().Rdr.(*XML) + assert.NotNil(t, xmlr) + assert.NotNil(t, xmlr.Data) + assert.Equal(t, 500, xmlr.Data.(Data)["code"]) + assert.Equal(t, "Internal Server Error", xmlr.Data.(Data)["message"]) +} diff --git a/event.go b/event.go index b88529e2..0da5a0c9 100644 --- a/event.go +++ b/event.go @@ -20,10 +20,10 @@ const ( // EventOnStart event is fired before HTTP/Unix listener starts EventOnStart = "OnStart" - // EventOnShutdown event is fired when server recevies interrupt or kill command. + // EventOnShutdown event is fired when server recevies an interrupt or kill command. EventOnShutdown = "OnShutdown" - // EventOnRequest event is fired when server recevies incoming request. + // EventOnRequest event is fired when server recevies an incoming request. EventOnRequest = "OnRequest" // EventOnPreReply event is fired when before server writes the reply on the wire. @@ -39,6 +39,12 @@ const ( // 2) `Reply().Redirect(...)` is called. // Refer `aah.Reply.Done()` godoc for more info. EventOnAfterReply = "OnAfterReply" + + // EventOnPreAuth event is fired before server Authenticates & Authorizes an incoming request. + EventOnPreAuth = "OnPreAuth" + + // EventOnPostAuth event is fired after server Authenticates & Authorizes an incoming request. + EventOnPostAuth = "OnPostAuth" ) var ( @@ -46,6 +52,8 @@ var ( onRequestFunc EventCallbackFunc onPreReplyFunc EventCallbackFunc onAfterReplyFunc EventCallbackFunc + onPreAuthFunc EventCallbackFunc + onPostAuthFunc EventCallbackFunc ) type ( @@ -77,7 +85,7 @@ type ( ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ // AppEventStore method returns aah application event store. @@ -119,7 +127,7 @@ func UnsubscribeEventf(eventName string, ecf EventCallbackFunc) { } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods - Server events +// Package methods - Server events //___________________________________ // OnInit method is to subscribe to aah application `OnInit` event. `OnInit` event @@ -199,6 +207,28 @@ func OnAfterReply(sef EventCallbackFunc) { log.Warn("'OnAfterReply' aah server extension point is already subscribed.") } +// OnPreAuth method is to subscribe to aah application `OnPreAuth` event. +// `OnPreAuth` event pubished right before the aah server is authenticates & +// authorizes an incoming request. +func OnPreAuth(sef EventCallbackFunc) { + if onPreAuthFunc == nil { + onPreAuthFunc = sef + return + } + log.Warn("'OnPreAuth' aah server extension point is already subscribed.") +} + +// OnPostAuth method is to subscribe to aah application `OnPreAuth` event. +// `OnPostAuth` event pubished right after the aah server is authenticates & +// authorizes an incoming request. +func OnPostAuth(sef EventCallbackFunc) { + if onPostAuthFunc == nil { + onPostAuthFunc = sef + return + } + log.Warn("'OnPostAuth' aah server extension point is already subscribed.") +} + //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // EventStore methods //___________________________________ @@ -341,6 +371,18 @@ func publishOnAfterReplyEvent(ctx *Context) { } } +func publishOnPreAuthEvent(ctx *Context) { + if onPreAuthFunc != nil { + onPreAuthFunc(&Event{Name: EventOnPreAuth, Data: ctx}) + } +} + +func publishOnPostAuthEvent(ctx *Context) { + if onPostAuthFunc != nil { + onPostAuthFunc(&Event{Name: EventOnPostAuth, Data: ctx}) + } +} + // funcEqual method to compare to function callback interface data. In effect // comparing the pointers of the indirect layer. Read more about the // representation of functions here: http://golang.org/s/go11func diff --git a/event_test.go b/event_test.go index f5b70fe7..f37e755b 100644 --- a/event_test.go +++ b/event_test.go @@ -213,6 +213,34 @@ func TestServerExtensionEvent(t *testing.T) { OnAfterReply(func(e *Event) { t.Log("OnAfterReply event func called 2") }) + + // OnPreAuth + assert.Nil(t, onPreAuthFunc) + publishOnPreAuthEvent(&Context{}) + OnPreAuth(func(e *Event) { + t.Log("OnPreAuth event func called") + }) + assert.NotNil(t, onPreAuthFunc) + + onPreAuthFunc(&Event{Name: EventOnPreAuth, Data: "Context Data OnPreAuth"}) + publishOnPreAuthEvent(&Context{}) + OnPreAuth(func(e *Event) { + t.Log("OnPreAuth event func called 2") + }) + + // OnPostAuth + assert.Nil(t, onPostAuthFunc) + publishOnPostAuthEvent(&Context{}) + OnPostAuth(func(e *Event) { + t.Log("OnPostAuth event func called") + }) + assert.NotNil(t, onPostAuthFunc) + + onPostAuthFunc(&Event{Name: EventOnPostAuth, Data: "Context Data OnPostAuth"}) + publishOnPostAuthEvent(&Context{}) + OnPostAuth(func(e *Event) { + t.Log("OnPostAuth event func called 2") + }) } func TestSubscribeAndUnsubscribeAndPublish(t *testing.T) { diff --git a/i18n.go b/i18n.go index 7a2e404f..c4a5488c 100644 --- a/i18n.go +++ b/i18n.go @@ -17,7 +17,7 @@ const keyLocale = "Locale" var appI18n *i18n.I18n //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ // AppDefaultI18nLang method returns aah application i18n default language if diff --git a/middleware.go b/middleware.go index 9cc3a4fa..8d3aae21 100644 --- a/middleware.go +++ b/middleware.go @@ -28,7 +28,7 @@ type ( ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ // Middlewares method adds given middleware into middleware stack @@ -64,7 +64,7 @@ func ToMiddleware(handler interface{}) MiddlewareFunc { case http.Handler: h := handler.(http.Handler) return func(ctx *Context, m *Middleware) { - h.ServeHTTP(ctx.Res, ctx.Req.Raw) + h.ServeHTTP(ctx.Res, ctx.Req.Unwrap()) m.Next(ctx) } case func(http.ResponseWriter, *http.Request): @@ -102,12 +102,8 @@ func interceptorMiddleware(ctx *Context, m *Middleware) { target := reflect.ValueOf(ctx.target) controller := resolveControllerName(ctx) - // Finally action and method + // Finally action and method. Always executed if present defer func() { - if ctx.abort { - return - } - if finallyActionMethod := target.MethodByName(incpFinallyActionName + ctx.action.Name); finallyActionMethod.IsValid() { log.Debugf("Calling interceptor: %s.%s", controller, incpFinallyActionName+ctx.action.Name) finallyActionMethod.Call(emptyArg) diff --git a/param.go b/param.go index 2a7947f3..135cb91d 100644 --- a/param.go +++ b/param.go @@ -14,19 +14,27 @@ import ( ) const ( - keyRequestParams = "RequestParams" - keyOverrideLocale = "lang" + // KeyViewArgRequestParams key name is used to store HTTP Request Params instance + // into `ViewArgs`. + KeyViewArgRequestParams = "_aahRequestParams" + + keyOverrideI18nName = "lang" +) + +var ( + keyQueryParamName = keyOverrideI18nName + keyPathParamName = keyOverrideI18nName ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Params method +// Params Unexported method //___________________________________ // parseRequestParams method parses the incoming HTTP request to collects request // parameters (Payload, Form, Query, Multi-part) stores into context. Request // params are made available in View via template functions. -func (e *engine) parseRequestParams(ctx *Context) { - req := ctx.Req.Raw +func (e *engine) parseRequestParams(ctx *Context) flowResult { + req := ctx.Req.Unwrap() if ctx.Req.Method != ahttp.MethodGet { contentType := ctx.Req.ContentType.Mime @@ -36,11 +44,14 @@ func (e *engine) parseRequestParams(ctx *Context) { // TODO HTML sanitizer for Form and Multipart Form switch contentType { - case ahttp.ContentTypeJSON.Mime, ahttp.ContentTypeXML.Mime, ahttp.ContentTypeXMLText.Mime: + case ahttp.ContentTypeJSON.Mime, ahttp.ContentTypeJSONText.Mime, + ahttp.ContentTypeXML.Mime, ahttp.ContentTypeXMLText.Mime: if payloadBytes, err := ioutil.ReadAll(req.Body); err == nil { ctx.Req.Payload = payloadBytes } else { log.Errorf("unable to read request body for '%s': %s", contentType, err) + writeErrorInfo(ctx, http.StatusBadRequest, "unable to read request body") + return flowStop } case ahttp.ContentTypeForm.Mime: if err := req.ParseForm(); err == nil { @@ -68,13 +79,19 @@ func (e *engine) parseRequestParams(ctx *Context) { }(req) } - // i18n option override by Query parameter `lang` - if lang := ctx.Req.QueryValue(keyOverrideLocale); !ess.IsStrEmpty(lang) { - ctx.Req.Locale = ahttp.NewLocale(lang) + // i18n locale HTTP header `Accept-Language` value override via + // Path Variable and URL Query Param (config i18n { param_name { ... } }). + // Note: Query parameter takes precedence of all. + // Default parameter name is `lang` + pathValue := ctx.Req.PathValue(keyPathParamName) + queryValue := ctx.Req.QueryValue(keyQueryParamName) + if locale := firstNonEmpty(queryValue, pathValue); !ess.IsStrEmpty(locale) { + ctx.Req.Locale = ahttp.NewLocale(locale) } // All the request parameters made available to templates via funcs. - ctx.AddViewArg(keyRequestParams, ctx.Req.Params) + ctx.AddViewArg(KeyViewArgRequestParams, ctx.Req.Params) + return flowCont } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ @@ -83,18 +100,27 @@ func (e *engine) parseRequestParams(ctx *Context) { // tmplPathParam method returns Request Path Param value for the given key. func tmplPathParam(viewArgs map[string]interface{}, key string) interface{} { - params := viewArgs[keyRequestParams].(*ahttp.Params) + params := viewArgs[KeyViewArgRequestParams].(*ahttp.Params) return sanatizeValue(params.PathValue(key)) } // tmplFormParam method returns Request Form value for the given key. func tmplFormParam(viewArgs map[string]interface{}, key string) interface{} { - params := viewArgs[keyRequestParams].(*ahttp.Params) + params := viewArgs[KeyViewArgRequestParams].(*ahttp.Params) return sanatizeValue(params.FormValue(key)) } // tmplQueryParam method returns Request Query String value for the given key. func tmplQueryParam(viewArgs map[string]interface{}, key string) interface{} { - params := viewArgs[keyRequestParams].(*ahttp.Params) + params := viewArgs[KeyViewArgRequestParams].(*ahttp.Params) return sanatizeValue(params.QueryValue(key)) } + +func paramInitialize(e *Event) { + keyPathParamName = AppConfig().StringDefault("i18n.param_name.path", keyOverrideI18nName) + keyQueryParamName = AppConfig().StringDefault("i18n.param_name.query", keyOverrideI18nName) +} + +func init() { + OnStart(paramInitialize) +} diff --git a/param_test.go b/param_test.go index 018ae77c..9e40e1ee 100644 --- a/param_test.go +++ b/param_test.go @@ -12,6 +12,7 @@ import ( "testing" "aahframework.org/ahttp.v0" + "aahframework.org/config.v0" "aahframework.org/essentials.v0" "aahframework.org/test.v0/assert" ) @@ -32,7 +33,7 @@ func TestParamTemplateFuncs(t *testing.T) { aahReq1.Params.Path["userId"] = "100001" viewArgs := map[string]interface{}{} - viewArgs[keyRequestParams] = aahReq1.Params + viewArgs[KeyViewArgRequestParams] = aahReq1.Params v1 := tmplQueryParam(viewArgs, "_ref") assert.Equal(t, "true", v1) @@ -79,3 +80,34 @@ func TestParamParse(t *testing.T) { e.parseRequestParams(ctx2) assert.NotNil(t, ctx2.Req.Params.Form) } + +func TestParamParseLocaleFromAppConfiguration(t *testing.T) { + defer ess.DeleteFiles("testapp.pid") + + cfg, err := config.ParseString(` + i18n { + param_name { + query = "language" + } + } + `) + appConfig = cfg + paramInitialize(&Event{}) + + assert.Nil(t, err) + + r := httptest.NewRequest("GET", "http://localhost:8080/index.html?language=en-CA", nil) + ctx1 := &Context{ + Req: ahttp.ParseRequest(r, &ahttp.Request{}), + viewArgs: make(map[string]interface{}), + } + + e := &engine{} + + assert.Nil(t, ctx1.Req.Locale) + e.parseRequestParams(ctx1) + assert.NotNil(t, ctx1.Req.Locale) + assert.Equal(t, "en", ctx1.Req.Locale.Language) + assert.Equal(t, "CA", ctx1.Req.Locale.Region) + assert.Equal(t, "en-CA", ctx1.Req.Locale.String()) +} diff --git a/render.go b/render.go index c4335e7d..ffde4e7d 100644 --- a/render.go +++ b/render.go @@ -13,11 +13,20 @@ import ( "io" "os" "path/filepath" + "strings" + "sync" "aahframework.org/essentials.v0" "aahframework.org/log.v0" ) +var ( + xmlHeaderBytes = []byte(xml.Header) + rdrHTMLPool = &sync.Pool{New: func() interface{} { return &HTML{} }} + rdrJSONPool = &sync.Pool{New: func() interface{} { return &JSON{} }} + rdrXMLPool = &sync.Pool{New: func() interface{} { return &XML{} }} +) + type ( // Data type used for convenient data type of map[string]interface{} Data map[string]interface{} @@ -115,6 +124,12 @@ func (j *JSON) Render(w io.Writer) error { return nil } +func (j *JSON) reset() { + j.Callback = "" + j.IsJSONP = false + j.Data = nil +} + //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // XML Render methods //___________________________________ @@ -136,6 +151,10 @@ func (x *XML) Render(w io.Writer) error { return err } + if _, err = w.Write(xmlHeaderBytes); err != nil { + return err + } + if _, err = w.Write(bytes); err != nil { return err } @@ -143,6 +162,33 @@ func (x *XML) Render(w io.Writer) error { return nil } +func (x *XML) reset() { + x.Data = nil +} + +// MarshalXML method is to marshal `aah.Data` into XML. +func (d Data) MarshalXML(e *xml.Encoder, start xml.StartElement) error { + tokens := []xml.Token{start} + for k, v := range d { + token := xml.StartElement{Name: xml.Name{Local: strings.Title(k)}} + tokens = append(tokens, token, + xml.CharData(fmt.Sprintf("%v", v)), + xml.EndElement{Name: token.Name}) + } + + tokens = append(tokens, xml.EndElement{Name: start.Name}) + + var err error + for _, t := range tokens { + if err = e.EncodeToken(t); err != nil { + return err + } + } + + // flush to ensure tokens are written + return e.Flush() +} + //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // File and Reader Render methods //___________________________________ @@ -195,12 +241,19 @@ func (h *HTML) Render(w io.Writer) error { return h.Template.ExecuteTemplate(w, h.Layout, h.ViewArgs) } +func (h *HTML) reset() { + h.Template = nil + h.Filename = "" + h.Layout = "" + h.ViewArgs = make(Data) +} + // doRender method renders and detects the errors earlier. Writes the // error info if any. func (e *engine) doRender(ctx *Context) { if ctx.Reply().Rdr != nil { reply := ctx.Reply() - reply.body = e.getBuffer() + reply.body = acquireBuffer() if jsonp, ok := reply.Rdr.(*JSON); ok && jsonp.IsJSONP { if ess.IsStrEmpty(jsonp.Callback) { jsonp.Callback = ctx.Req.QueryValue("callback") @@ -209,9 +262,37 @@ func (e *engine) doRender(ctx *Context) { if err := reply.Rdr.Render(reply.body); err != nil { log.Error("Render response body error: ", err) + // TODO handle response based on content type reply.InternalServerError() reply.body.Reset() reply.body.WriteString("500 Internal Server Error\n") } } } + +func acquireHTML() *HTML { + return rdrHTMLPool.Get().(*HTML) +} + +func acquireJSON() *JSON { + return rdrJSONPool.Get().(*JSON) +} + +func acquireXML() *XML { + return rdrXMLPool.Get().(*XML) +} + +func releaseRender(r Render) { + if r != nil { + if t, ok := r.(*JSON); ok { + t.reset() + rdrJSONPool.Put(t) + } else if t, ok := r.(*HTML); ok { + t.reset() + rdrHTMLPool.Put(t) + } else if t, ok := r.(*XML); ok { + t.reset() + rdrXMLPool.Put(t) + } + } +} diff --git a/render_test.go b/render_test.go index 9c600018..21420e86 100644 --- a/render_test.go +++ b/render_test.go @@ -119,7 +119,8 @@ func TestRenderXML(t *testing.T) { xml1 := XML{Data: data} err := xml1.Render(buf) assert.FailOnError(t, err, "") - assert.Equal(t, ` + assert.Equal(t, ` + John 28
this is my street
@@ -131,7 +132,8 @@ func TestRenderXML(t *testing.T) { err = xml1.Render(buf) assert.FailOnError(t, err, "") - assert.Equal(t, `John28
this is my street
`, + assert.Equal(t, ` +John28
this is my street
`, buf.String()) } diff --git a/reply.go b/reply.go index 5e5a8572..064a45c9 100644 --- a/reply.go +++ b/reply.go @@ -9,11 +9,17 @@ import ( "io" "net/http" "strings" + "sync" "aahframework.org/ahttp.v0" "aahframework.org/essentials.v0" ) +var ( + bufPool = &sync.Pool{New: func() interface{} { return &bytes.Buffer{} }} + replyPool = &sync.Pool{New: func() interface{} { return NewReply() }} +) + // Reply gives you control and convenient way to write a response effectively. type Reply struct { Code int @@ -29,7 +35,7 @@ type Reply struct { } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ // NewReply method returns the new instance on reply builder. @@ -155,7 +161,9 @@ func (r *Reply) ContentType(contentType string) *Reply { // Also it sets HTTP 'Content-Type' as 'application/json; charset=utf-8'. // Response rendered pretty if 'render.pretty' is true. func (r *Reply) JSON(data interface{}) *Reply { - r.Rdr = &JSON{Data: data} + j := acquireJSON() + j.Data = data + r.Rdr = j r.ContentType(ahttp.ContentTypeJSON.Raw()) return r } @@ -166,7 +174,11 @@ func (r *Reply) JSON(data interface{}) *Reply { // Note: If `callback` param is empty and `callback` query param is exists then // query param value will be used. func (r *Reply) JSONP(data interface{}, callback string) *Reply { - r.Rdr = &JSON{Data: data, IsJSONP: true, Callback: callback} + j := acquireJSON() + j.Data = data + j.IsJSONP = true + j.Callback = callback + r.Rdr = j r.ContentType(ahttp.ContentTypeJSON.Raw()) return r } @@ -175,7 +187,9 @@ func (r *Reply) JSONP(data interface{}, callback string) *Reply { // HTTP Content-Type as 'application/xml; charset=utf-8'. // Response rendered pretty if 'render.pretty' is true. func (r *Reply) XML(data interface{}) *Reply { - r.Rdr = &XML{Data: data} + x := acquireXML() + x.Data = data + r.Rdr = x r.ContentType(ahttp.ContentTypeXML.Raw()) return r } @@ -262,11 +276,11 @@ func (r *Reply) HTMLf(filename string, data Data) *Reply { // HTMLlf method renders based on given layout, filename and data. Refer `Reply.HTML(...)` // method. func (r *Reply) HTMLlf(layout, filename string, data Data) *Reply { - r.Rdr = &HTML{ - Layout: layout, - Filename: filename, - ViewArgs: data, - } + html := acquireHTML() + html.Layout = layout + html.Filename = filename + html.ViewArgs = data + r.Rdr = html r.ContentType(ahttp.ContentTypeHTML.String()) return r } @@ -361,7 +375,7 @@ func (r *Reply) Body() *bytes.Buffer { return r.body } -// Reset method resets the values into initialized state. +// Reset method resets the instance values for repurpose. func (r *Reply) Reset() { r.Code = http.StatusOK r.ContType = "" @@ -374,3 +388,31 @@ func (r *Reply) Reset() { r.done = false r.gzip = true } + +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// Unexported methods +//___________________________________ + +func acquireReply() *Reply { + return replyPool.Get().(*Reply) +} + +func releaseReply(r *Reply) { + if r != nil { + releaseBuffer(r.body) + releaseRender(r.Rdr) + r.Reset() + replyPool.Put(r) + } +} + +func acquireBuffer() *bytes.Buffer { + return bufPool.Get().(*bytes.Buffer) +} + +func releaseBuffer(b *bytes.Buffer) { + if b != nil { + b.Reset() + bufPool.Put(b) + } +} diff --git a/reply_test.go b/reply_test.go index 4bf2e6a0..4c10c159 100644 --- a/reply_test.go +++ b/reply_test.go @@ -5,7 +5,6 @@ package aah import ( - "bytes" "html/template" "net/http" "os" @@ -71,7 +70,7 @@ func TestReplyStatusCodes(t *testing.T) { } func TestReplyText(t *testing.T) { - buf, re1 := getBufferAndReply() + buf, re1 := acquireBuffer(), acquireReply() re1.Text("welcome to %s %s", "aah", "framework") assert.True(t, re1.IsContentTypeSet()) @@ -91,7 +90,7 @@ func TestReplyText(t *testing.T) { } func TestReplyJSON(t *testing.T) { - buf, re1 := getBufferAndReply() + buf, re1 := acquireBuffer(), acquireReply() appConfig = getReplyRenderCfg() data := struct { @@ -132,7 +131,7 @@ func TestReplyJSON(t *testing.T) { } func TestReplyJSONP(t *testing.T) { - buf, re1 := getBufferAndReply() + buf, re1 := acquireBuffer(), acquireReply() re1.body = buf appConfig = getReplyRenderCfg() @@ -172,7 +171,7 @@ func TestReplyJSONP(t *testing.T) { } func TestReplyXML(t *testing.T) { - buf, re1 := getBufferAndReply() + buf, re1 := acquireBuffer(), acquireReply() appConfig = getReplyRenderCfg() type Sample struct { @@ -192,7 +191,8 @@ func TestReplyXML(t *testing.T) { err := re1.Rdr.Render(buf) assert.FailOnError(t, err, "") - assert.Equal(t, ` + assert.Equal(t, ` + John 28
this is my street
@@ -204,12 +204,30 @@ func TestReplyXML(t *testing.T) { err = re1.Rdr.Render(buf) assert.FailOnError(t, err, "") - assert.Equal(t, `John28
this is my street
`, + assert.Equal(t, ` +John28
this is my street
`, buf.String()) + + buf.Reset() + + data2 := Data{ + "Name": "John", + "Age": 28, + "Address": "this is my street", + } + + re1.Rdr.(*XML).reset() + re1.XML(data2) + assert.True(t, re1.IsContentTypeSet()) + + err = re1.Rdr.Render(buf) + assert.FailOnError(t, err, "") + assert.True(t, strings.HasPrefix(buf.String(), ``)) + re1.Rdr.(*XML).reset() } func TestReplyReadfrom(t *testing.T) { - buf, re1 := getBufferAndReply() + buf, re1 := acquireBuffer(), acquireReply() re1.ContentType(ahttp.ContentTypeOctetStream.Raw()). Binary([]byte(`John28
this is my street
`)) @@ -225,7 +243,7 @@ func TestReplyReadfrom(t *testing.T) { } func TestReplyFileDownload(t *testing.T) { - buf, re1 := getBufferAndReply() + buf, re1 := acquireBuffer(), acquireReply() re1.FileDownload(getReplyFilepath("file1.txt"), "sample.txt") assert.Equal(t, http.StatusOK, re1.Code) @@ -253,7 +271,7 @@ func TestReplyHTML(t *testing.T) { {{ define "body" }}

This is test body

{{ end }} ` - buf, re1 := getBufferAndReply() + buf, re1 := acquireBuffer(), acquireReply() tmpl := template.Must(template.New("test").Parse(tmplStr)) assert.NotNil(t, tmpl) @@ -285,42 +303,47 @@ func TestReplyHTML(t *testing.T) { err = re1.Rdr.Render(buf) assert.NotNil(t, err) assert.Equal(t, "template is nil", err.Error()) + releaseReply(re1) // HTMLlf - relf := NewReply() + relf := acquireReply() relf.HTMLlf("docs.html", "Filename.html", nil) assert.Equal(t, "text/html; charset=utf-8", relf.ContType) htmllf := relf.Rdr.(*HTML) assert.Equal(t, "docs.html", htmllf.Layout) assert.Equal(t, "Filename.html", htmllf.Filename) + releaseRender(htmllf) // HTMLf - ref := NewReply() + ref := acquireReply() ref.HTMLf("Filename1.html", nil) assert.Equal(t, "text/html; charset=utf-8", ref.ContType) htmlf := ref.Rdr.(*HTML) assert.True(t, ess.IsStrEmpty(htmlf.Layout)) assert.Equal(t, "Filename1.html", htmlf.Filename) + releaseRender(htmlf) } func TestReplyRedirect(t *testing.T) { - redirect1 := NewReply() + redirect1 := acquireReply() redirect1.Redirect("/go-to-see.page") assert.Equal(t, http.StatusFound, redirect1.Code) assert.True(t, redirect1.redirect) assert.Equal(t, "/go-to-see.page", redirect1.path) + releaseReply(redirect1) - redirect2 := NewReply() + redirect2 := acquireReply() redirect2.RedirectSts("/go-to-see-gone-premanent.page", http.StatusMovedPermanently) assert.Equal(t, http.StatusMovedPermanently, redirect2.Code) assert.True(t, redirect2.redirect) assert.Equal(t, "/go-to-see-gone-premanent.page", redirect2.path) + releaseReply(redirect2) } func TestReplyDone(t *testing.T) { - re1 := NewReply() + re1 := acquireReply() assert.False(t, re1.done) re1.Done() @@ -328,9 +351,8 @@ func TestReplyDone(t *testing.T) { } func TestReplyCookie(t *testing.T) { - re1 := NewReply() + re1 := acquireReply() - assert.Nil(t, re1.cookies) re1.Cookie(&http.Cookie{ Name: "aah-test-cookie", Value: "This is reply cookie interface test value", @@ -343,6 +365,7 @@ func TestReplyCookie(t *testing.T) { cookie := re1.cookies[0] assert.Equal(t, "aah-test-cookie", cookie.Name) + releaseReply(re1) } func getReplyRenderCfg() *config.Config { @@ -354,10 +377,6 @@ func getReplyRenderCfg() *config.Config { return cfg } -func getBufferAndReply() (*bytes.Buffer, *Reply) { - return &bytes.Buffer{}, NewReply() -} - func getReplyFilepath(name string) string { wd, _ := os.Getwd() return filepath.Join(wd, "testdata", "reply", name) diff --git a/router.go b/router.go index 78145858..280ed188 100644 --- a/router.go +++ b/router.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "html/template" + "net/http" "path/filepath" "reflect" "strings" @@ -22,7 +23,7 @@ import ( var appRouter *router.Router //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ // AppRouter method returns aah application router instance. @@ -133,12 +134,12 @@ func handleRtsOptionsMna(ctx *Context, domain *router.Domain, rts bool) error { } if len(reqPath) > 1 && reqPath[len(reqPath)-1] == '/' { - ctx.Req.Raw.URL.Path = reqPath[:len(reqPath)-1] + ctx.Req.Unwrap().URL.Path = reqPath[:len(reqPath)-1] } else { - ctx.Req.Raw.URL.Path = reqPath + "/" + ctx.Req.Unwrap().URL.Path = reqPath + "/" } - reply.Redirect(ctx.Req.Raw.URL.String()) + reply.Redirect(ctx.Req.Unwrap().URL.String()) log.Debugf("RedirectTrailingSlash: %d, %s ==> %s", reply.Code, reqPath, reply.path) return nil } @@ -147,43 +148,43 @@ func handleRtsOptionsMna(ctx *Context, domain *router.Domain, rts bool) error { // HTTP: OPTIONS if reqMethod == ahttp.MethodOptions { if domain.AutoOptions { - if allowed := domain.Allowed(reqMethod, reqPath); !ess.IsStrEmpty(allowed) { - allowed += ", " + ahttp.MethodOptions - log.Debugf("Auto 'OPTIONS' allowed HTTP Methods: %s", allowed) - reply.Header(ahttp.HeaderAllow, allowed) - return nil - } + processAllowedMethods(reply, domain.Allowed(reqMethod, reqPath), "Auto 'OPTIONS', ") + writeErrorInfo(ctx, http.StatusOK, "'OPTIONS' allowed HTTP Methods") + return nil } } // 405 Method Not Allowed if domain.MethodNotAllowed { - if allowed := domain.Allowed(reqMethod, reqPath); !ess.IsStrEmpty(allowed) { - allowed += ", " + ahttp.MethodOptions - log.Debugf("Allowed HTTP Methods for 405 response: %s", allowed) - reply.MethodNotAllowed(). - Header(ahttp.HeaderAllow, allowed). - Text("405 Method Not Allowed") - return nil - } + processAllowedMethods(reply, domain.Allowed(reqMethod, reqPath), "405 response, ") + writeErrorInfo(ctx, http.StatusMethodNotAllowed, "Method Not Allowed") + return nil } return errors.New("route not found") } +func processAllowedMethods(reply *Reply, allowed, prefix string) { + if !ess.IsStrEmpty(allowed) { + allowed += ", " + ahttp.MethodOptions + reply.Header(ahttp.HeaderAllow, allowed) + log.Debugf("%sAllowed HTTP Methods: %s", prefix, allowed) + } +} + // handleRouteNotFound method is used for 1. route action not found, 2. route is // not found and 3. static file/directory. func handleRouteNotFound(ctx *Context, domain *router.Domain, route *router.Route) { // handle effectively to reduce heap allocation if domain.NotFoundRoute == nil { log.Warnf("Route not found: %s, isStaticRoute: false", ctx.Req.Path) - ctx.Reply().NotFound().Text("404 Not Found") + writeErrorInfo(ctx, http.StatusNotFound, "Not Found") return } log.Warnf("Route not found: %s, isStaticRoute: %v", ctx.Req.Path, route.IsStatic) if err := ctx.setTarget(route); err == errTargetNotFound { - ctx.Reply().NotFound().Text("404 Not Found") + writeErrorInfo(ctx, http.StatusNotFound, "Not Found") return } @@ -205,15 +206,11 @@ func tmplURL(viewArgs map[string]interface{}, args ...interface{}) template.URL log.Errorf("router: template 'rurl' - route name is empty: %v", args) return template.URL("#") } - - host := viewArgs["Host"].(string) - routeName := args[0].(string) - return template.URL(createReverseURL(host, routeName, nil, args[1:]...)) + return template.URL(createReverseURL(viewArgs["Host"].(string), args[0].(string), nil, args[1:]...)) } // tmplURLm method returns reverse URL by given route name and // map[string]interface{}. Mapped to Go template func. func tmplURLm(viewArgs map[string]interface{}, routeName string, args map[string]interface{}) template.URL { - host := viewArgs["Host"].(string) - return template.URL(createReverseURL(host, routeName, args)) + return template.URL(createReverseURL(viewArgs["Host"].(string), routeName, args)) } diff --git a/security.go b/security.go index 3b018ecd..4491dd70 100644 --- a/security.go +++ b/security.go @@ -6,31 +6,42 @@ package aah import ( "fmt" - "path/filepath" + "net/http" + "aahframework.org/ahttp.v0" "aahframework.org/config.v0" + "aahframework.org/essentials.v0" + "aahframework.org/log.v0" "aahframework.org/security.v0" + "aahframework.org/security.v0/authc" + "aahframework.org/security.v0/scheme" "aahframework.org/security.v0/session" ) -const keySessionValues = "SessionValues" +const ( + // KeyViewArgAuthcInfo key name is used to store `AuthenticationInfo` instance into `ViewArgs`. + KeyViewArgAuthcInfo = "_aahAuthcInfo" -var appSecurity *security.Security + // KeyViewArgSubject key name is used to store `Subject` instance into `ViewArgs`. + KeyViewArgSubject = "_aahSubject" +) + +var appSecurityManager = security.New() //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ -// AppSecurity method returns the application security instance, +// AppSecurityManager method returns the application security instance, // which manages the Session, CORS, CSRF, Security Headers, etc. -func AppSecurity() *security.Security { - return appSecurity +func AppSecurityManager() *security.Manager { + return appSecurityManager } // AppSessionManager method returns the application session manager. // By default session is stateless. func AppSessionManager() *session.Manager { - return AppSecurity().SessionManager + return AppSecurityManager().SessionManager } // AddSessionStore method allows you to add custom session store which @@ -40,15 +51,149 @@ func AddSessionStore(name string, store session.Storer) error { return session.AddStore(name, store) } +//‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ +// Authentication and Authorization methods +//__________________________________________ + +func (e engine) handleAuthcAndAuthz(ctx *Context) flowResult { + // If route auth is `anonymous` then continue the request flow + // No authentication or authorization is required for that route. + if ctx.route.Auth == "anonymous" { + log.Debugf("Route auth is anonymous: %v", ctx.Req.Path) + return flowCont + } + + authScheme := AppSecurityManager().GetAuthScheme(ctx.route.Auth) + if authScheme == nil { + // If one or more auth schemes are defined in `security.auth_schemes { ... }` + // and routes `auth` attribute is not defined then framework treats that route as `403 Forbidden`. + if AppSecurityManager().IsAuthSchemesConfigured() { + log.Warnf("Auth schemes are configured in security.conf, however attribute 'auth' or 'default_auth' is not defined in routes.conf, so treat it as 403 forbidden: %v", ctx.Req.Path) + writeErrorInfo(ctx, http.StatusForbidden, "Forbidden") + return flowStop + } + + // If auth scheme is not configured in security.conf then treat it as `anonymous`. + log.Tracef("Route auth scheme is not configured, so treat it as anonymous: %v", ctx.Req.Path) + return flowCont + } + + log.Debugf("Route auth scheme: %s", authScheme.Scheme()) + switch authScheme.Scheme() { + case "form": + return e.doFormAuthcAndAuthz(authScheme, ctx) + default: + return e.doAuthcAndAuthz(authScheme, ctx) + } +} + +// doFormAuthcAndAuthz method does Form Authentication and Authorization. +func (e *engine) doFormAuthcAndAuthz(ascheme scheme.Schemer, ctx *Context) flowResult { + formAuth := ascheme.(*scheme.FormAuth) + + // In Form authentication check session is already authentication if yes + // then continue the request flow immediately. + if ctx.Subject().IsAuthenticated() { + if ctx.Session().IsKeyExists(KeyViewArgAuthcInfo) { + ctx.Subject().AuthenticationInfo = ctx.Session().Get(KeyViewArgAuthcInfo).(*authc.AuthenticationInfo) + ctx.Subject().AuthorizationInfo = formAuth.DoAuthorizationInfo(ctx.Subject().AuthenticationInfo) + } else { + log.Warn("It seems there is an issue with session data of AuthenticationInfo") + } + + return flowCont + } + + // Check route is login submit URL otherwise send it login URL. + // Since session is not authenticated. + if formAuth.LoginSubmitURL != ctx.route.Path && ctx.Req.Method != ahttp.MethodPost { + loginURL := formAuth.LoginURL + if formAuth.LoginURL != ctx.Req.Path { + loginURL = fmt.Sprintf("%s?_rt=%s", loginURL, ctx.Req.Unwrap().RequestURI) + } + ctx.Reply().Redirect(loginURL) + return flowStop + } + + publishOnPreAuthEvent(ctx) + + // Do Authentication + authcInfo, err := formAuth.DoAuthenticate(formAuth.ExtractAuthenticationToken(ctx.Req)) + if err != nil || authcInfo == nil { + log.Info("Authentication is failed, sending to login failure URL") + + redirectURL := formAuth.LoginFailureURL + redirectTarget := ctx.Req.Unwrap().FormValue("_rt") + if !ess.IsStrEmpty(redirectTarget) { + redirectURL = redirectURL + "&_rt=" + redirectTarget + } + + ctx.Reply().Redirect(redirectURL) + return flowStop + } + + log.Info("Authentication successful") + + ctx.Subject().AuthenticationInfo = authcInfo + ctx.Subject().AuthorizationInfo = formAuth.DoAuthorizationInfo(authcInfo) + ctx.Session().IsAuthenticated = true + + // Remove the credential + ctx.Subject().AuthenticationInfo.Credential = nil + ctx.Session().Set(KeyViewArgAuthcInfo, ctx.Subject().AuthenticationInfo) + + publishOnPostAuthEvent(ctx) + + rt := ctx.Req.Unwrap().FormValue("_rt") + if formAuth.IsAlwaysToDefaultTarget || ess.IsStrEmpty(rt) { + ctx.Reply().Redirect(formAuth.DefaultTargetURL) + } else { + log.Debugf("Redirect to URL found: %v", rt) + ctx.Reply().Redirect(rt) + } + + return flowStop +} + +// doAuthcAndAuthz method does Authentication and Authorization. +func (e *engine) doAuthcAndAuthz(ascheme scheme.Schemer, ctx *Context) flowResult { + publishOnPreAuthEvent(ctx) + + // Do Authentication + authcInfo, err := ascheme.DoAuthenticate(ascheme.ExtractAuthenticationToken(ctx.Req)) + if err != nil || authcInfo == nil { + log.Info("Authentication is failed") + + if ascheme.Scheme() == "basic" { + basicAuth := ascheme.(*scheme.BasicAuth) + ctx.Reply().Header(ahttp.HeaderWWWAuthenticate, `Basic realm="`+basicAuth.RealmName+`"`) + } + + writeErrorInfo(ctx, http.StatusUnauthorized, "Unauthorized") + return flowStop + } + + log.Info("Authentication successful") + + ctx.Subject().AuthenticationInfo = authcInfo + ctx.Subject().AuthorizationInfo = ascheme.DoAuthorizationInfo(authcInfo) + ctx.Session().IsAuthenticated = true + + // Remove the credential + ctx.Subject().AuthenticationInfo.Credential = nil + + publishOnPostAuthEvent(ctx) + + return flowCont +} + //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Unexported methods //___________________________________ -func initSecurity(cfgDir string, appCfg *config.Config) error { - var err error - configPath := filepath.Join(cfgDir, "security.conf") - if appSecurity, err = security.New(configPath, appCfg); err != nil { - return fmt.Errorf("security init: %s", err) +func initSecurity(appCfg *config.Config) error { + if err := appSecurityManager.Init(appCfg); err != nil { + return err } // Based on aah server SSL configuration `http.Cookie.Secure` value is set, even @@ -60,6 +205,14 @@ func initSecurity(cfgDir string, appCfg *config.Config) error { return nil } +func isFormAuthLoginRoute(ctx *Context) bool { + authScheme := AppSecurityManager().GetAuthScheme(ctx.route.Auth) + if authScheme != nil && authScheme.Scheme() == "form" { + return authScheme.(*scheme.FormAuth).LoginSubmitURL == ctx.route.Path + } + return false +} + //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ // Template methods //___________________________________ @@ -67,9 +220,11 @@ func initSecurity(cfgDir string, appCfg *config.Config) error { // tmplSessionValue method returns session value for the given key. If session // object unavailable this method returns nil. func tmplSessionValue(viewArgs map[string]interface{}, key string) interface{} { - if sv, found := viewArgs[keySessionValues]; found { - value := sv.(*session.Session).Get(key) - return sanatizeValue(value) + if sub := getSubjectFromViewArgs(viewArgs); sub != nil { + if sub.Session != nil { + value := sub.Session.Get(key) + return sanatizeValue(value) + } } return nil } @@ -77,17 +232,67 @@ func tmplSessionValue(viewArgs map[string]interface{}, key string) interface{} { // tmplFlashValue method returns session value for the given key. If session // object unavailable this method returns nil. func tmplFlashValue(viewArgs map[string]interface{}, key string) interface{} { - if sv, found := viewArgs[keySessionValues]; found { - value := sv.(*session.Session).GetFlash(key) - return sanatizeValue(value) + if sub := getSubjectFromViewArgs(viewArgs); sub != nil { + if sub.Session != nil { + return sanatizeValue(sub.Session.GetFlash(key)) + } } return nil } // tmplIsAuthenticated method returns the value of `Session.IsAuthenticated`. -func tmplIsAuthenticated(viewArgs map[string]interface{}) interface{} { - if sv, found := viewArgs[keySessionValues]; found { - return sv.(*session.Session).IsAuthenticated +func tmplIsAuthenticated(viewArgs map[string]interface{}) bool { + if sub := getSubjectFromViewArgs(viewArgs); sub != nil { + if sub.Session != nil { + return sub.Session.IsAuthenticated + } + } + return false +} + +// tmplHasRole method returns the value of `Subject.HasRole`. +func tmplHasRole(viewArgs map[string]interface{}, role string) bool { + if sub := getSubjectFromViewArgs(viewArgs); sub != nil { + return sub.HasRole(role) + } + return false +} + +// tmplHasAllRoles method returns the value of `Subject.HasAllRoles`. +func tmplHasAllRoles(viewArgs map[string]interface{}, roles ...string) bool { + if sub := getSubjectFromViewArgs(viewArgs); sub != nil { + return sub.HasAllRoles(roles...) + } + return false +} + +// tmplHasAnyRole method returns the value of `Subject.HasAnyRole`. +func tmplHasAnyRole(viewArgs map[string]interface{}, roles ...string) bool { + if sub := getSubjectFromViewArgs(viewArgs); sub != nil { + return sub.HasAnyRole(roles...) + } + return false +} + +// tmplIsPermitted method returns the value of `Subject.IsPermitted`. +func tmplIsPermitted(viewArgs map[string]interface{}, permission string) bool { + if sub := getSubjectFromViewArgs(viewArgs); sub != nil { + return sub.IsPermitted(permission) } return false } + +// tmplIsPermittedAll method returns the value of `Subject.IsPermittedAll`. +func tmplIsPermittedAll(viewArgs map[string]interface{}, permissions ...string) bool { + if sub := getSubjectFromViewArgs(viewArgs); sub != nil { + return sub.IsPermittedAll(permissions...) + } + return false +} + +func getSubjectFromViewArgs(viewArgs map[string]interface{}) *security.Subject { + if sv, found := viewArgs[KeyViewArgSubject]; found { + return sv.(*security.Subject) + } + return nil +} diff --git a/security_test.go b/security_test.go index 73629e6a..f659025c 100644 --- a/security_test.go +++ b/security_test.go @@ -5,8 +5,17 @@ package aah import ( + "net/http/httptest" + "strings" "testing" + "aahframework.org/ahttp.v0" + "aahframework.org/config.v0" + "aahframework.org/router.v0" + "aahframework.org/security.v0" + "aahframework.org/security.v0/authc" + "aahframework.org/security.v0/authz" + "aahframework.org/security.v0/scheme" "aahframework.org/security.v0/session" "aahframework.org/test.v0/assert" ) @@ -24,7 +33,7 @@ func TestSecuritySessionStore(t *testing.T) { func TestSecuritySessionTemplateFuns(t *testing.T) { viewArgs := make(map[string]interface{}) - assert.Nil(t, viewArgs[keySessionValues]) + assert.Nil(t, viewArgs[KeyViewArgSubject]) bv1 := tmplSessionValue(viewArgs, "my-testvalue") assert.Nil(t, bv1) @@ -36,8 +45,18 @@ func TestSecuritySessionTemplateFuns(t *testing.T) { session.Set("my-testvalue", 38458473684763) session.SetFlash("my-flashvalue", "user not found") - viewArgs[keySessionValues] = session - assert.NotNil(t, viewArgs[keySessionValues]) + assert.False(t, tmplHasRole(viewArgs, "role1")) + assert.False(t, tmplHasAllRoles(viewArgs, "role1", "role2", "role3")) + assert.False(t, tmplHasAnyRole(viewArgs, "role1", "role2", "role3")) + assert.False(t, tmplIsPermitted(viewArgs, "*")) + assert.False(t, tmplIsPermittedAll(viewArgs, "news:read,write", "manage:*")) + + viewArgs[KeyViewArgSubject] = &security.Subject{ + Session: session, + AuthenticationInfo: authc.NewAuthenticationInfo(), + AuthorizationInfo: authz.NewAuthorizationInfo(), + } + assert.NotNil(t, viewArgs[KeyViewArgSubject]) v1 := tmplSessionValue(viewArgs, "my-testvalue") assert.Equal(t, 38458473684763, v1) @@ -47,4 +66,179 @@ func TestSecuritySessionTemplateFuns(t *testing.T) { v3 := tmplIsAuthenticated(viewArgs) assert.False(t, v3) + + assert.False(t, tmplHasRole(viewArgs, "role1")) + assert.False(t, tmplHasAllRoles(viewArgs, "role1", "role2", "role3")) + assert.False(t, tmplHasAnyRole(viewArgs, "role1", "role2", "role3")) + assert.False(t, tmplIsPermitted(viewArgs, "*")) + assert.False(t, tmplIsPermittedAll(viewArgs, "news:read,write", "manage:*")) + + delete(viewArgs, KeyViewArgSubject) + v4 := tmplIsAuthenticated(viewArgs) + assert.False(t, v4) +} + +type testFormAuthentication struct { +} + +var ( + _ authc.Authenticator = (*testFormAuthentication)(nil) + _ authz.Authorizer = (*testFormAuthentication)(nil) +) + +func (tfa *testFormAuthentication) Init(cfg *config.Config) error { + return nil +} + +func (tfa *testFormAuthentication) GetAuthenticationInfo(authcToken *authc.AuthenticationToken) (*authc.AuthenticationInfo, error) { + return testGetAuthenticationInfo(), nil +} + +func (tfa *testFormAuthentication) GetAuthorizationInfo(authcInfo *authc.AuthenticationInfo) *authz.AuthorizationInfo { + return nil +} + +func testGetAuthenticationInfo() *authc.AuthenticationInfo { + authcInfo := authc.NewAuthenticationInfo() + authcInfo.Principals = append(authcInfo.Principals, &authc.Principal{Realm: "database", Value: "jeeva", IsPrimary: true}) + authcInfo.Credential = []byte("$2y$10$2A4GsJ6SmLAMvDe8XmTam.MSkKojdobBVJfIU7GiyoM.lWt.XV3H6") // welcome123 + return authcInfo +} + +func TestSecurityHandleFormAuthcAndAuthz(t *testing.T) { + e := engine{} + + // anonymous + r1 := httptest.NewRequest("GET", "http://localhost:8080/doc/v0.3/mydoc.html", nil) + ctx1 := &Context{ + Req: ahttp.ParseRequest(r1, &ahttp.Request{}), + route: &router.Route{Auth: "anonymous"}, + } + result1 := e.handleAuthcAndAuthz(ctx1) + assert.True(t, result1 == flowCont) + + // form auth scheme + cfg, _ := config.ParseString(` + security { + auth_schemes { + # HTTP Form Auth Scheme + form_auth { + scheme = "form" + + # Authenticator is used to validate the subject (aka User) + authenticator = "security/Authentication" + + # Authorizer is used to get Subject authorization information, + # such as Roles and Permissions + authorizer = "security/Authorization" + } + } + } + `) + err := initSecurity(cfg) + assert.Nil(t, err) + r2 := httptest.NewRequest("GET", "http://localhost:8080/doc/v0.3/mydoc.html", nil) + w2 := httptest.NewRecorder() + ctx2 := &Context{ + Req: ahttp.ParseRequest(r2, &ahttp.Request{}), + Res: ahttp.GetResponseWriter(w2), + route: &router.Route{Auth: "form_auth"}, + subject: &security.Subject{}, + reply: NewReply(), + } + result2 := e.handleAuthcAndAuthz(ctx2) + assert.True(t, result2 == flowStop) + + // session is authenticated + ctx2.Session().IsAuthenticated = true + result3 := e.handleAuthcAndAuthz(ctx2) + assert.True(t, result3 == flowCont) + + // form auth + testFormAuth := &testFormAuthentication{} + formAuth := AppSecurityManager().GetAuthScheme("form_auth").(*scheme.FormAuth) + err = formAuth.SetAuthenticator(testFormAuth) + assert.Nil(t, err) + err = formAuth.SetAuthorizer(testFormAuth) + assert.Nil(t, err) + r3 := httptest.NewRequest("POST", "http://localhost:8080/login", nil) + ctx2.Req = ahttp.ParseRequest(r3, &ahttp.Request{}) + ctx2.Session().Set(KeyViewArgAuthcInfo, testGetAuthenticationInfo()) + result4 := e.handleAuthcAndAuthz(ctx2) + assert.True(t, result4 == flowCont) + + // form auth not authenticated and no credentials + ctx2.Session().IsAuthenticated = false + delete(ctx2.Session().Values, KeyViewArgAuthcInfo) + result5 := e.handleAuthcAndAuthz(ctx2) + assert.True(t, result5 == flowStop) + + // form auth not authenticated and with credentials + r4 := httptest.NewRequest("POST", "http://localhost:8080/login", strings.NewReader("username=jeeva&password=welcome123")) + r4.Header.Set(ahttp.HeaderContentType, "application/x-www-form-urlencoded") + ctx2.Req = ahttp.ParseRequest(r4, &ahttp.Request{}) + result6 := e.handleAuthcAndAuthz(ctx2) + assert.True(t, result6 == flowStop) +} + +type testBasicAuthentication struct { +} + +var ( + _ authc.Authenticator = (*testBasicAuthentication)(nil) + _ authz.Authorizer = (*testBasicAuthentication)(nil) +) + +func (tba *testBasicAuthentication) Init(cfg *config.Config) error { + return nil +} + +func (tba *testBasicAuthentication) GetAuthenticationInfo(authcToken *authc.AuthenticationToken) (*authc.AuthenticationInfo, error) { + return testGetAuthenticationInfo(), nil +} + +func (tba *testBasicAuthentication) GetAuthorizationInfo(authcInfo *authc.AuthenticationInfo) *authz.AuthorizationInfo { + return nil +} + +func TestSecurityHandleBasicAuthcAndAuthz(t *testing.T) { + e := engine{} + + // basic auth scheme + cfg, _ := config.ParseString(` + security { + auth_schemes { + # HTTP Basic Auth Scheme + basic_auth { + scheme = "basic" + authenticator = "security/Authentication" + authorizer = "security/Authorization" + } + } + } + `) + err := initSecurity(cfg) + assert.Nil(t, err) + r1 := httptest.NewRequest("GET", "http://localhost:8080/doc/v0.3/mydoc.html", nil) + w1 := httptest.NewRecorder() + ctx1 := &Context{ + Req: ahttp.ParseRequest(r1, &ahttp.Request{}), + Res: ahttp.GetResponseWriter(w1), + route: &router.Route{Auth: "basic_auth"}, + subject: &security.Subject{}, + reply: NewReply(), + } + result1 := e.handleAuthcAndAuthz(ctx1) + assert.True(t, result1 == flowStop) + + testBasicAuth := &testBasicAuthentication{} + basicAuth := AppSecurityManager().GetAuthScheme("basic_auth").(*scheme.BasicAuth) + err = basicAuth.SetAuthenticator(testBasicAuth) + assert.Nil(t, err) + err = basicAuth.SetAuthorizer(testBasicAuth) + assert.Nil(t, err) + r2 := httptest.NewRequest("GET", "http://localhost:8080/doc/v0.3/mydoc.html", nil) + ctx1.Req = ahttp.ParseRequest(r2, &ahttp.Request{}) + result2 := e.handleAuthcAndAuthz(ctx1) + assert.True(t, result2 == flowStop) } diff --git a/server.go b/server.go index 5e153354..fc18ab21 100644 --- a/server.go +++ b/server.go @@ -31,7 +31,7 @@ var ( ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ // AddServerTLSConfig method can be used for custom TLS config for aah server. @@ -62,12 +62,15 @@ func Start() { log.Infof("App Profile: %v", AppProfile()) log.Infof("App TLS/SSL Enabled: %v", AppIsSSLEnabled()) log.Infof("App Session Mode: %v", sessionMode) - log.Debugf("App i18n Locales: %v", strings.Join(AppI18n().Locales(), ", ")) - log.Debugf("App Route Domains: %v", strings.Join(AppRouter().DomainAddresses(), ", ")) - for event := range AppEventStore().subscribers { - for _, c := range AppEventStore().subscribers[event] { - log.Debugf("Callback: %s, subscribed to event: %s", funcName(c.Callback), event) + if log.IsLevelDebug() { + log.Debugf("App i18n Locales: %v", strings.Join(AppI18n().Locales(), ", ")) + log.Debugf("App Route Domains: %v", strings.Join(AppRouter().DomainAddresses(), ", ")) + + for event := range AppEventStore().subscribers { + for _, c := range AppEventStore().subscribers[event] { + log.Debugf("Callback: %s, subscribed to event: %s", funcName(c.Callback), event) + } } } @@ -121,6 +124,9 @@ func Shutdown() { graceTimeout, _ := time.ParseDuration(graceTime) ctx, cancel := context.WithTimeout(context.Background(), graceTimeout) + defer cancel() + + log.Trace("aah go server shutdown with timeout: ", graceTime) if err := aahServer.Shutdown(ctx); err != nil && err != http.ErrServerClosed { log.Error(err) } @@ -128,10 +134,9 @@ func Shutdown() { // Publish `OnShutdown` event AppEventStore().sortAndPublishSync(&Event{Name: EventOnShutdown}) - // Exit normally - cancel() log.Infof("'%v' application stopped", AppName()) - os.Exit(0) + + // bye bye } //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ @@ -182,14 +187,14 @@ func startHTTPS() { aahServer.TLSConfig.NextProtos = append(aahServer.TLSConfig.NextProtos, "h2") } - log.Infof("aah go server running on %v", aahServer.Addr) + printStartupNote() if err := aahServer.ListenAndServeTLS(appSSLCert, appSSLKey); err != nil && err != http.ErrServerClosed { log.Error(err) } } func startHTTP() { - log.Infof("aah go server running on %v", aahServer.Addr) + printStartupNote() if err := aahServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Error(err) } @@ -197,17 +202,17 @@ func startHTTP() { // listenSignals method listens to OS signals for aah server Shutdown. func listenSignals() { - sc := make(chan os.Signal, 2) + sc := make(chan os.Signal, 1) signal.Notify(sc, os.Interrupt, syscall.SIGTERM) - go func() { - switch <-sc { - case os.Interrupt: - log.Warn("Interrupt signal received") - case syscall.SIGTERM: - log.Warn("Termination signal received") - } - Shutdown() - }() + + sig := <-sc + switch sig { + case os.Interrupt: + log.Warn("Interrupt signal received") + case syscall.SIGTERM: + log.Warn("Termination signal received") + } + Shutdown() } func initAutoCertManager(cfg *config.Config) error { @@ -237,3 +242,8 @@ func initAutoCertManager(cfg *config.Config) error { return nil } + +func printStartupNote() { + port := firstNonEmpty(AppConfig().StringDefault("server.port", appDefaultHTTPPort), AppConfig().StringDefault("server.proxyport", "")) + log.Infof("aah go server running on %s:%s", AppHTTPAddress(), parsePort(port)) +} diff --git a/server_test.go b/server_test.go index 95a2f774..26c820a6 100644 --- a/server_test.go +++ b/server_test.go @@ -50,7 +50,7 @@ func TestServerStart2(t *testing.T) { assert.NotNil(t, AppRouter()) // Security - err = initSecurity(cfgDir, AppConfig()) + err = initSecurity(AppConfig()) assert.Nil(t, err) // i18n diff --git a/static.go b/static.go index fd17370c..e8efe823 100644 --- a/static.go +++ b/static.go @@ -24,8 +24,9 @@ import ( ) const ( - dirStatic = "static" - sniffLen = 512 + dirStatic = "static" + sniffLen = 512 + noCacheHdrValue = "no-cache, no-store, must-revalidate" ) var ( @@ -34,8 +35,13 @@ var ( errSeeker = errors.New("static: seeker can't seek") ) +type byName []os.FileInfo + // serveStatic method static file/directory delivery. func (e *engine) serveStatic(ctx *Context) error { + // Taking control over for static file delivery + ctx.Reply().Done() + // TODO static assets Dynamic minify for JS and CSS for non-dev profile // Determine route is file or directory as per user defined @@ -70,9 +76,11 @@ func (e *engine) serveStatic(ctx *Context) error { return nil } - // Gzip - ctx.Reply().gzip = checkGzipRequired(filePath) - e.wrapGzipWriter(ctx) + // Gzip, 1kb above + if fi.Size() > 1024 { + ctx.Reply().gzip = checkGzipRequired(filePath) + e.wrapGzipWriter(ctx) + } e.writeHeaders(ctx) // Serve file @@ -84,16 +92,24 @@ func (e *engine) serveStatic(ctx *Context) error { // apply cache header if environment profile is `prod` if appIsProfileProd { ctx.Res.Header().Set(ahttp.HeaderCacheControl, cacheHeader(contentType)) + } else { // for static files hot-reload + ctx.Res.Header().Set(ahttp.HeaderExpires, "0") + ctx.Res.Header().Set(ahttp.HeaderCacheControl, noCacheHdrValue) } } // 'OnPreReply' server extension point publishOnPreReplyEvent(ctx) - http.ServeContent(ctx.Res, ctx.Req.Raw, path.Base(filePath), fi.ModTime(), f) + http.ServeContent(ctx.Res, ctx.Req.Unwrap(), path.Base(filePath), fi.ModTime(), f) // 'OnAfterReply' server extension point publishOnAfterReplyEvent(ctx) + + // Send data to access log channel + if e.isAccessLogEnabled && e.isStaticAccessLogEnabled { + sendToAccessLog(ctx) + } return nil } @@ -113,6 +129,11 @@ func (e *engine) serveStatic(ctx *Context) error { // 'OnAfterReply' server extension point publishOnAfterReplyEvent(ctx) + + // Send data to access log channel + if e.isAccessLogEnabled && e.isStaticAccessLogEnabled { + sendToAccessLog(ctx) + } return nil } @@ -177,9 +198,11 @@ func checkGzipRequired(file string) bool { // Note: `ctx.route.*` values come from application routes configuration. func getHTTPDirAndFilePath(ctx *Context) (http.Dir, string) { if ctx.route.IsFile() { // this is configured value from routes.conf - return http.Dir(filepath.Join(AppBaseDir(), dirStatic)), ctx.route.File + return http.Dir(filepath.Join(AppBaseDir(), dirStatic)), + parseCacheBustPart(ctx.route.File, AppBuildInfo().Version) } - return http.Dir(filepath.Join(AppBaseDir(), ctx.route.Dir)), ctx.Req.PathValue("filepath") + return http.Dir(filepath.Join(AppBaseDir(), ctx.route.Dir)), + parseCacheBustPart(ctx.Req.PathValue("filepath"), AppBuildInfo().Version) } // detectFileContentType method to identify the static file content-type. @@ -234,6 +257,15 @@ func parseStaticMimeCacheMap(e *Event) { } } +func parseCacheBustPart(name, part string) string { + if strings.Contains(name, part) { + name = strings.Replace(name, "-"+part, "", 1) + name = strings.Replace(name, part+"-", "", 1) + return name + } + return name +} + func init() { OnStart(parseStaticMimeCacheMap) } diff --git a/static_test.go b/static_test.go index eb68c24e..84efb10c 100644 --- a/static_test.go +++ b/static_test.go @@ -19,21 +19,21 @@ import ( "aahframework.org/test.v0/assert" ) -func TestStaticDirectoryListing(t *testing.T) { +func TestStaticFileAndDirectoryListing(t *testing.T) { appCfg, _ := config.ParseString("") e := newEngine(appCfg) - appIsProfileProd = true testStaticServe(t, e, "http://localhost:8080/static/css/aah\x00.css", "static", "css/aah\x00.css", "", "500 Internal Server Error", false) - testStaticServe(t, e, "http://localhost:8080/static/test.txt", "static", "test.txt", "", "This is file content of test.txt", false) + testStaticServe(t, e, "http://localhost:8080/static/", "static", "", "", `Listing of /static/`, true) testStaticServe(t, e, "http://localhost:8080/static", "static", "", "", "403 Directory listing not allowed", false) testStaticServe(t, e, "http://localhost:8080/static", "static", "", "", `Moved Permanently`, true) - testStaticServe(t, e, "http://localhost:8080/static/", "static", "", "", `Listing of /static/`, true) + testStaticServe(t, e, "http://localhost:8080/static/test.txt", "static", "test.txt", "", "This is file content of test.txt", false) + appIsProfileProd = true testStaticServe(t, e, "http://localhost:8080/robots.txt", "", "", "test.txt", "This is file content of test.txt", false) appIsProfileProd = false } @@ -54,6 +54,10 @@ func TestStaticMisc(t *testing.T) { directoryList(w1, r1, f) assert.Equal(t, "Error reading directory", w1.Body.String()) + + // cache bust filename parse + filename := parseCacheBustPart("aah-813e524.css", "813e524") + assert.Equal(t, "aah.css", filename) } func TestParseStaticCacheMap(t *testing.T) { diff --git a/testdata/config/aah.conf b/testdata/config/aah.conf index 3dc962f2..0b65ac9f 100644 --- a/testdata/config/aah.conf +++ b/testdata/config/aah.conf @@ -61,6 +61,30 @@ server { host_policy = ["sample.com"] } } + + # To manage aah server effectively it is necessary to know details about the + # request, response, processing time, client IP address, etc. aah framework + # provides the flexible and configurable access log capabilities. + access_log { + # Enabling server access log + # Default value is `false`. + enable = true + + # Absolute path to access log file or relative path. + # Default location is application logs directory + #file = "{{ .AppName }}-access.log" + + # Default server access log pattern + #pattern = "%clientip %custom:- %reqtime %reqmethod %requrl %reqproto %resstatus %ressize %restime %reqhdr:referer" + + # Access Log channel buffer size + # Default value is `500`. + #channel_buffer_size = 500 + + # Include static files access log too. + # Default value is `true`. + #static_file = false + } } # --------------------- @@ -82,3 +106,9 @@ request { # Default value is 32mb, choose your value based on your use case multipart_size = "32mb" } + +# -------------------------------------------------------------- +# Application Security +# Doc: https://docs.aahframework.org/security-config.html +# -------------------------------------------------------------- +include "./security.conf" diff --git a/util.go b/util.go index 457f4df3..b8f8d1c2 100644 --- a/util.go +++ b/util.go @@ -8,12 +8,14 @@ import ( "errors" "fmt" "io/ioutil" + "net/http" "os" "path" "path/filepath" "strconv" "strings" + "aahframework.org/ahttp.v0" "aahframework.org/essentials.v0" "aahframework.org/log.v0" ) @@ -76,13 +78,14 @@ func getBinaryFileName() string { return ess.StripExt(AppBuildInfo().BinaryName) } -func isNoGzipStatusCode(code int) bool { - for _, c := range noGzipStatusCodes { - if c == code { - return true - } +// This method is similar to +// https://golang.org/src/net/http/transfer.go#bodyAllowedForStatus +func isResponseBodyAllowed(code int) bool { + if (code >= http.StatusContinue && code < http.StatusOK) || + code == http.StatusNoContent || code == http.StatusNotModified { + return false } - return false + return true } func resolveControllerName(ctx *Context) string { @@ -95,3 +98,37 @@ func resolveControllerName(ctx *Context) string { func isCharsetExists(value string) bool { return strings.Contains(value, "charset") } + +// this method is candidate for essentials library +// move it when you get a time +func firstNonEmpty(values ...string) string { + for _, v := range values { + if !ess.IsStrEmpty(v) { + return v + } + } + return "" +} + +func identifyContentType(ctx *Context) *ahttp.ContentType { + // based on 'Accept' Header + if !ess.IsStrEmpty(ctx.Req.AcceptContentType.Mime) && + ctx.Req.AcceptContentType.Mime != "*/*" { + return ctx.Req.AcceptContentType + } + + // as per 'render.default' in aah.conf or nil + return defaultContentType() +} + +func parsePort(port string) string { + if !ess.IsStrEmpty(port) { + return port + } + + if AppIsSSLEnabled() { + return "443" + } + + return "80" +} diff --git a/view.go b/view.go index a4e37c3a..eca8453e 100644 --- a/view.go +++ b/view.go @@ -30,7 +30,7 @@ var ( ) //‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾‾ -// Global methods +// Package methods //___________________________________ // AppViewEngine method returns aah application view Engine instance. @@ -49,6 +49,7 @@ func AddViewEngine(name string, engine view.Enginer) error { } // SetMinifier method sets the given minifier func into aah framework. +// Note: currently minifier is called only for HTML contentType. func SetMinifier(fn MinifierFunc) { if minifier == nil { minifier = fn @@ -101,7 +102,7 @@ func (e *engine) resolveView(ctx *Context) { reply := ctx.Reply() // HTML response - if ahttp.ContentTypeHTML.IsEqual(reply.ContType) && appViewEngine != nil { + if ctHTML.IsEqual(reply.ContType) && appViewEngine != nil { if reply.Rdr == nil { reply.Rdr = &HTML{} } @@ -136,6 +137,7 @@ func (e *engine) resolveView(ctx *Context) { htmlRdr.ViewArgs["AahVersion"] = Version htmlRdr.ViewArgs["EnvProfile"] = AppProfile() htmlRdr.ViewArgs["AppBuildInfo"] = AppBuildInfo() + htmlRdr.ViewArgs[KeyViewArgSubject] = ctx.Subject() // find view template by convention if not provided findViewTemplate(ctx) @@ -234,7 +236,12 @@ func init() { "fparam": tmplFormParam, "qparam": tmplQueryParam, "session": tmplSessionValue, - "isauthenticated": tmplIsAuthenticated, "flash": tmplFlashValue, + "isauthenticated": tmplIsAuthenticated, + "hasrole": tmplHasRole, + "hasallroles": tmplHasAllRoles, + "hasanyrole": tmplHasAnyRole, + "ispermitted": tmplIsPermitted, + "ispermittedall": tmplIsPermittedAll, }) }