Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Upgrade view sdk #2969

Merged
merged 10 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions pkg/acceptance/helpers/aggregation_policy_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package helpers

import (
"context"
"fmt"
"testing"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/stretchr/testify/require"
)

// TODO(SNOW-1564954): change raw sqls to proper client
type AggregationPolicyClient struct {
context *TestClientContext
ids *IdsGenerator
}

func NewAggregationPolicyClient(context *TestClientContext, idsGenerator *IdsGenerator) *AggregationPolicyClient {
return &AggregationPolicyClient{
context: context,
ids: idsGenerator,
}
}

func (c *AggregationPolicyClient) client() *sdk.Client {
return c.context.client
}

func (c *AggregationPolicyClient) CreateAggregationPolicy(t *testing.T) (sdk.SchemaObjectIdentifier, func()) {
t.Helper()
ctx := context.Background()

id := c.ids.RandomSchemaObjectIdentifier()
_, err := c.client().ExecForTests(ctx, fmt.Sprintf(`CREATE AGGREGATION POLICY %s AS () RETURNS AGGREGATION_CONSTRAINT -> AGGREGATION_CONSTRAINT(MIN_GROUP_SIZE => 5)`, id.Name()))
require.NoError(t, err)
return id, c.DropAggregationPolicyFunc(t, id)
}

func (c *AggregationPolicyClient) DropAggregationPolicyFunc(t *testing.T, id sdk.SchemaObjectIdentifier) func() {
t.Helper()
ctx := context.Background()

return func() {
_, err := c.client().ExecForTests(ctx, fmt.Sprintf(`DROP AGGREGATION POLICY IF EXISTS %s`, id.Name()))
require.NoError(t, err)
}
}
18 changes: 18 additions & 0 deletions pkg/acceptance/helpers/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package helpers

import (
"context"
"database/sql"
"fmt"
"log"
"testing"
Expand Down Expand Up @@ -70,3 +71,20 @@ func AssertErrorContainsPartsFunc(t *testing.T, parts []string) resource.ErrorCh
return nil
}
}

type PolicyReference struct {
sfc-gh-jcieslak marked this conversation as resolved.
Show resolved Hide resolved
PolicyDb string `db:"POLICY_DB"`
PolicySchema string `db:"POLICY_SCHEMA"`
PolicyName string `db:"POLICY_NAME"`
PolicyKind string `db:"POLICY_KIND"`
RefDatabaseName string `db:"REF_DATABASE_NAME"`
RefSchemaName string `db:"REF_SCHEMA_NAME"`
RefEntityName string `db:"REF_ENTITY_NAME"`
RefEntityDomain string `db:"REF_ENTITY_DOMAIN"`
RefColumnName sql.NullString `db:"REF_COLUMN_NAME"`
RefArgColumnNames sql.NullString `db:"REF_ARG_COLUMN_NAMES"`
TagDatabase sql.NullString `db:"TAG_DATABASE"`
TagSchema sql.NullString `db:"TAG_SCHEMA"`
TagName sql.NullString `db:"TAG_NAME"`
PolicyStatus string `db:"POLICY_STATUS"`
}
47 changes: 47 additions & 0 deletions pkg/acceptance/helpers/projection_policy_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package helpers

import (
"context"
"fmt"
"testing"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/stretchr/testify/require"
)

// TODO(SNOW-1564959): change raw sqls to proper client
type ProjectionPolicyClient struct {
context *TestClientContext
ids *IdsGenerator
}

func NewProjectionPolicyClient(context *TestClientContext, idsGenerator *IdsGenerator) *ProjectionPolicyClient {
return &ProjectionPolicyClient{
context: context,
ids: idsGenerator,
}
}

func (c *ProjectionPolicyClient) client() *sdk.Client {
return c.context.client
}

func (c *ProjectionPolicyClient) CreateProjectionPolicy(t *testing.T) (sdk.SchemaObjectIdentifier, func()) {
t.Helper()
ctx := context.Background()

id := c.ids.RandomSchemaObjectIdentifier()
_, err := c.client().ExecForTests(ctx, fmt.Sprintf(`CREATE PROJECTION POLICY %s AS () RETURNS PROJECTION_CONSTRAINT -> PROJECTION_CONSTRAINT(ALLOW => false)`, id.Name()))
require.NoError(t, err)
return id, c.DropProjectionPolicyFunc(t, id)
}

func (c *ProjectionPolicyClient) DropProjectionPolicyFunc(t *testing.T, id sdk.SchemaObjectIdentifier) func() {
t.Helper()
ctx := context.Background()

return func() {
_, err := c.client().ExecForTests(ctx, fmt.Sprintf(`DROP PROJECTION POLICY IF EXISTS %s`, id.Name()))
require.NoError(t, err)
}
}
35 changes: 35 additions & 0 deletions pkg/acceptance/helpers/references_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package helpers

import (
"context"
"fmt"
"testing"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
)

type ReferencesClient struct {
context *TestClientContext
}

func NewReferencesClient(context *TestClientContext) *ReferencesClient {
return &ReferencesClient{
context: context,
}
}

func (c *ReferencesClient) client() sdk.RowAccessPolicies {
return c.context.client.RowAccessPolicies
}

// GetAllPolicyReferences is based on https://docs.snowflake.com/en/sql-reference/functions/policy_references.
func (c *ReferencesClient) GetAllPolicyReferences(t *testing.T, id sdk.SchemaObjectIdentifier, objectType sdk.ObjectType) ([]PolicyReference, error) {
t.Helper()
ctx := context.Background()

s := []PolicyReference{}
policyReferencesId := sdk.NewSchemaObjectIdentifier(id.DatabaseName(), "INFORMATION_SCHEMA", "POLICY_REFERENCES")
err := c.context.client.QueryForTests(ctx, &s, fmt.Sprintf(`SELECT * FROM TABLE(%s(REF_ENTITY_NAME => '%s', REF_ENTITY_DOMAIN => '%v'))`, policyReferencesId.FullyQualifiedName(), id.FullyQualifiedName(), objectType))

return s, err
}
20 changes: 1 addition & 19 deletions pkg/acceptance/helpers/row_access_policy_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package helpers

import (
"context"
"database/sql"
"fmt"
"testing"

Expand Down Expand Up @@ -56,7 +55,7 @@ func (c *RowAccessPolicyClient) DropRowAccessPolicyFunc(t *testing.T, id sdk.Sch

// GetRowAccessPolicyFor is based on https://docs.snowflake.com/en/user-guide/security-row-intro#obtain-database-objects-with-a-row-access-policy.
// TODO: extract getting row access policies as resource (like getting tag in system functions)
func (c *RowAccessPolicyClient) GetRowAccessPolicyFor(t *testing.T, id sdk.SchemaObjectIdentifier, objectType sdk.ObjectType) (*PolicyReference, error) {
func (c *RowAccessPolicyClient) GetOneRowAccessPolicyFor(t *testing.T, id sdk.SchemaObjectIdentifier, objectType sdk.ObjectType) (*PolicyReference, error) {
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
t.Helper()
ctx := context.Background()

Expand All @@ -66,20 +65,3 @@ func (c *RowAccessPolicyClient) GetRowAccessPolicyFor(t *testing.T, id sdk.Schem

return s, err
}

type PolicyReference struct {
PolicyDb string `db:"POLICY_DB"`
PolicySchema string `db:"POLICY_SCHEMA"`
PolicyName string `db:"POLICY_NAME"`
PolicyKind string `db:"POLICY_KIND"`
RefDatabaseName string `db:"REF_DATABASE_NAME"`
RefSchemaName string `db:"REF_SCHEMA_NAME"`
RefEntityName string `db:"REF_ENTITY_NAME"`
RefEntityDomain string `db:"REF_ENTITY_DOMAIN"`
RefColumnName sql.NullString `db:"REF_COLUMN_NAME"`
RefArgColumnNames string `db:"REF_ARG_COLUMN_NAMES"`
TagDatabase sql.NullString `db:"TAG_DATABASE"`
TagSchema sql.NullString `db:"TAG_SCHEMA"`
TagName sql.NullString `db:"TAG_NAME"`
PolicyStatus string `db:"POLICY_STATUS"`
}
6 changes: 6 additions & 0 deletions pkg/acceptance/helpers/test_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type TestClient struct {
Ids *IdsGenerator

Account *AccountClient
AggregationPolicy *AggregationPolicyClient
Alert *AlertClient
ApiIntegration *ApiIntegrationClient
Application *ApplicationClient
Expand All @@ -31,6 +32,8 @@ type TestClient struct {
Parameter *ParameterClient
PasswordPolicy *PasswordPolicyClient
Pipe *PipeClient
ProjectionPolicy *ProjectionPolicyClient
References *ReferencesClient
ResourceMonitor *ResourceMonitorClient
Role *RoleClient
RowAccessPolicy *RowAccessPolicyClient
Expand Down Expand Up @@ -63,6 +66,7 @@ func NewTestClient(c *sdk.Client, database string, schema string, warehouse stri
Ids: idsGenerator,

Account: NewAccountClient(context),
AggregationPolicy: NewAggregationPolicyClient(context, idsGenerator),
Alert: NewAlertClient(context, idsGenerator),
ApiIntegration: NewApiIntegrationClient(context, idsGenerator),
Application: NewApplicationClient(context, idsGenerator),
Expand All @@ -84,6 +88,8 @@ func NewTestClient(c *sdk.Client, database string, schema string, warehouse stri
Parameter: NewParameterClient(context),
PasswordPolicy: NewPasswordPolicyClient(context, idsGenerator),
Pipe: NewPipeClient(context, idsGenerator),
ProjectionPolicy: NewProjectionPolicyClient(context, idsGenerator),
References: NewReferencesClient(context),
sfc-gh-jcieslak marked this conversation as resolved.
Show resolved Hide resolved
ResourceMonitor: NewResourceMonitorClient(context, idsGenerator),
Role: NewRoleClient(context, idsGenerator),
RowAccessPolicy: NewRowAccessPolicyClient(context, idsGenerator),
Expand Down
4 changes: 2 additions & 2 deletions pkg/acceptance/helpers/view_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func (c *ViewClient) RecreateView(t *testing.T, id sdk.SchemaObjectIdentifier, q
t.Helper()
ctx := context.Background()

err := c.client().Create(ctx, sdk.NewCreateViewRequest(id, query).WithOrReplace(sdk.Bool(true)))
err := c.client().Create(ctx, sdk.NewCreateViewRequest(id, query).WithOrReplace(true))
require.NoError(t, err)

view, err := c.client().ShowByID(ctx, id)
Expand All @@ -57,7 +57,7 @@ func (c *ViewClient) DropViewFunc(t *testing.T, id sdk.SchemaObjectIdentifier) f
ctx := context.Background()

return func() {
err := c.client().Drop(ctx, sdk.NewDropViewRequest(id).WithIfExists(sdk.Bool(true)))
err := c.client().Drop(ctx, sdk.NewDropViewRequest(id).WithIfExists(true))
require.NoError(t, err)
}
}
2 changes: 1 addition & 1 deletion pkg/datasources/views.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func ReadViews(d *schema.ResourceData, meta interface{}) error {

schemaId := sdk.NewDatabaseObjectIdentifier(databaseName, schemaName)
extractedViews, err := client.Views.Show(ctx, sdk.NewShowViewRequest().WithIn(
&sdk.In{Schema: schemaId},
sdk.ExtendedIn{In: sdk.In{Schema: schemaId}},
))
if err != nil {
log.Printf("[DEBUG] failed when searching views in schema (%s), err = %s", schemaId.FullyQualifiedName(), err.Error())
Expand Down
26 changes: 13 additions & 13 deletions pkg/resources/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,19 @@ func CreateView(d *schema.ResourceData, meta interface{}) error {
createRequest := sdk.NewCreateViewRequest(id, s)

if v, ok := d.GetOk("or_replace"); ok && v.(bool) {
createRequest.WithOrReplace(sdk.Bool(true))
createRequest.WithOrReplace(true)
}

if v, ok := d.GetOk("is_secure"); ok && v.(bool) {
createRequest.WithSecure(sdk.Bool(true))
createRequest.WithSecure(true)
}

if v, ok := d.GetOk("copy_grants"); ok && v.(bool) {
createRequest.WithCopyGrants(sdk.Bool(true))
createRequest.WithCopyGrants(true)
}

if v, ok := d.GetOk("comment"); ok {
createRequest.WithComment(sdk.String(v.(string)))
createRequest.WithComment(v.(string))
}

err := client.Views.Create(ctx, createRequest)
Expand Down Expand Up @@ -206,13 +206,13 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error {
oldTags, _ := d.GetChange("tag")

createRequest := sdk.NewCreateViewRequest(id, d.Get("statement").(string)).
WithOrReplace(sdk.Bool(true)).
WithCopyGrants(sdk.Bool(true)).
WithComment(sdk.String(oldComment.(string))).
WithOrReplace(true).
WithCopyGrants(true).
WithComment(oldComment.(string)).
WithTag(getTagsFromList(oldTags.([]any)))

if oldIsSecure.(bool) {
createRequest.WithSecure(sdk.Bool(true))
createRequest.WithSecure(true)
}

err := client.Views.Create(ctx, createRequest)
Expand All @@ -224,7 +224,7 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error {
if d.HasChange("name") {
newId := sdk.NewSchemaObjectIdentifierInSchema(id.SchemaId(), d.Get("name").(string))

err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithRenameTo(&newId))
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithRenameTo(newId))
if err != nil {
return fmt.Errorf("error renaming view %v err = %w", d.Id(), err)
}
Expand All @@ -235,12 +235,12 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error {

if d.HasChange("comment") {
if comment := d.Get("comment").(string); comment == "" {
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithUnsetComment(sdk.Bool(true)))
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithUnsetComment(true))
if err != nil {
return fmt.Errorf("error unsetting comment for view %v", d.Id())
}
} else {
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithSetComment(sdk.String(comment)))
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithSetComment(comment))
if err != nil {
return fmt.Errorf("error updating comment for view %v", d.Id())
}
Expand All @@ -249,12 +249,12 @@ func UpdateView(d *schema.ResourceData, meta interface{}) error {

if d.HasChange("is_secure") {
if d.Get("is_secure").(bool) {
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithSetSecure(sdk.Bool(true)))
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithSetSecure(true))
if err != nil {
return fmt.Errorf("error setting secure for view %v", d.Id())
}
} else {
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithUnsetSecure(sdk.Bool(true)))
err := client.Views.Alter(ctx, sdk.NewAlterViewRequest(id).WithUnsetSecure(true))
if err != nil {
return fmt.Errorf("error unsetting secure for view %v", d.Id())
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/sdk/common_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ type In struct {
Schema DatabaseObjectIdentifier `ddl:"identifier" sql:"SCHEMA"`
}

type ExtendedIn struct {
sfc-gh-asawicki marked this conversation as resolved.
Show resolved Hide resolved
In
Application AccountObjectIdentifier `ddl:"identifier" sql:"APPLICATION"`
ApplicationPackage AccountObjectIdentifier `ddl:"identifier" sql:"APPLICATION PACKAGE"`
}

type Like struct {
Pattern *string `ddl:"keyword,single_quotes"`
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/sdk/poc/generator/keyword_builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ func (v *QueryStruct) OptionalIn() *QueryStruct {
return v.PredefinedQueryStructField("In", "*In", KeywordOptions().SQL("IN"))
}

func (v *QueryStruct) OptionalExtendedIn() *QueryStruct {
return v.PredefinedQueryStructField("In", "*ExtendedIn", KeywordOptions().SQL("IN"))
}

func (v *QueryStruct) OptionalStartsWith() *QueryStruct {
return v.PredefinedQueryStructField("StartsWith", "*string", ParameterOptions().NoEquals().SingleQuotes().SQL("STARTS WITH"))
}
Expand Down
1 change: 1 addition & 0 deletions pkg/sdk/poc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"strings"

"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/genhelpers"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/example"
"github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/generator"
Expand Down
Loading
Loading