-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Save Model Inputs, Model Outputs, Gradients, Custom Tensors, Layer Inputs, Layer Outputs #282
Merged
Merged
Changes from all commits
Commits
Show all changes
155 commits
Select commit
Hold shift + click to select a range
be4f48a
save outputs
NihalHarish d32d017
assert updates
NihalHarish 8e95f12
update assert
NihalHarish 48f45d6
cleanup
NihalHarish 55f10d4
as_dtype:
NihalHarish ec82021
model outputs are now constants
NihalHarish 666bcd4
update to test
NihalHarish d867a9b
update import statement
NihalHarish 5fd3a74
tmp
NihalHarish 11c20c6
Revert "tmp"
NihalHarish 7f260e5
str_to_mode
NihalHarish 345f785
add tensor
NihalHarish beaa68d
add tensor
NihalHarish ab3d5c1
add dist tensor:
NihalHarish 61372e8
add tensor
NihalHarish 46c5e0f
for-loop
NihalHarish 650fd6a
fix append
NihalHarish 42fdc3a
fix assert
NihalHarish 16b38d1
add
NihalHarish 9e1d2c5
model output
NihalHarish 14d911b
rename
NihalHarish 20d0413
add to all collections
NihalHarish d46ebb6
revert
NihalHarish 960d383
add to all
NihalHarish 67f4efc
helper fn
NihalHarish 2df341e
helper fn
NihalHarish 94765d2
extend returns none
NihalHarish 9eff79b
ypred
NihalHarish 61d94e1
ypred
NihalHarish 07d72d3
change assert
NihalHarish d8a8ea9
init
NihalHarish f7ead88
do not match in metric
NihalHarish 6e24ca8
update
NihalHarish cda4e3e
inputs
NihalHarish 9b59d0d
id
NihalHarish 9e5606e
save outputs
NihalHarish 11ddcdd
assert updates
NihalHarish 34d2294
update assert
NihalHarish f87ce01
cleanup
NihalHarish bbb0dc6
as_dtype:
NihalHarish 82f0531
model outputs are now constants
NihalHarish 4663370
update to test
NihalHarish c64a7a1
update import statement
NihalHarish 15c1d61
tmp
NihalHarish be6186f
Revert "tmp"
NihalHarish ae8f96b
str_to_mode
NihalHarish 30bd425
add tensor
NihalHarish 1e7aa1b
add tensor
NihalHarish 85ea95a
add dist tensor:
NihalHarish 95b8bcc
add tensor
NihalHarish 07fd399
for-loop
NihalHarish 7151978
fix append
NihalHarish 72a7256
fix assert
NihalHarish 046d165
add
NihalHarish 070cd6f
model output
NihalHarish 8af4ce8
rename
NihalHarish 1761ca2
add to all collections
NihalHarish 6b581bf
revert
NihalHarish 6b14ee7
add to all
NihalHarish 5c89dff
helper fn
NihalHarish cc13566
helper fn
NihalHarish d07dd47
extend returns none
NihalHarish 766902a
ypred
NihalHarish 4e1b802
ypred
NihalHarish 5782846
change assert
NihalHarish f745186
Merge branch 'y_pred' of https://github.com/awslabs/sagemaker-debugge…
NihalHarish 07c6e75
init
NihalHarish 0d8c6cb
do not match in metric
NihalHarish ae526c0
update
NihalHarish bf82f9c
inputs
NihalHarish 101fcb2
id
NihalHarish cdaf7f8
Merge branch 'save_model_inputs' of https://github.com/awslabs/sagema…
NihalHarish bc84269
test
NihalHarish 5091415
fuse model inputs and outputs
NihalHarish 13ce988
set fix
NihalHarish 460e0e0
add tests
NihalHarish c20cc75
update test
NihalHarish 5766aa2
eager mode
NihalHarish 0428d62
update tests
NihalHarish 54ad7a5
rename fn
NihalHarish 40ded77
remove unused imports
NihalHarish 9ead6fa
save custom tensor fn
NihalHarish c9a6198
test_
NihalHarish 7c7fbb3
revert tests
NihalHarish ab8d103
save custom tensor fn
NihalHarish 63babf7
test_
NihalHarish 9633e2e
save custom tensor
NihalHarish a997bfa
save custom tensor
NihalHarish 1376045
init
NihalHarish 05b28c5
save gradients
NihalHarish 9ae86df
ignore smdebug metrics
NihalHarish c8a0844
update assert
NihalHarish 3db6856
gradients
NihalHarish 32affd2
save inputs
NihalHarish 582cd6e
merge master
NihalHarish ccde310
checks
NihalHarish 4e14182
change assert
NihalHarish a68dc3e
check if collection should be saved
NihalHarish 712f94b
set
NihalHarish cdb0882
revert assert
NihalHarish c692d8f
revert assert
NihalHarish cac439d
save inputs
NihalHarish cd36430
change regex
NihalHarish 60d671b
modify tests
NihalHarish 73b5362
collection
NihalHarish abdc64b
save fn
NihalHarish 027b022
move test
NihalHarish 6c5e4c9
run only for tf2
NihalHarish 29e1319
mark skip
NihalHarish 9e9092b
fn rename
NihalHarish e97de64
rename fn
NihalHarish cec3e09
correct boolean logic
NihalHarish 90a8f23
fix input output logic
NihalHarish 06ebf84
comments
NihalHarish 15851de
grad tape example
NihalHarish 41ca695
save layers
NihalHarish af1e411
rename
NihalHarish 8cdd13e
change boolean logic
NihalHarish 03e4f18
bug fix
NihalHarish 2660a76
retrigger CI
NihalHarish fccf7e8
fix flag
NihalHarish f221f74
duplicate set
NihalHarish 480db00
pred
NihalHarish c0817b9
nit
NihalHarish cb79e19
Merge remote-tracking branch 'origin' into save_inputs
NihalHarish 80a65c7
update
NihalHarish e7cb92a
rename default collection
NihalHarish 39b65df
model inputs
NihalHarish ca68f77
lint
NihalHarish 281011d
update tests
NihalHarish 74de9c9
modify assert
NihalHarish 6dd95d7
Merge remote-tracking branch 'origin' into save_inputs
NihalHarish 9abe494
modify assert
NihalHarish 33c21c0
save Layers
NihalHarish 7bd87c8
clear saved collections after saving
NihalHarish 651d6ea
refactor
NihalHarish 1aaabe7
nit
NihalHarish 6d3b733
pr comments
NihalHarish 0f08773
save tensor api
NihalHarish 3015cae
revert typo
NihalHarish cca7fea
save custom tensors
NihalHarish bbf1bf6
pr comments
NihalHarish a32a8d4
len
NihalHarish 259414a
default
NihalHarish b1ad7a0
save smdebug logs
NihalHarish d3b54c3
comments
NihalHarish fb548a9
update
NihalHarish 7ca2942
constants
NihalHarish 2df55e0
Implement Save Tensor For Mxnet and Pytorch (#291)
NihalHarish 067e724
parameterize test keras fit
NihalHarish 49550e8
tf eager
NihalHarish cb44a7d
nit
NihalHarish b67fa45
nit and remove duped fn
NihalHarish 075b2a0
refactor
NihalHarish 1a1838e
retrigger CI
NihalHarish File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
""" | ||
This file is temporary, for testing with 2.X. | ||
We'll need to integrate a more robust testing pipeline and make this part of pytest | ||
before pushing to master. | ||
|
||
This was tested with TensorFlow 2.1, by running | ||
`python tests/tensorflow2/test_keras.py` from the main directory. | ||
""" | ||
# Standard Library | ||
import shutil | ||
|
||
# Third Party | ||
import pytest | ||
import tensorflow.compat.v2 as tf | ||
|
||
# First Party | ||
import smdebug.tensorflow as smd | ||
from smdebug.core.collection import CollectionKeys | ||
from smdebug.tensorflow import SaveConfig | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def out_dir(): | ||
""" Use this method to construct an out_dir. | ||
|
||
Then it will be automatically cleaned up for you, passed into the test method, and we'll have | ||
fewer folders lying around. | ||
""" | ||
out_dir = "/tmp/test" | ||
shutil.rmtree(out_dir, ignore_errors=True) | ||
return out_dir | ||
|
||
|
||
def helper_keras_fit( | ||
trial_dir, | ||
save_all=False, | ||
include_collections=None, | ||
reduction_config=None, | ||
save_config=None, | ||
hook=None, | ||
steps=None, | ||
add_callbacks=None, | ||
run_eagerly=False, | ||
): | ||
|
||
mnist = tf.keras.datasets.mnist | ||
(x_train, y_train), (x_test, y_test) = mnist.load_data() | ||
x_train, x_test = x_train / 255, x_test / 255 | ||
|
||
model = tf.keras.models.Sequential( | ||
[ | ||
tf.keras.layers.Flatten(input_shape=(28, 28)), | ||
tf.keras.layers.Dense(128, activation="relu"), | ||
tf.keras.layers.Dropout(0.2), | ||
tf.keras.layers.Dense(10, activation="softmax"), | ||
] | ||
) | ||
|
||
if hook is None: | ||
if save_config is None: | ||
save_config = SaveConfig(save_interval=3) | ||
|
||
hook = smd.KerasHook( | ||
trial_dir, | ||
save_config=save_config, | ||
save_all=save_all, | ||
include_collections=include_collections, | ||
reduction_config=reduction_config, | ||
) | ||
|
||
if not save_all and include_collections is not None: | ||
for cname in hook.include_collections: | ||
if cname not in include_collections: | ||
hook.get_collection(cname).save_config = SaveConfig(end_step=0) | ||
|
||
opt = tf.keras.optimizers.Adam() | ||
|
||
opt = hook.wrap_optimizer(opt) | ||
model.compile( | ||
optimizer=opt, | ||
loss="sparse_categorical_crossentropy", | ||
metrics=["accuracy"], | ||
run_eagerly=run_eagerly, | ||
) | ||
hooks = [] | ||
if add_callbacks: | ||
if "tensorboard" in add_callbacks: | ||
hooks.append( | ||
tf.keras.callbacks.TensorBoard( | ||
log_dir="/tmp/logs", histogram_freq=1, write_grads=True, write_images=True | ||
) | ||
) | ||
hooks.append(hook) | ||
|
||
if steps is None: | ||
steps = ["train"] | ||
for step in steps: | ||
if step == "train": | ||
model.fit(x_train, y_train, epochs=1, steps_per_epoch=10, callbacks=hooks, verbose=0) | ||
elif step == "eval": | ||
model.evaluate(x_test, y_test, steps=10, callbacks=hooks, verbose=0) | ||
elif step == "predict": | ||
model.predict(x_test[:100], callbacks=hooks, verbose=0) | ||
|
||
hook.close() | ||
|
||
|
||
def test_keras_fit_eager(out_dir, tf_eager_mode=True): | ||
test_include_collections = [ | ||
CollectionKeys.LOSSES, | ||
CollectionKeys.METRICS, | ||
CollectionKeys.WEIGHTS, | ||
CollectionKeys.BIASES, | ||
CollectionKeys.GRADIENTS, | ||
CollectionKeys.INPUTS, | ||
CollectionKeys.OUTPUTS, | ||
CollectionKeys.LAYERS, | ||
CollectionKeys.OPTIMIZER_VARIABLES, | ||
] | ||
hook = smd.KerasHook(out_dir=out_dir, include_collections=test_include_collections) | ||
helper_keras_fit( | ||
include_collections=test_include_collections, | ||
trial_dir=out_dir, | ||
hook=hook, | ||
run_eagerly=tf_eager_mode, | ||
steps=["train", "eval", "predict", "train"], | ||
) | ||
trial = smd.create_trial(path=out_dir) | ||
|
||
# We first assert that none of the collections we requested for are empty | ||
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1 | ||
assert len(trial.tensor_names(collection=CollectionKeys.METRICS)) == 2 | ||
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2 | ||
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2 | ||
assert len(trial.tensor_names(collection=CollectionKeys.GRADIENTS)) == 4 | ||
assert len(trial.tensor_names(collection=CollectionKeys.INPUTS)) == 1 # 1 Model Input | ||
assert len(trial.tensor_names(collection=CollectionKeys.OUTPUTS)) == 2 # 2 Model outputs | ||
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5 | ||
|
||
# We assert that all the tensors saved have a valid value | ||
for tname in trial.tensor_names(): | ||
assert trial.tensor(tname).value(0) is not None | ||
|
||
# We then analyse Layer Inputs and Layer Outputs | ||
# Check that output of layer is equal to the input of the next | ||
boolean_matrix = trial.tensor("flatten/outputs").value(0) == trial.tensor("dense/inputs").value( | ||
0 | ||
) | ||
assert boolean_matrix.all() | ||
boolean_matrix = trial.tensor("dense/outputs").value(0) == trial.tensor("dropout/inputs").value( | ||
0 | ||
) | ||
assert boolean_matrix.all() | ||
boolean_matrix = trial.tensor("dropout/outputs").value(0) == trial.tensor( | ||
"dense_1/inputs" | ||
).value(0) | ||
assert boolean_matrix.all() | ||
|
||
|
||
def test_keras_fit_false(out_dir, tf_eager_mode=False): | ||
test_include_collections = [ | ||
CollectionKeys.LOSSES, | ||
CollectionKeys.METRICS, | ||
CollectionKeys.WEIGHTS, | ||
CollectionKeys.BIASES, | ||
CollectionKeys.GRADIENTS, | ||
CollectionKeys.INPUTS, | ||
CollectionKeys.OUTPUTS, | ||
CollectionKeys.LAYERS, | ||
CollectionKeys.OPTIMIZER_VARIABLES, | ||
] | ||
hook = smd.KerasHook(out_dir=out_dir, include_collections=test_include_collections) | ||
helper_keras_fit( | ||
include_collections=test_include_collections, | ||
trial_dir=out_dir, | ||
hook=hook, | ||
run_eagerly=tf_eager_mode, | ||
steps=["train", "eval", "predict", "train"], | ||
) | ||
trial = smd.create_trial(path=out_dir) | ||
|
||
# We first assert that none of the collections we requested for are empty | ||
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1 | ||
assert len(trial.tensor_names(collection=CollectionKeys.METRICS)) == 2 | ||
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2 | ||
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2 | ||
assert len(trial.tensor_names(collection=CollectionKeys.GRADIENTS)) == 4 | ||
assert len(trial.tensor_names(collection=CollectionKeys.INPUTS)) == 1 # 1 Model Input | ||
assert len(trial.tensor_names(collection=CollectionKeys.OUTPUTS)) == 2 # 2 Model outputs | ||
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5 | ||
|
||
# We assert that all the tensors saved have a valid value | ||
for tname in trial.tensor_names(): | ||
assert trial.tensor(tname).value(0) is not None | ||
|
||
# We then analyse Layer Inputs and Layer Outputs | ||
# Check that output of layer is equal to the input of the next | ||
boolean_matrix = trial.tensor("flatten_1/outputs").value(0) == trial.tensor( | ||
"dense_2/inputs" | ||
).value(0) | ||
assert boolean_matrix.all() | ||
boolean_matrix = trial.tensor("dense_2/outputs").value(0) == trial.tensor( | ||
"dropout_1/inputs" | ||
).value(0) | ||
assert boolean_matrix.all() | ||
boolean_matrix = trial.tensor("dropout_1/outputs").value(0) == trial.tensor( | ||
"dense_3/inputs" | ||
).value(0) | ||
assert boolean_matrix.all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
161 changes: 161 additions & 0 deletions
161
examples/tensorflow2/scripts/tf_save_metrics_gradient_tape.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
""" | ||
This file is temporary, for testing with 2.X. | ||
We'll need to integrate a more robust testing pipeline and make this part of pytest | ||
before pushing to master. | ||
""" | ||
# Standard Library | ||
import shutil | ||
|
||
# Third Party | ||
import pytest | ||
import tensorflow.compat.v2 as tf | ||
|
||
# First Party | ||
import smdebug.tensorflow as smd | ||
from smdebug.core.collection import CollectionKeys | ||
from smdebug.tensorflow import SaveConfig | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def out_dir(): | ||
""" Use this method to construct an out_dir. | ||
|
||
Then it will be automatically cleaned up for you, passed into the test method, and we'll have | ||
fewer folders lying around. | ||
""" | ||
out_dir = "/tmp/test" | ||
shutil.rmtree(out_dir, ignore_errors=True) | ||
return out_dir | ||
|
||
|
||
def helper_keras_gradtape( | ||
trial_dir, | ||
save_all=False, | ||
include_collections=None, | ||
reduction_config=None, | ||
save_config=None, | ||
hook=None, | ||
batch_size=64, | ||
persistent=False, | ||
): | ||
mnist = tf.keras.datasets.mnist | ||
(x_train, y_train), _ = mnist.load_data() | ||
dataset = tf.data.Dataset.from_tensor_slices( | ||
(tf.cast(x_train[..., tf.newaxis] / 255, tf.float32), tf.cast(y_train, tf.int64)) | ||
) | ||
dataset = dataset.shuffle(1000).batch(batch_size) | ||
|
||
model = tf.keras.models.Sequential( | ||
[ | ||
# WA for TF issue https://github.com/tensorflow/tensorflow/issues/36279 | ||
tf.keras.layers.Flatten(input_shape=(28, 28, 1)), | ||
tf.keras.layers.Dense(128, activation="relu"), | ||
tf.keras.layers.Dropout(0.2), | ||
tf.keras.layers.Dense(10, activation="softmax"), | ||
] | ||
) | ||
|
||
if hook is None: | ||
if save_config is None: | ||
save_config = SaveConfig(save_interval=3) | ||
|
||
hook = smd.KerasHook( | ||
trial_dir, | ||
save_config=save_config, | ||
save_all=save_all, | ||
include_collections=include_collections, | ||
reduction_config=reduction_config, | ||
) | ||
|
||
if not save_all and include_collections is not None: | ||
for cname in hook.include_collections: | ||
if cname not in include_collections: | ||
hook.get_collection(cname).save_config = SaveConfig(end_step=0) | ||
|
||
opt = tf.keras.optimizers.Adam() | ||
hook.wrap_optimizer(opt) | ||
hook.register_model(model) # Can be skipped in ZCC | ||
|
||
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True) | ||
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy() | ||
|
||
n_epochs = 1 | ||
for epoch in range(n_epochs): | ||
for data, labels in dataset: | ||
dataset_labels = labels | ||
labels = tf.one_hot(labels, depth=10) | ||
with hook.wrap_tape(tf.GradientTape(persistent=persistent)) as tape: | ||
logits = model(data, training=True) | ||
loss_value = cce(labels, logits) | ||
hook.save_tensor("y_labels", labels, "outputs") | ||
grads = tape.gradient(loss_value, model.variables) | ||
|
||
# By default, the resources held by a GradientTape are released as | ||
# soon as GradientTape.gradient() method is called. To compute | ||
# multiple gradients over the same computation, create a persistent | ||
# gradient tape. This allows multiple calls to the gradient() method | ||
# as resources are released when the tape object is garbage collected. | ||
if persistent: | ||
_ = tape.gradient(loss_value, model.variables) | ||
opt.apply_gradients(zip(grads, model.variables)) | ||
acc = train_acc_metric(dataset_labels, logits) | ||
hook.save_tensor( | ||
tensor_name="accuracy", | ||
tensor_value=acc, | ||
collections_to_write=CollectionKeys.METRICS, | ||
) | ||
train_acc_metric.reset_states() | ||
|
||
hook.close() | ||
|
||
|
||
def test_keras_gradtape(out_dir): | ||
""" | ||
Test save all and save default collection | ||
""" | ||
include_collections = [ | ||
CollectionKeys.WEIGHTS, | ||
CollectionKeys.BIASES, | ||
CollectionKeys.GRADIENTS, | ||
CollectionKeys.LAYERS, | ||
CollectionKeys.LOSSES, | ||
CollectionKeys.INPUTS, | ||
CollectionKeys.OUTPUTS, | ||
CollectionKeys.METRICS, | ||
CollectionKeys.OPTIMIZER_VARIABLES, | ||
] | ||
hook = smd.KerasHook( | ||
out_dir=out_dir, | ||
save_config=SaveConfig(save_interval=1), | ||
include_collections=include_collections, | ||
) | ||
helper_keras_gradtape(trial_dir=out_dir, hook=hook) | ||
|
||
trial = smd.create_trial(path=out_dir) | ||
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2 | ||
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2 | ||
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5 | ||
assert len(trial.tensor_names(collection=CollectionKeys.LAYERS)) == 8 | ||
assert len(trial.tensor_names(collection=CollectionKeys.OUTPUTS)) == 2 | ||
assert len(trial.tensor_names(collection=CollectionKeys.INPUTS)) == 1 | ||
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1 | ||
assert len(trial.tensor_names(collection=CollectionKeys.METRICS)) == 1 | ||
|
||
# We assert that all the tensors saved have a valid value | ||
for tname in trial.tensor_names(): | ||
assert trial.tensor(tname).value(5) is not None | ||
|
||
# We then analyse Layer Inputs and Layer Outputs | ||
# Check that output of a layer is equal to the input of the next | ||
boolean_matrix = trial.tensor("flatten/outputs").value(0) == trial.tensor("dense/inputs").value( | ||
0 | ||
) | ||
assert boolean_matrix.all() | ||
boolean_matrix = trial.tensor("dense/outputs").value(0) == trial.tensor("dropout/inputs").value( | ||
0 | ||
) | ||
assert boolean_matrix.all() | ||
boolean_matrix = trial.tensor("dropout/outputs").value(0) == trial.tensor( | ||
"dense_1/inputs" | ||
).value(0) | ||
assert boolean_matrix.all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so these files will be deleted just before merge?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These files are currently run on the AWS TF test pipeline.
They will be either modified or deleted after merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create a issue to followup for the testing after PR is approved.