Skip to content

Commit

Permalink
fix: clean up zod generation (zenstackhq#883)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymc9 authored Dec 8, 2023
1 parent aa705a4 commit 9d4a8ed
Show file tree
Hide file tree
Showing 17 changed files with 114 additions and 117 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "zenstack-monorepo",
"version": "1.4.0",
"version": "1.4.1",
"description": "",
"scripts": {
"build": "pnpm -r build",
Expand Down
2 changes: 1 addition & 1 deletion packages/language/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/language",
"version": "1.4.0",
"version": "1.4.1",
"displayName": "ZenStack modeling language compiler",
"description": "ZenStack modeling language compiler",
"homepage": "https://zenstack.dev",
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/openapi/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/openapi",
"displayName": "ZenStack Plugin and Runtime for OpenAPI",
"version": "1.4.0",
"version": "1.4.1",
"description": "ZenStack plugin and runtime supporting OpenAPI",
"main": "index.js",
"repository": {
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/swr/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/swr",
"displayName": "ZenStack plugin for generating SWR hooks",
"version": "1.4.0",
"version": "1.4.1",
"description": "ZenStack plugin for generating SWR hooks",
"main": "index.js",
"repository": {
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/tanstack-query/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/tanstack-query",
"displayName": "ZenStack plugin for generating tanstack-query hooks",
"version": "1.4.0",
"version": "1.4.1",
"description": "ZenStack plugin for generating tanstack-query hooks",
"main": "index.js",
"exports": {
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/trpc/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/trpc",
"displayName": "ZenStack plugin for tRPC",
"version": "1.4.0",
"version": "1.4.1",
"description": "ZenStack plugin for tRPC",
"main": "index.js",
"repository": {
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/runtime",
"displayName": "ZenStack Runtime Library",
"version": "1.4.0",
"version": "1.4.1",
"description": "Runtime of ZenStack for both client-side and server-side environments.",
"repository": {
"type": "git",
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ export class PolicyUtil {
if (!this.hasFieldValidation(model)) {
return undefined;
}
const schemaKey = `${upperCaseFirst(model)}${kind ? upperCaseFirst(kind) : ''}Schema`;
const schemaKey = `${upperCaseFirst(model)}${kind ? 'Prisma' + upperCaseFirst(kind) : ''}Schema`;
return this.zodSchemas?.models?.[schemaKey];
}

Expand Down
2 changes: 1 addition & 1 deletion packages/schema/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"publisher": "zenstack",
"displayName": "ZenStack Language Tools",
"description": "Build scalable web apps with minimum code by defining authorization and validation rules inside the data schema that closer to the database",
"version": "1.4.0",
"version": "1.4.1",
"author": {
"name": "ZenStack Team"
},
Expand Down
148 changes: 57 additions & 91 deletions packages/schema/src/plugins/zod/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import {
PluginOptions,
createProject,
emitProject,
getAttribute,
getAttributeArg,
getDataModels,
getLiteral,
getPrismaClientImportSpec,
Expand All @@ -17,16 +15,7 @@ import {
resolvePath,
saveProject,
} from '@zenstackhq/sdk';
import {
DataModel,
DataModelField,
DataSource,
EnumField,
Model,
isDataModel,
isDataSource,
isEnum,
} from '@zenstackhq/sdk/ast';
import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast';
import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers';
import { promises as fs } from 'fs';
import { streamAllContents } from 'langium';
Expand Down Expand Up @@ -271,18 +260,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
overwrite: true,
});
sf.replaceWithText((writer) => {
const fields = model.fields.filter(
const scalarFields = model.fields.filter(
(field) =>
// regular fields only
!isDataModel(field.type.reference?.ref) && !isForeignKeyField(field)
);

const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref));
const fkFields = model.fields.filter((field) => isForeignKeyField(field));
// unsafe version of relations: including foreign keys and relation fields without fk
const unsafeRelations = model.fields.filter(
(field) => isForeignKeyField(field) || (isDataModel(field.type.reference?.ref) && !hasForeignKey(field))
);

writer.writeLine('/* eslint-disable */');
writer.writeLine(`import { z } from 'zod';`);
Expand All @@ -304,7 +289,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s

// import enum schemas
const importedEnumSchemas = new Set<string>();
for (const field of fields) {
for (const field of scalarFields) {
if (field.type.reference?.ref && isEnum(field.type.reference?.ref)) {
const name = upperCaseFirst(field.type.reference?.ref.name);
if (!importedEnumSchemas.has(name)) {
Expand All @@ -315,29 +300,28 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
}

// import Decimal
if (fields.some((field) => field.type.type === 'Decimal')) {
if (scalarFields.some((field) => field.type.type === 'Decimal')) {
writer.writeLine(`import { DecimalSchema } from '../common';`);
writer.writeLine(`import { Decimal } from 'decimal.js';`);
}

// base schema
writer.write(`const baseSchema = z.object(`);
writer.inlineBlock(() => {
fields.forEach((field) => {
scalarFields.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
});
});
writer.writeLine(');');

// relation fields

let allRelationSchema: string | undefined;
let safeRelationSchema: string | undefined;
let unsafeRelationSchema: string | undefined;
let relationSchema: string | undefined;
let fkSchema: string | undefined;

if (relations.length > 0 || fkFields.length > 0) {
allRelationSchema = 'allRelationSchema';
writer.write(`const ${allRelationSchema} = z.object(`);
relationSchema = 'relationSchema';
writer.write(`const ${relationSchema} = z.object(`);
writer.inlineBlock(() => {
[...relations, ...fkFields].forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
Expand All @@ -346,23 +330,12 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
writer.writeLine(');');
}

if (relations.length > 0) {
safeRelationSchema = 'safeRelationSchema';
writer.write(`const ${safeRelationSchema} = z.object(`);
if (fkFields.length > 0) {
fkSchema = 'fkSchema';
writer.write(`const ${fkSchema} = z.object(`);
writer.inlineBlock(() => {
relations.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`);
});
});
writer.writeLine(');');
}

if (unsafeRelations.length > 0) {
unsafeRelationSchema = 'unsafeRelationSchema';
writer.write(`const ${unsafeRelationSchema} = z.object(`);
writer.inlineBlock(() => {
unsafeRelations.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`);
fkFields.forEach((field) => {
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
});
});
writer.writeLine(');');
Expand All @@ -383,25 +356,25 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
////////////////////////////////////////////////
// 1. Model schema
////////////////////////////////////////////////
let modelSchema = 'baseSchema';
let modelSchema = makePartial('baseSchema');

// omit fields
const fieldsToOmit = fields.filter((field) => hasAttribute(field, '@omit'));
const fieldsToOmit = scalarFields.filter((field) => hasAttribute(field, '@omit'));
if (fieldsToOmit.length > 0) {
modelSchema = makeOmit(
modelSchema,
fieldsToOmit.map((f) => f.name)
);
}

if (allRelationSchema) {
if (relationSchema) {
// export schema with only scalar fields
const modelScalarSchema = `${upperCaseFirst(model.name)}ScalarSchema`;
writer.writeLine(`export const ${modelScalarSchema} = ${modelSchema};`);
modelSchema = modelScalarSchema;

// merge relations
modelSchema = makeMerge(modelSchema, allRelationSchema);
modelSchema = makeMerge(modelSchema, makePartial(relationSchema));
}

// refine
Expand All @@ -413,10 +386,40 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
writer.writeLine(`export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};`);

////////////////////////////////////////////////
// 2. Create schema
// 2. Prisma create & update
////////////////////////////////////////////////

// schema for validating prisma create input (all fields optional)
let prismaCreateSchema = makePartial('baseSchema');
if (refineFuncName) {
prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`;
}
writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaCreateSchema = ${prismaCreateSchema};`);

// schema for validating prisma update input (all fields optional)
// note numeric fields can be simple update or atomic operations
let prismaUpdateSchema = `z.object({
${scalarFields
.map((field) => {
let fieldSchema = makeFieldSchema(field);
if (field.type.type === 'Int' || field.type.type === 'Float') {
fieldSchema = `z.union([${fieldSchema}, z.record(z.unknown())])`;
}
return `\t${field.name}: ${fieldSchema}`;
})
.join(',\n')}
})`;
prismaUpdateSchema = makePartial(prismaUpdateSchema);
if (refineFuncName) {
prismaUpdateSchema = `${refineFuncName}(${prismaUpdateSchema})`;
}
writer.writeLine(`export const ${upperCaseFirst(model.name)}PrismaUpdateSchema = ${prismaUpdateSchema};`);

////////////////////////////////////////////////
// 3. Create schema
////////////////////////////////////////////////
let createSchema = 'baseSchema';
const fieldsWithDefault = fields.filter(
const fieldsWithDefault = scalarFields.filter(
(field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array
);
if (fieldsWithDefault.length > 0) {
Expand All @@ -426,30 +429,13 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
);
}

if (safeRelationSchema || unsafeRelationSchema) {
if (fkSchema) {
// export schema with only scalar fields
const createScalarSchema = `${upperCaseFirst(model.name)}CreateScalarSchema`;
writer.writeLine(`export const ${createScalarSchema} = ${createSchema};`);
createSchema = createScalarSchema;

if (safeRelationSchema && unsafeRelationSchema) {
// build a union of with relation object fields and with fk fields (mutually exclusive)

// TODO: we make all relation fields partial for now because in case of
// nested create, not all relation/fk fields are inside payload, need a
// better solution
createSchema = makeUnion(
makeMerge(createSchema, makePartial(safeRelationSchema)),
makeMerge(createSchema, makePartial(unsafeRelationSchema))
);
} else if (safeRelationSchema) {
// just relation

// TODO: we make all relation fields partial for now because in case of
// nested create, not all relation/fk fields are inside payload, need a
// better solution
createSchema = makeMerge(createSchema, makePartial(safeRelationSchema));
}

// merge fk fields
createSchema = makeMerge(createScalarSchema, fkSchema);
}

if (refineFuncName) {
Expand All @@ -465,22 +451,14 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
////////////////////////////////////////////////
let updateSchema = makePartial('baseSchema');

if (safeRelationSchema || unsafeRelationSchema) {
if (fkSchema) {
// export schema with only scalar fields
const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`;
writer.writeLine(`export const ${updateScalarSchema} = ${updateSchema};`);
updateSchema = updateScalarSchema;

if (safeRelationSchema && unsafeRelationSchema) {
// build a union of with relation object fields and with fk fields (mutually exclusive)
updateSchema = makeUnion(
makeMerge(updateSchema, makePartial(safeRelationSchema)),
makeMerge(updateSchema, makePartial(unsafeRelationSchema))
);
} else if (safeRelationSchema) {
// just relation
updateSchema = makeMerge(updateSchema, makePartial(safeRelationSchema));
}
// merge fk fields
updateSchema = makeMerge(updateSchema, makePartial(fkSchema));
}

if (refineFuncName) {
Expand Down Expand Up @@ -514,15 +492,3 @@ function makeOmit(schema: string, fields: string[]) {
function makeMerge(schema1: string, schema2: string): string {
return `${schema1}.merge(${schema2})`;
}

function makeUnion(...schemas: string[]): string {
return `z.union([${schemas.join(', ')}])`;
}

function hasForeignKey(field: DataModelField) {
const relAttr = getAttribute(field, '@relation');
if (!relAttr) {
return false;
}
return !!getAttributeArg(relAttr, 'fields');
}
17 changes: 5 additions & 12 deletions packages/schema/src/plugins/zod/utils/schema-gen.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,13 @@ import {
TypeScriptExpressionTransformerError,
} from '../../../utils/typescript-expression-transformer';

export function makeFieldSchema(field: DataModelField, forMutation = false) {
export function makeFieldSchema(field: DataModelField) {
if (isDataModel(field.type.reference?.ref)) {
if (!forMutation) {
// read schema, always optional
if (field.type.array) {
return `z.array(z.unknown()).optional()`;
} else {
return `z.record(z.unknown()).optional()`;
}
if (field.type.array) {
// array field is always optional
return `z.array(z.unknown()).optional()`;
} else {
// write schema
return `${
field.type.optional || field.type.array ? 'z.record(z.unknown()).optional()' : 'z.record(z.unknown())'
}`;
return field.type.optional ? `z.record(z.unknown()).optional()` : `z.record(z.unknown())`;
}
}

Expand Down
2 changes: 1 addition & 1 deletion packages/sdk/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/sdk",
"version": "1.4.0",
"version": "1.4.1",
"description": "ZenStack plugin development SDK",
"main": "index.js",
"scripts": {
Expand Down
2 changes: 1 addition & 1 deletion packages/server/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/server",
"version": "1.4.0",
"version": "1.4.1",
"displayName": "ZenStack Server-side Adapters",
"description": "ZenStack server-side adapters",
"homepage": "https://zenstack.dev",
Expand Down
2 changes: 1 addition & 1 deletion packages/testtools/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/testtools",
"version": "1.4.0",
"version": "1.4.1",
"description": "ZenStack Test Tools",
"main": "index.js",
"private": true,
Expand Down
Loading

0 comments on commit 9d4a8ed

Please sign in to comment.