diff --git a/ppsci/arch/gan.py b/ppsci/arch/gan.py index 3daf311e6..9a3f4a8ad 100644 --- a/ppsci/arch/gan.py +++ b/ppsci/arch/gan.py @@ -182,6 +182,14 @@ class Generator(base.Arch): >>> use_bns_tuple = ((True, True, True), ) * 3 + ((False, False, False), ) >>> acts_tuple = (("relu", None, None), ) * 4 >>> model = ppsci.arch.Generator(("in",), ("out",), in_channel, out_channels_tuple, kernel_sizes_tuple, strides_tuple, use_bns_tuple, acts_tuple) + >>> batch_size = 4 + >>> height = 64 + >>> width = 64 + >>> input_data = paddle.randn([batch_size, in_channel, height, width]) + >>> input_dict = {'in': input_data} + >>> output_data = model(input_dict) + >>> print(output_data['out'].shape) + [4, 1, 64, 64] """ def __init__(