You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Adding this as tracking issue to unblock #181 from landing:
per @wanchaol :
IMO we should also register the fwd/bwd rmsnorm kernel as a PyTorch op, this is so that:
making it a custom op makes it compatible with PT2, which I believe it's currently graph breaking on the FusedRMSNorm path if we turn on torch.compile
it allows other components (i.e. DTensor) to provide sharding rule to this custom op so that it would compatible with the tensor parallelism
The text was updated successfully, but these errors were encountered:
Adding this as tracking issue to unblock #181 from landing:
per @wanchaol :
IMO we should also register the fwd/bwd rmsnorm kernel as a PyTorch op, this is so that:
making it a custom op makes it compatible with PT2, which I believe it's currently graph breaking on the FusedRMSNorm path if we turn on torch.compile
it allows other components (i.e. DTensor) to provide sharding rule to this custom op so that it would compatible with the tensor parallelism
The text was updated successfully, but these errors were encountered: