Skip to content

Commit

Permalink
Upgrade sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jmichalak committed Jul 30, 2024
1 parent f6064af commit 3ea1e0d
Show file tree
Hide file tree
Showing 22 changed files with 1,009 additions and 354 deletions.
47 changes: 47 additions & 0 deletions pkg/acceptance/helpers/aggregation_policy_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package helpers

import (
"context"
"fmt"
"testing"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/stretchr/testify/require"
)

// TODO(SNOW-TODO): change raw sqls to proper client
type AggregationPolicyClient struct {
context *TestClientContext
ids *IdsGenerator
}

func NewAggregationPolicyClient(context *TestClientContext, idsGenerator *IdsGenerator) *AggregationPolicyClient {
return &AggregationPolicyClient{
context: context,
ids: idsGenerator,
}
}

func (c *AggregationPolicyClient) client() *sdk.Client {
return c.context.client
}

func (c *AggregationPolicyClient) CreateAggregationPolicy(t *testing.T) (sdk.SchemaObjectIdentifier, func()) {
t.Helper()
ctx := context.Background()

id := c.ids.RandomSchemaObjectIdentifier()
_, err := c.client().ExecForTests(ctx, fmt.Sprintf(`CREATE AGGREGATION POLICY %s AS () RETURNS AGGREGATION_CONSTRAINT -> AGGREGATION_CONSTRAINT(MIN_GROUP_SIZE => 5)`, id.Name()))
require.NoError(t, err)
return id, c.DropAggregationPolicyFunc(t, id)
}

func (c *AggregationPolicyClient) DropAggregationPolicyFunc(t *testing.T, id sdk.SchemaObjectIdentifier) func() {
t.Helper()
ctx := context.Background()

return func() {
_, err := c.client().ExecForTests(ctx, fmt.Sprintf(`DROP AGGREGATION POLICY IF EXISTS %s`, id.Name()))
require.NoError(t, err)
}
}
18 changes: 18 additions & 0 deletions pkg/acceptance/helpers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package helpers

import (
"context"
"database/sql"
"fmt"
"log"
"testing"
Expand Down Expand Up @@ -70,3 +71,20 @@ func AssertErrorContainsPartsFunc(t *testing.T, parts []string) resource.ErrorCh
return nil
}
}

type PolicyReference struct {
PolicyDb string `db:"POLICY_DB"`
PolicySchema string `db:"POLICY_SCHEMA"`
PolicyName string `db:"POLICY_NAME"`
PolicyKind string `db:"POLICY_KIND"`
RefDatabaseName string `db:"REF_DATABASE_NAME"`
RefSchemaName string `db:"REF_SCHEMA_NAME"`
RefEntityName string `db:"REF_ENTITY_NAME"`
RefEntityDomain string `db:"REF_ENTITY_DOMAIN"`
RefColumnName sql.NullString `db:"REF_COLUMN_NAME"`
RefArgColumnNames sql.NullString `db:"REF_ARG_COLUMN_NAMES"`
TagDatabase sql.NullString `db:"TAG_DATABASE"`
TagSchema sql.NullString `db:"TAG_SCHEMA"`
TagName sql.NullString `db:"TAG_NAME"`
PolicyStatus string `db:"POLICY_STATUS"`
}
47 changes: 47 additions & 0 deletions pkg/acceptance/helpers/projection_policy_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package helpers

import (
"context"
"fmt"
"testing"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/stretchr/testify/require"
)

// TODO(SNOW-TODO): change raw sqls to proper client
type ProjectionPolicyClient struct {
context *TestClientContext
ids *IdsGenerator
}

func NewProjectionPolicyClient(context *TestClientContext, idsGenerator *IdsGenerator) *ProjectionPolicyClient {
return &ProjectionPolicyClient{
context: context,
ids: idsGenerator,
}
}

func (c *ProjectionPolicyClient) client() *sdk.Client {
return c.context.client
}

func (c *ProjectionPolicyClient) CreateProjectionPolicy(t *testing.T) (sdk.SchemaObjectIdentifier, func()) {
t.Helper()
ctx := context.Background()

id := c.ids.RandomSchemaObjectIdentifier()
_, err := c.client().ExecForTests(ctx, fmt.Sprintf(`CREATE PROJECTION POLICY %s AS () RETURNS PROJECTION_CONSTRAINT -> PROJECTION_CONSTRAINT(ALLOW => false)`, id.Name()))
require.NoError(t, err)
return id, c.DropProjectionPolicyFunc(t, id)
}

func (c *ProjectionPolicyClient) DropProjectionPolicyFunc(t *testing.T, id sdk.SchemaObjectIdentifier) func() {
t.Helper()
ctx := context.Background()

return func() {
_, err := c.client().ExecForTests(ctx, fmt.Sprintf(`DROP PROJECTION POLICY IF EXISTS %s`, id.Name()))
require.NoError(t, err)
}
}
29 changes: 12 additions & 17 deletions pkg/acceptance/helpers/row_access_policy_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package helpers

import (
"context"
"database/sql"
"fmt"
"testing"

Expand Down Expand Up @@ -56,7 +55,7 @@ func (c *RowAccessPolicyClient) DropRowAccessPolicyFunc(t *testing.T, id sdk.Sch

// GetRowAccessPolicyFor is based on https://docs.snowflake.com/en/user-guide/security-row-intro#obtain-database-objects-with-a-row-access-policy.
// TODO: extract getting row access policies as resource (like getting tag in system functions)
func (c *RowAccessPolicyClient) GetRowAccessPolicyFor(t *testing.T, id sdk.SchemaObjectIdentifier, objectType sdk.ObjectType) (*PolicyReference, error) {
func (c *RowAccessPolicyClient) GetOneRowAccessPolicyFor(t *testing.T, id sdk.SchemaObjectIdentifier, objectType sdk.ObjectType) (*PolicyReference, error) {
t.Helper()
ctx := context.Background()

Expand All @@ -67,19 +66,15 @@ func (c *RowAccessPolicyClient) GetRowAccessPolicyFor(t *testing.T, id sdk.Schem
return s, err
}

type PolicyReference struct {
PolicyDb string `db:"POLICY_DB"`
PolicySchema string `db:"POLICY_SCHEMA"`
PolicyName string `db:"POLICY_NAME"`
PolicyKind string `db:"POLICY_KIND"`
RefDatabaseName string `db:"REF_DATABASE_NAME"`
RefSchemaName string `db:"REF_SCHEMA_NAME"`
RefEntityName string `db:"REF_ENTITY_NAME"`
RefEntityDomain string `db:"REF_ENTITY_DOMAIN"`
RefColumnName sql.NullString `db:"REF_COLUMN_NAME"`
RefArgColumnNames string `db:"REF_ARG_COLUMN_NAMES"`
TagDatabase sql.NullString `db:"TAG_DATABASE"`
TagSchema sql.NullString `db:"TAG_SCHEMA"`
TagName sql.NullString `db:"TAG_NAME"`
PolicyStatus string `db:"POLICY_STATUS"`
// GetRowAccessPolicyFor is based on https://docs.snowflake.com/en/user-guide/security-row-intro#obtain-database-objects-with-a-row-access-policy.
// TODO: this is a generic function for all kinds of policies. Move to commons and add filtering for other policies
func (c *RowAccessPolicyClient) GetAllPoliciesFor(t *testing.T, id sdk.SchemaObjectIdentifier, objectType sdk.ObjectType) ([]PolicyReference, error) {
t.Helper()
ctx := context.Background()

s := []PolicyReference{}
policyReferencesId := sdk.NewSchemaObjectIdentifier(id.DatabaseName(), "INFORMATION_SCHEMA", "POLICY_REFERENCES")
err := c.context.client.QueryForTests(ctx, &s, fmt.Sprintf(`SELECT * FROM TABLE(%s(REF_ENTITY_NAME => '%s', REF_ENTITY_DOMAIN => '%v'))`, policyReferencesId.FullyQualifiedName(), id.FullyQualifiedName(), objectType))

return s, err
}
4 changes: 4 additions & 0 deletions pkg/acceptance/helpers/test_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type TestClient struct {
Ids *IdsGenerator

Account *AccountClient
AggregationPolicy *AggregationPolicyClient
Alert *AlertClient
ApiIntegration *ApiIntegrationClient
Application *ApplicationClient
Expand All @@ -31,6 +32,7 @@ type TestClient struct {
Parameter *ParameterClient
PasswordPolicy *PasswordPolicyClient
Pipe *PipeClient
ProjectionPolicy *ProjectionPolicyClient
ResourceMonitor *ResourceMonitorClient
Role *RoleClient
RowAccessPolicy *RowAccessPolicyClient
Expand Down Expand Up @@ -63,6 +65,7 @@ func NewTestClient(c *sdk.Client, database string, schema string, warehouse stri
Ids: idsGenerator,

Account: NewAccountClient(context),
AggregationPolicy: NewAggregationPolicyClient(context, idsGenerator),
Alert: NewAlertClient(context, idsGenerator),
ApiIntegration: NewApiIntegrationClient(context, idsGenerator),
Application: NewApplicationClient(context, idsGenerator),
Expand All @@ -84,6 +87,7 @@ func NewTestClient(c *sdk.Client, database string, schema string, warehouse stri
Parameter: NewParameterClient(context),
PasswordPolicy: NewPasswordPolicyClient(context, idsGenerator),
Pipe: NewPipeClient(context, idsGenerator),
ProjectionPolicy: NewProjectionPolicyClient(context, idsGenerator),
ResourceMonitor: NewResourceMonitorClient(context, idsGenerator),
Role: NewRoleClient(context, idsGenerator),
RowAccessPolicy: NewRowAccessPolicyClient(context, idsGenerator),
Expand Down
4 changes: 2 additions & 2 deletions pkg/acceptance/helpers/view_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (c *ViewClient) RecreateView(t *testing.T, id sdk.SchemaObjectIdentifier, q
t.Helper()
ctx := context.Background()

err := c.client().Create(ctx, sdk.NewCreateViewRequest(id, query).WithOrReplace(sdk.Bool(true)))
err := c.client().Create(ctx, sdk.NewCreateViewRequest(id, query).WithOrReplace(true))
require.NoError(t, err)

view, err := c.client().ShowByID(ctx, id)
Expand All @@ -57,7 +57,7 @@ func (c *ViewClient) DropViewFunc(t *testing.T, id sdk.SchemaObjectIdentifier) f
ctx := context.Background()

return func() {
err := c.client().Drop(ctx, sdk.NewDropViewRequest(id).WithIfExists(sdk.Bool(true)))
err := c.client().Drop(ctx, sdk.NewDropViewRequest(id).WithIfExists(true))
require.NoError(t, err)
}
}
2 changes: 1 addition & 1 deletion pkg/datasources/views.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func ReadViews(d *schema.ResourceData, meta interface{}) error {

schemaId := sdk.NewDatabaseObjectIdentifier(databaseName, schemaName)
extractedViews, err := client.Views.Show(ctx, sdk.NewShowViewRequest().WithIn(
&sdk.In{Schema: schemaId},
sdk.ExtendedIn{In: sdk.In{Schema: schemaId}},
))
if err != nil {
log.Printf("[DEBUG] failed when searching views in schema (%s), err = %s", schemaId.FullyQualifiedName(), err.Error())
Expand Down
26 changes: 13 additions & 13 deletions pkg/resources/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,19 @@ func CreateView(d *schema.ResourceData, meta interface{}) error {
createRequest := sdk.NewCreateViewRequest(id, s)

if v, ok := d.GetOk("or_replace"); ok && v.(bool) {
createRequest.WithOrReplace(sdk.Bool(true))
createRequest.WithOrReplace(true)
}

if v, ok := d.GetOk("is_secure"); ok && v.(bool) {
createRequest.WithSecure(sdk.Bool(true))
createRequest.WithSecure(true)
}

if v, ok := d.GetOk("copy_grants"); ok && v.(bool) {
createRequest.WithCopyGrants(sdk.Bool(true))
createRequest.WithCopyGrants(true)
}

if v, ok := d.GetOk("comment"); ok {
createRequest.WithComment(sdk.String(v.(string)))
createRequest.WithComment(v.(string))
}

err := client.Views.Create(ctx, createRequest)
Expand Down Expand Up @@ -206,13 +206,13 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error {
oldTags, _ := d.GetChange("tag")

createRequest := sdk.NewCreateViewRequest(id, d.Get("statement").(string)).
WithOrReplace(sdk.Bool(true)).
WithCopyGrants(sdk.Bool(true)).
WithComment(sdk.String(oldComment.(string))).
WithOrReplace(true).
WithCopyGrants(true).
WithComment(oldComment.(string)).
WithTag(getTagsFromList(oldTags.([]any)))

if oldIsSecure.(bool) {
createRequest.WithSecure(sdk.Bool(true))
createRequest.WithSecure(true)
}

err := client.Views.Create(ctx, createRequest)
Expand All @@ -224,7 +224,7 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error {
if d.HasChange("name") {
newId := sdk.NewSchemaObjectIdentifierInSchema(id.SchemaId(), d.Get("name").(string))

err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithRenameTo(&newId))
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithRenameTo(newId))
if err != nil {
return fmt.Errorf("error renaming view %v err = %w", d.Id(), err)
}
Expand All @@ -235,12 +235,12 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error {

if d.HasChange("comment") {
if comment := d.Get("comment").(string); comment == "" {
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithUnsetComment(sdk.Bool(true)))
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithUnsetComment(true))
if err != nil {
return fmt.Errorf("error unsetting comment for view %v", d.Id())
}
} else {
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithSetComment(sdk.String(comment)))
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithSetComment(comment))
if err != nil {
return fmt.Errorf("error updating comment for view %v", d.Id())
}
Expand All @@ -249,12 +249,12 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error {

if d.HasChange("is_secure") {
if d.Get("is_secure").(bool) {
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithSetSecure(sdk.Bool(true)))
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithSetSecure(true))
if err != nil {
return fmt.Errorf("error setting secure for view %v", d.Id())
}
} else {
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithUnsetSecure(sdk.Bool(true)))
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithUnsetSecure(true))
if err != nil {
return fmt.Errorf("error unsetting secure for view %v", d.Id())
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/sdk/common_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ type In struct {
Schema DatabaseObjectIdentifier `ddl:"identifier" sql:"SCHEMA"`
}

type ExtendedIn struct {
In
Application AccountObjectIdentifier `ddl:"identifier" sql:"APPLICATION"`
ApplicationPackage AccountObjectIdentifier `ddl:"identifier" sql:"APPLICATION PACKAGE"`
}

type Like struct {
Pattern *string `ddl:"keyword,single_quotes"`
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/sdk/poc/generator/keyword_builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ func (v *QueryStruct) OptionalIn() *QueryStruct {
return v.PredefinedQueryStructField("In", "*In", KeywordOptions().SQL("IN"))
}

func (v *QueryStruct) OptionalExtendedIn() *QueryStruct {
return v.PredefinedQueryStructField("In", "*ExtendedIn", KeywordOptions().SQL("IN"))
}

func (v *QueryStruct) OptionalStartsWith() *QueryStruct {
return v.PredefinedQueryStructField("StartsWith", "*string", ParameterOptions().NoEquals().SingleQuotes().SQL("STARTS WITH"))
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sdk/poc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"strings"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/genhelpers"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/example"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/generator"
Expand Down
10 changes: 5 additions & 5 deletions pkg/sdk/testint/event_tables_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func TestInt_EventTables(t *testing.T) {
err := client.EventTables.Alter(ctx, alterRequest)
require.NoError(t, err)

e, err := testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable)
e, err := testClientHelper().RowAccessPolicy.GetOneRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable)
require.NoError(t, err)
assert.Equal(t, rowAccessPolicy.ID().Name(), e.PolicyName)
assert.Equal(t, "ROW_ACCESS_POLICY", e.PolicyKind)
Expand All @@ -237,15 +237,15 @@ func TestInt_EventTables(t *testing.T) {
err = client.EventTables.Alter(ctx, alterRequest)
require.NoError(t, err)

_, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable)
_, err = testClientHelper().RowAccessPolicy.GetOneRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable)
require.Error(t, err, "no rows in result set")

// add policy again
alterRequest = sdk.NewAlterEventTableRequest(table.ID()).WithAddRowAccessPolicy(sdk.NewEventTableAddRowAccessPolicyRequest(rowAccessPolicy.ID(), []string{"id"}))
err = client.EventTables.Alter(ctx, alterRequest)
require.NoError(t, err)

e, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable)
e, err = testClientHelper().RowAccessPolicy.GetOneRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable)
require.NoError(t, err)
assert.Equal(t, rowAccessPolicy.ID().Name(), e.PolicyName)

Expand All @@ -257,7 +257,7 @@ func TestInt_EventTables(t *testing.T) {
err = client.EventTables.Alter(ctx, alterRequest)
require.NoError(t, err)

e, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable)
e, err = testClientHelper().RowAccessPolicy.GetOneRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeTable)
require.NoError(t, err)
assert.Equal(t, rowAccessPolicy2.ID().Name(), e.PolicyName)

Expand All @@ -266,7 +266,7 @@ func TestInt_EventTables(t *testing.T) {
err = client.EventTables.Alter(ctx, alterRequest)
require.NoError(t, err)

_, err = testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeView)
_, err = testClientHelper().RowAccessPolicy.GetOneRowAccessPolicyFor(t, table.ID(), sdk.ObjectTypeView)
require.Error(t, err, "no rows in result set")
})
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sdk/testint/materialized_views_gen_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func TestInt_MaterializedViews(t *testing.T) {
view := createMaterializedViewWithRequest(t, request)

assertMaterializedViewWithOptions(t, view, id, true, "comment", fmt.Sprintf(`LINEAR("%s")`, "COLUMN_WITH_COMMENT"))
rowAccessPolicyReference, err := testClientHelper().RowAccessPolicy.GetRowAccessPolicyFor(t, view.ID(), sdk.ObjectTypeView)
rowAccessPolicyReference, err := testClientHelper().RowAccessPolicy.GetOneRowAccessPolicyFor(t, view.ID(), sdk.ObjectTypeView)
require.NoError(t, err)
assert.Equal(t, rowAccessPolicy.Name, rowAccessPolicyReference.PolicyName)
assert.Equal(t, "ROW_ACCESS_POLICY", rowAccessPolicyReference.PolicyKind)
Expand Down
Loading

0 comments on commit 3ea1e0d

Please sign in to comment.