Skip to content

Commit

Permalink
feat: add gorm data types
Browse files Browse the repository at this point in the history
  • Loading branch information
yashmehrotra committed Jan 9, 2023
1 parent 10c67f7 commit 48daf55
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 0 deletions.
13 changes: 13 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module github.com/flanksource/duty

go 1.19

require (
github.com/google/uuid v1.3.0
gorm.io/gorm v1.24.3
)

require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.4 // indirect
)
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
gorm.io/gorm v1.24.3 h1:WL2ifUmzR/SLp85CSURAfybcHnGZ+yLSGSxgYXlFBHg=
gorm.io/gorm v1.24.3/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA=
328 changes: 328 additions & 0 deletions types/types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
package types

import (
"context"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"

"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

// JSON defined JSON data type, need to implements driver.Valuer, sql.Scanner interface
type JSON json.RawMessage

// Value return json value, implement driver.Valuer interface
func (j JSON) Value() (driver.Value, error) {
if len(j) == 0 {
return nil, nil
}
bytes, err := json.RawMessage(j).MarshalJSON()
return string(bytes), err
}

// Scan scan value into Jsonb, implements sql.Scanner interface
func (j *JSON) Scan(value interface{}) error {
if value == nil {
*j = JSON("null")
return nil
}
var bytes []byte
switch v := value.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}

result := json.RawMessage{}
err := json.Unmarshal(bytes, &result)
*j = JSON(result)
return err
}

// MarshalJSON to output non base64 encoded []byte
func (j JSON) MarshalJSON() ([]byte, error) {
return json.RawMessage(j).MarshalJSON()
}

// UnmarshalJSON to deserialize []byte
func (j *JSON) UnmarshalJSON(b []byte) error {
result := json.RawMessage{}
err := result.UnmarshalJSON(b)
*j = JSON(result)
return err
}

func (j JSON) String() string {
return string(j)
}

// GormDataType gorm common data type
func (JSON) GormDataType() string {
return "json"
}

// GormDBDataType gorm db data type
func (JSON) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
case "mysql":
return "JSON"
case "postgres":
return "JSONB"
}
return ""
}

func (js JSON) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
if len(js) == 0 {
return gorm.Expr("NULL")
}

data, _ := js.MarshalJSON()
return gorm.Expr("?", string(data))
}

// JSONQueryExpression json query expression, implements clause.Expression interface to use as querier
type JSONQueryExpression struct {
column string
keys []string
hasKeys bool
equals bool
equalsValue interface{}
}

// JSONQuery query column as json
func JSONQuery(column string) *JSONQueryExpression {
return &JSONQueryExpression{column: column}
}

// HasKey returns clause.Expression
func (jsonQuery *JSONQueryExpression) HasKey(keys ...string) *JSONQueryExpression {
jsonQuery.keys = keys
jsonQuery.hasKeys = true
return jsonQuery
}

// Keys returns clause.Expression
func (jsonQuery *JSONQueryExpression) Equals(value interface{}, keys ...string) *JSONQueryExpression {
jsonQuery.keys = keys
jsonQuery.equals = true
jsonQuery.equalsValue = value
return jsonQuery
}

// Build implements clause.Expression
func (jsonQuery *JSONQueryExpression) Build(builder clause.Builder) {
if stmt, ok := builder.(*gorm.Statement); ok {
switch stmt.Dialector.Name() {
case "mysql", "sqlite":
switch {
case jsonQuery.hasKeys:
if len(jsonQuery.keys) > 0 {
_, _ = builder.WriteString("JSON_EXTRACT(" + stmt.Quote(jsonQuery.column) + ",")
builder.AddVar(stmt, "$."+strings.Join(jsonQuery.keys, "."))
_, _ = builder.WriteString(") IS NOT NULL")
}
case jsonQuery.equals:
if len(jsonQuery.keys) > 0 {
_, _ = builder.WriteString("JSON_EXTRACT(" + stmt.Quote(jsonQuery.column) + ",")
builder.AddVar(stmt, "$."+strings.Join(jsonQuery.keys, "."))
_, _ = builder.WriteString(") = ")
if _, ok := jsonQuery.equalsValue.(bool); ok {
_, _ = builder.WriteString(fmt.Sprint(jsonQuery.equalsValue))
} else {
stmt.AddVar(builder, jsonQuery.equalsValue)
}
}
}
case "postgres":
switch {
case jsonQuery.hasKeys:
if len(jsonQuery.keys) > 0 {
stmt.WriteQuoted(jsonQuery.column)
_, _ = stmt.WriteString("::jsonb")
for _, key := range jsonQuery.keys[0 : len(jsonQuery.keys)-1] {
_, _ = stmt.WriteString(" -> ")
stmt.AddVar(builder, key)
}

_, _ = stmt.WriteString(" ? ")
stmt.AddVar(builder, jsonQuery.keys[len(jsonQuery.keys)-1])
}
case jsonQuery.equals:
if len(jsonQuery.keys) > 0 {
_, _ = builder.WriteString(fmt.Sprintf("json_extract_path_text(%v::json,", stmt.Quote(jsonQuery.column)))

for idx, key := range jsonQuery.keys {
if idx > 0 {
_ = builder.WriteByte(',')
}
stmt.AddVar(builder, key)
}
_, _ = builder.WriteString(") = ")

if _, ok := jsonQuery.equalsValue.(string); ok {
stmt.AddVar(builder, jsonQuery.equalsValue)
} else {
stmt.AddVar(builder, fmt.Sprint(jsonQuery.equalsValue))
}
}
}
}
}
}

// JSONStringMap defiend JSON data type, need to implements driver.Valuer, sql.Scanner interface
type JSONStringMap map[string]string

// Value return json value, implement driver.Valuer interface
func (m JSONStringMap) Value() (driver.Value, error) {
if m == nil {
return nil, nil
}
ba, err := m.MarshalJSON()
return string(ba), err
}

// Scan scan value into Jsonb, implements sql.Scanner interface
func (m *JSONStringMap) Scan(val interface{}) error {
if val == nil {
*m = make(JSONStringMap)
return nil
}
var ba []byte
switch v := val.(type) {
case []byte:
ba = v
case string:
ba = []byte(v)
default:
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", val))
}
t := map[string]string{}
err := json.Unmarshal(ba, &t)
*m = t
return err
}

// MarshalJSON to output non base64 encoded []byte
func (m JSONStringMap) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("{}"), nil
}
t := (map[string]string)(m)
return json.Marshal(t)
}

// UnmarshalJSON to deserialize []byte
func (m *JSONStringMap) UnmarshalJSON(b []byte) error {
t := map[string]string{}
err := json.Unmarshal(b, &t)
*m = JSONStringMap(t)
return err
}

// GormDataType gorm common data type
func (m JSONStringMap) GormDataType() string {
return "jsonstringmap"
}

// GormDBDataType gorm db data type
func (JSONStringMap) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
case "postgres":
return "JSONB"
case "sqlserver":
return "NVARCHAR(MAX)"
}
return ""
}

func (jm JSONStringMap) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
data, _ := jm.MarshalJSON()
return gorm.Expr("?", string(data))
}

// JSONMap defiend JSON data type, need to implements driver.Valuer, sql.Scanner interface
type JSONMap map[string]interface{}

// Value return json value, implement driver.Valuer interface
func (m JSONMap) Value() (driver.Value, error) {
if m == nil {
return nil, nil
}
ba, err := m.MarshalJSON()
return string(ba), err
}

// Scan scan value into Jsonb, implements sql.Scanner interface
func (m *JSONMap) Scan(val interface{}) error {
if val == nil {
*m = make(JSONMap)
return nil
}
var ba []byte
switch v := val.(type) {
case []byte:
ba = v
case string:
ba = []byte(v)
default:
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", val))
}
t := map[string]interface{}{}
err := json.Unmarshal(ba, &t)
*m = t
return err
}

// MarshalJSON to output non base64 encoded []byte
func (m JSONMap) MarshalJSON() ([]byte, error) {
if m == nil {
return []byte("{}"), nil
}
t := (map[string]interface{})(m)
return json.Marshal(t)
}

// UnmarshalJSON to deserialize []byte
func (m *JSONMap) UnmarshalJSON(b []byte) error {
t := map[string]interface{}{}
err := json.Unmarshal(b, &t)
*m = JSONMap(t)
return err
}

// GormDataType gorm common data type
func (m JSONMap) GormDataType() string {
return "jsonmap"
}

// GormDBDataType gorm db data type
func (JSONMap) GormDBDataType(db *gorm.DB, field *schema.Field) string {
switch db.Dialector.Name() {
case "sqlite":
return "JSON"
case "postgres":
return "JSONB"
case "sqlserver":
return "NVARCHAR(MAX)"
}
return ""
}

func (jm JSONMap) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
data, _ := jm.MarshalJSON()
return gorm.Expr("?", string(data))
}

0 comments on commit 48daf55

Please sign in to comment.