-
Notifications
You must be signed in to change notification settings - Fork 80
/
__init__.py
5721 lines (4511 loc) · 201 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import builtins
import itertools
import math
import operator
import collections
import re
from collections.abc import Sequence
from enum import Enum
from functools import partial, reduce, wraps
from numbers import Number
from typing import Any, overload
from types import NoneType, ModuleType
from collections.abc import Callable
import opt_einsum
# Initializes the language context
from thunder.torch.langctx import register_method, register_property
from thunder.core.baseutils import run_once
import thunder.clang as clang
import thunder.core.devices as devices
from thunder.core.devices import to_device, device_from_string
import thunder.core.dtypes as dtypes
from thunder.core.dtypes import to_torch_dtype, to_dtype, _thunder_to_torch_dtype_map, _torch_to_thunder_dtype_map
import thunder.core.prims as prims
import thunder.core.utils as utils
import thunder.distributed.prims as dist_prims
from thunder.core.langctxs import langctx, Languages, get_langctx
from thunder.core.proxies import (
FloatProxy,
IntegerProxy,
NumberProxy,
NumberLike,
TensorProxy,
FutureTensorProxy,
pyval,
TupleProxy,
ListProxy,
DictProxy,
numberproxy,
)
from thunder.core.pytree import tree_map, tree_flatten, tree_unflatten
from thunder.core.symbol import Symbol
from thunder.core.transforms import register_grad
from thunder.core.prims import get_grad, put_grad
from thunder.core.baseutils import run_once
import thunder
from thunder.torch.default_torch_ops import _auto_registered_operators_returning_views
__all__ = [
"is_available",
]
# NOTE torch is a requirement
import torch
import warnings
# Type annotation helpers
TensorLike = TensorProxy
FutureTensorLike = FutureTensorProxy
DeviceLike = str | devices.Device | torch.device
dtypeLike = dtypes.dtype | torch.dtype
# TODO RC1 Remove this map
_torch_noinline_functions = {
torch.nn.modules.utils._single,
torch.nn.modules.utils._pair,
torch.nn.modules.utils._triple,
torch.nn.modules.utils._quadruple,
}
# Maps torch functions, like torch.foo, to their corresponding thunder.torch functions
# NOTE This is defined here and populated as functions are defined below
_torch_to_thunder_function_map: dict[Callable, Callable] = {}
#
# torch operation definitions
#
# in-place sym -> out-of-place (= functional) sym with index of `inplace` argument
_inplace_to_out_of_place: dict[Callable, tuple[Callable, int]] = {}
# Helpers for factory functions to get default dtype and device.
def get_default_dtype():
# `thunder.jit` will create cache info and stash the default dtype
# observed at the beginning of jitting.
cache_info = thunder._get_cache_info()
# Currently, changing dtype during the jitted function is unsupported.
utils.check(
cache_info["default_dtype"] == torch.get_default_dtype(),
lambda: "Default dtype is changed during the execution of jitted function. This is currently unsupported.",
)
return torch.get_default_dtype()
def maybe_get_default_dtype(dtype):
return dtype or get_default_dtype()
def get_default_device():
# `thunder.jit` will create cache info and stash the default device
# observed at the beginning of jitting.
cache_info = thunder._get_cache_info()
# Currently, changing device during the jitted function is unsupported.
utils.check(
cache_info["default_device"] == torch.get_default_device(),
lambda: "Default device is changed during the execution of jitted function. This is currently unsupported.",
)
return torch.get_default_device()
def maybe_get_default_device(device):
return device or get_default_device()
# A wrapper that executes the operations within the torch language context
# NOTE because this module defines the torch language context, a reference to itself
# is acquired by inspecting the __module__ attribute of the is_available function defined
# above
# NOTE Functions that set is_method=True must be able to accept a tensor as their first positional input
class torchsymbol:
def __init__(
self,
*torchfns,
is_method: bool = False,
method_name: None | str = None,
is_property: bool = False,
id: str | None = None,
is_prim: bool = False,
tags: None | list[Any] = None,
):
self.torchfns = torchfns
self.is_method = is_method or (method_name is not None)
self.method_name: None | str = method_name
self.is_property = is_property
self.id = id
# When is_prim is True, the function is treated as a primitive, so that
# executors must execute it directly without decomposition.
self.is_prim = is_prim
self.tags = tags
def __call__(self, fn: Callable) -> Symbol:
_fn = langctx(Languages.TORCH)(fn)
id: str
if self.id is None:
name = fn.__name__
if hasattr(torch, name):
id = f"torch.{name}"
elif hasattr(torch.nn.functional, name):
id = f"torch.nn.functional.{name}"
elif hasattr(torch.Tensor, name):
id = f"torch.Tensor.{name}"
elif hasattr(torch.ops.aten, name):
id = f"torch.ops.aten.{name}"
elif hasattr(torch.special, name):
id = f"torch.special.{name}"
else:
utils.check(
False,
lambda: f"The torchsymbol decorator failed to infer an id for {name}, specify one explicitly (with id=<your id>)",
exception_type=AssertionError,
)
else:
id = self.id
if self.is_prim:
sym = Symbol(
name=fn.__name__, meta=langctx(Languages.PRIMS)(_fn), id=id, is_prim=self.is_prim, tags=self.tags
)
else:
sym = Symbol(name=fn.__name__, meta=_fn, id=id, is_prim=self.is_prim, tags=self.tags)
if self.is_method:
method_name: str = self.method_name if self.method_name is not None else fn.__name__
register_method(method_name, sym)
torch_method: None | Callable = getattr(torch.Tensor, method_name, None)
if torch_method is not None:
_torch_to_thunder_function_map[torch_method] = sym
elif self.is_property:
method_name: str = self.method_name if self.method_name is not None else fn.__name__
register_property(method_name, sym)
torch_property = getattr(torch.Tensor, method_name, None)
if torch_property is not None:
_torch_to_thunder_function_map[torch_property] = sym
if self.torchfns is not None:
for torchfn in self.torchfns:
_torch_to_thunder_function_map[torchfn] = sym
if self.tags and prims.OpTags.IN_PLACE in self.tags:
if self.id is not None:
name = self.id
_inplace_to_out_of_place[sym] = globals()[name[:-1]], -1
return sym
# This is function maps an implementation for `torch` operation without creating a Symbol.
# So, the registered implementation will not show up in trace as a Symbol (but will get inlined).
# This is helpful if we want to support a torch operation and bake it's output directly into the trace.
# See `clone` and `torch.device` for example.
def register_function(torchfn, thunderfn_impl):
_torch_to_thunder_function_map[torchfn] = thunderfn_impl
#
# Tensor properties
#
@torchsymbol(torch.Tensor.dim, is_method=True)
def dim(a: TensorLike, /) -> int:
return a.ndim
# NOTE: Named `compute_len` so that it doesn't
# conflict with built-in `len`
def compute_len(a: TensorLike, /) -> int:
if a.ndim == 0:
raise TypeError("len() of a 0-d tensor")
return a.shape[0]
register_method("len", compute_len)
@torchsymbol(torch.is_floating_point, is_method=True)
def is_floating_point(a: TensorLike, /) -> bool:
return dtypes.is_float_dtype(a.dtype)
# Handles the size method
def size(a: TensorLike, /, dim: None | int = None) -> int | Sequence[int]:
if dim is not None:
return a.shape[dim]
return a.shape
register_method("size", size)
@torchsymbol(torch.numel, torch.Tensor.numel, is_method=True)
def numel(a: TensorLike, /) -> int:
return a._numel
register_method("numel", numel)
@torchsymbol(torch.Tensor.is_complex, is_property=True, id="torch.is_complex")
def is_complex(a: TensorLike, /) -> bool:
return dtypes.is_complex_dtype(a.dtype)
_torch_to_thunder_function_map[torch.is_complex] = is_complex
_torch_to_thunder_function_map[torch.Tensor.is_complex] = is_complex
register_method("is_complex", is_complex)
@torchsymbol(torch.Tensor.is_cuda, is_property=True, id="torch.is_cuda")
def is_cuda(a: TensorLike, /) -> bool:
return a.device.devicetype is devices.DeviceType.CUDA
# is nested always returns False for now:
# https://github.com/Lightning-AI/lightning-thunder/issues/93#issuecomment-2030416883
@torchsymbol(torch.Tensor.is_nested, is_property=True, id="torch.is_nested")
def is_nested(a: TensorLike, /) -> bool:
return False
_torch_dtype_to_old_torch_typestring_map = {
torch.float32: "FloatTensor",
torch.float64: "DoubleTensor",
torch.float16: "HalfTensor",
torch.bfloat16: "BFloat16Tensor",
torch.uint8: "ByteTensor",
torch.int8: "CharTensor",
torch.int16: "ShortTensor",
torch.int32: "IntTensor",
torch.long: "LongTensor",
torch.bool: "BoolTensor",
}
_old_torch_typestring_to_torch_dtype_map = {v: k for k, v in _torch_dtype_to_old_torch_typestring_map.items()}
def _device_and_dtype_to_old_torch_typestring(device: DeviceLike, dtype: dtypeLike) -> str:
torch_dtype = to_torch_dtype(dtype)
dtype_str = _torch_dtype_to_old_torch_typestring_map.get(torch_dtype)
devicetype_str: str = ""
if device.devicetype is not devices.DeviceType.CPU:
devicetype_str = f"{devices.devicetype_string(device.devicetype)}."
return f"torch.{devicetype_str}{dtype_str}"
def _old_torch_typestring_to_devicetype_and_dtype(typestring: str) -> tuple[DeviceLike, dtypeLike]:
# Two cases:
# - torch.DtypeTensor
# - torch.device.DtypeTensor
_, *dev_and_dtype = typestring.split(".")
devicetype_str = "cpu"
dtype_str = ""
if len(dev_and_dtype) == 1:
# when devicetype_str is omitted, device type is CPU
(dtype_str,) = dev_and_dtype
dtype_str = _old_torch_typestring_to_torch_dtype_map[dtype_str]
if len(dev_and_dtype) == 2:
devicetype_str, dtype_str = dev_and_dtype
dtype_str = _old_torch_typestring_to_torch_dtype_map[dtype_str]
# Value error
# expected the string to split into one or two elements
# and devicetype_str should be either "cpu" or "cuda"
utils.check(
devicetype_str in ("cpu", "cuda") and 1 <= len(dev_and_dtype) <= 2,
lambda: f"type(): unrecognized torch typestring {typestring}",
exception_type=ValueError,
)
return devicetype_str, dtype_str
@torchsymbol(torch.Tensor.type, is_method=True)
def type(a: TensorLike, /, dtype: None | str | dtypeLike = None, non_blocking: bool = False) -> str | TensorLike:
utils.check(
not non_blocking,
lambda: f"type(): `non_blocking==True` is currently not supported.",
exception_type=NotImplementedError,
)
if dtype is None:
# returns the type of the input tensor in string
return _device_and_dtype_to_old_torch_typestring(a.device, a.dtype)
if isinstance(dtype, str):
devtype, dtype = _old_torch_typestring_to_devicetype_and_dtype(dtype)
if devtype == a.device.type:
# This handles two cases:
# 1. When a tensor is already on a CUDA device, and the device type string is CUDA. In this case the tensor remains on its current device.
# 2. When a tensor is on a CPU device and the device type string is omitted, the tensor remains on the CPU device.
dev = a.device
else:
dev = to_device(devtype)
else:
# dtype is assumed to be torch.dtype (e.g. torch.int32)
dev = a.device
return to(a, dev, dtype)
register_method("type", type)
#
# Data movement and transformation operations
#
# NOTE This handles a.float()
# It avoids using the name "float" to not collide with the builtin
# "float"
def to_float(a: NumberLike | TensorLike) -> Number | TensorLike:
return clang.maybe_convert_to_dtype(a, dtypes.float32)
register_method("float", to_float)
# NOTE to's parsing is a little whacky
# to supports five first positional arguments
# 1) a tensor, in which case device and dtype cannot be specified (although we allow them to be)
# 2) a dtype, in which case device cannot be specified (although we allow this)
# 3) a device, in which case dtype can be specified,
# 4) None, in which case device and dtype come from kwargs (which may also be None, a.to() is valid and just returns)
# a itself
# 5) device and dtype
def _parse_to_device_and_dtype(
tensor_dtype_or_device: None | TensorLike | dtypeLike | DeviceLike = None,
optional_positional_dtype: None | dtypeLike = None,
/,
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
) -> tuple[devices.Device, dtypes.dtype]:
# Case 3 and 5 -- device first
if isinstance(tensor_dtype_or_device, (torch.device, devices.Device, str)):
utils.check(device is None, lambda: f"to received both a positional and keyword device argument")
device = to_device(tensor_dtype_or_device)
if optional_positional_dtype is not None:
utils.check(dtype is None, lambda: f"to received both a positional and keyword dtype argument")
dtype = to_dtype(optional_positional_dtype)
else:
dtype = to_dtype(dtype)
# Case 2 -- dtype first
elif isinstance(tensor_dtype_or_device, (torch.dtype, dtypes.dtype)):
utils.check(dtype is None, lambda: f"to received both a positional and keyword dtype argument")
device = to_device(device) if device is not None else None
dtype = to_dtype(tensor_dtype_or_device)
# Case 4 -- None first
elif tensor_dtype_or_device is None:
device = to_device(device) if device is not None else None
dtype = to_dtype(dtype)
# Case 1 -- tensor first
else:
# It'd be nice to write torch.Tensor here instead of TensorProxy.
# See issue "Translate isinstance(a, torch.Tensor) calls so that
# TensorProxies can pass as torch.Tensors"
utils.check_type(tensor_dtype_or_device, TensorProxy)
device_ = tensor_dtype_or_device.device if device is None else to_device(device)
dtype_ = tensor_dtype_or_device.true_dtype if dtype is None else to_dtype(dtype)
device, dtype = device_, dtype_
return device, dtype
# TODO Model non_blocking (as kwargs)
@torchsymbol(torch.Tensor.to, is_method=True)
def to(
a: TensorLike,
tensor_dtype_or_device: None | TensorLike | dtypeLike | DeviceLike = None,
optional_positional_dtype: None | dtypeLike = None,
/,
*,
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
copy: bool = False,
memory_format: None | torch.memory_format = None,
) -> TensorLike:
device, dtype = _parse_to_device_and_dtype(
tensor_dtype_or_device, optional_positional_dtype, device=device, dtype=dtype
)
if copy:
if device is not None:
device = to_device(device)
a = prims.device_put(a, device)
if dtype is not None:
dtype = to_dtype(dtype)
a = prims.convert_element_type(a, dtype)
if memory_format is not None:
# NOTE not sure if we need to handle torch.preserve_format explicitly
if memory_format == torch.channels_last:
a = prims.stride_order(a, (3, 0, 2, 1))
elif memory_format == torch.channels_last_3d:
a = prims.stride_order(a, (4, 0, 3, 2, 1))
return a
# NOTE copy == False
# NOTE to() returns the tensor unmodified if the device and dtype requested are the same
# (and copy=False)
# NOTE clang.device_put does nothing when device is None or a.device == device
a = clang.device_put(a, device)
if dtype is not None:
return clang.maybe_convert_to_dtype(a, dtype)
if memory_format is not None:
# NOTE not sure if we need to handle torch.preserve_format explicitly
if memory_format == torch.channels_last:
a = prims.stride_order(a, (3, 0, 2, 1))
elif memory_format == torch.channels_last_3d:
a = prims.stride_order(a, (4, 0, 3, 2, 1))
return a
@torchsymbol(torch.Tensor.cuda, is_method=True)
def cuda(
a: TensorLike,
/,
device: None | DeviceLike = None,
non_blocking: bool = False,
memory_format: None | torch.memory_format = None,
) -> TensorLike:
# Modeled similar to PyTorch:
# https://github.com/pytorch/pytorch/blob/e3ac61587aa368c613ef01df1f328a396b64cd5d/tools/autograd/templates/python_variable_methods.cpp#L496-L501
# If `device` is None, this function defaults `device` to current CUDA device
# and delegates actual data-movement and layout ordering to `Tensor.to`.
# NOTE: `Tensor.to` doesn't model `non_blocking` currently.
utils.check(not non_blocking, lambda: "cuda(): `non_blocking==True` is currently not supported.")
if device is None:
# Move tensor to `current` GPU device.
cuda_idx = torch.cuda.current_device()
device = devices.Device(devices.DeviceType.CUDA, cuda_idx)
else:
device = to_device(device)
utils.check(
device.devicetype == devices.DeviceType.CUDA,
lambda: f"cuda(): Invalid device {device.device_str()}, must be cuda device",
)
return to(a, device=device, memory_format=memory_format)
@torchsymbol(torch.Tensor.type_as, is_method=True)
def type_as(a: TensorProxy, b: TensorProxy, /) -> TensorProxy:
# NOTE This type check is intentional since we're accessing the true_dtype
# attribute of the TensorProxy
# TODO Create a generic Tensor annotation, and support both PyTorch
# tensors and TensorProxies being passed to this operation
utils.check_type(b, TensorProxy)
return to(a, b.true_dtype, device=b.device)
@torchsymbol(torch.Tensor.long, is_method=True)
def long(a: TensorLike, /, memory_format: torch.memory_format = torch.preserve_format) -> TensorLike:
return to(a, dtype=dtypes.int64, memory_format=memory_format)
#
# Tensor creation operations
#
@torchsymbol(torch.arange)
def arange(
start: NumberLike,
end: None | Number = None,
step: NumberLike = 1,
*,
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
) -> TensorLike:
device = maybe_get_default_device(device)
device = to_device(device)
# From torch docs - https://pytorch.org/docs/stable/generated/torch.arange.html
# If any of start, end, or stop are floating-point, the dtype is inferred to be the default dtype, see get_default_dtype().
# Otherwise, the dtype is inferred to be torch.int64.
if dtype is None: # infer the dtype
if any(map(lambda x: isinstance(x, float), (start, end, step))):
dtype = maybe_get_default_dtype(dtype)
else:
dtype = torch.int64
dtype = to_dtype(dtype)
if end is None:
end = start
start = 0
return clang.arange(start=start, step=step, stop=end, device=device, dtype=dtype)
# Infers dtype from the fill_value and dtype
def _infer_full_dtype(fill_value: NumberLike, dtype: None | dtypeLike) -> dtypeLike:
# Short-circuits if dtype is explicitly specified
if dtype is not None:
return to_dtype(dtype)
# NOTE dtype is None
fill_value_dtype = dtypes.numbertype_to_dtype(dtypes.to_dtype(fill_value))
if dtypes.is_boolean_dtype(fill_value_dtype):
return dtypes.bool8
if dtypes.is_nonboolean_integer_dtype(fill_value_dtype):
return dtypes.int64
current_default_dtype = get_default_dtype()
# NOTE When the `fill_value' is a complex dtype, Thunder infers a slightly different dtype than Torch.
# Torch (2.5.0a0+git8927fc2):
# float64 -> complex128
# float32, float16, bfloat16 -> complex64
# (Ref: the torch function: https://github.com/pytorch/pytorch/blob/cd307fb0b1a833f9297d2233653b514ed4aa3163/aten/src/ATen/native/TensorFactories.cpp#L584-L604)
# Thunder uses `dtypes.corresponding_complex_dtype` (see its implementation for details)
# The only difference is that when `fill_value_dtype` is float16, Thunder returns complex32 but Torch returns complex64
if dtypes.is_complex_dtype(fill_value_dtype):
return dtypes.corresponding_complex_dtype(current_default_dtype)
# NOTE fill_value_dtype is a non-complex floating-point type
return to_dtype(current_default_dtype)
@torchsymbol(torch.full)
def full(
shape: Sequence[int], fill_value: NumberLike, *, device: None | DeviceLike = None, dtype: None | dtypeLike = None
) -> TensorLike:
device = to_device(maybe_get_default_device(device))
dtype = _infer_full_dtype(fill_value, dtype)
return clang.full(shape, fill_value, device=device, dtype=dtype)
@torchsymbol(torch.full_like)
def full_like(
a: TensorLike, /, fill_value: NumberLike, *, device: None | DeviceLike = None, dtype: None | dtypeLike = None
) -> TensorLike:
device = to_device(device)
dtype = to_dtype(dtype)
return clang.full_like(a, fill_value, device=device, dtype=dtype)
# NOTE ones, unlike full, can accept an integer shape
@torchsymbol(torch.ones)
def ones(*shape: int, device: None | DeviceLike = None, dtype: None | dtypeLike = None) -> TensorLike:
shape = utils.extract_shape_from_varargs(shape)
return full(shape, 1, device=device, dtype=maybe_get_default_dtype(dtype))
@torchsymbol(torch.ones_like)
def ones_like(a: TensorLike, /, *, device: None | DeviceLike = None, dtype: None | dtypeLike = None) -> TensorLike:
return full_like(a, 1, device=device, dtype=dtype)
@torchsymbol(torch.tensor, is_method=False, id="torch.tensor")
def tensor(
seq_or_number: Sequence | Number,
*,
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> TensorLike:
# TODO: Support torch.Tensor/np.ndarray as input similar to `torch.tensor`
utils.check(
isinstance(seq_or_number, (Number, Sequence)),
lambda: f"Currently only directly constructing tensors with a single number or a Sequence of numbers is supported, but received {n}",
exception_type=NotImplementedError,
)
utils.check(
not requires_grad, lambda: "requires_grad=True is not yet supported within thunder.jit", NotImplementedError
)
utils.check(not pin_memory, lambda: "pin_memory=True is not supported within thunder.jit", NotImplementedError)
if isinstance(seq_or_number, (Number, NumberProxy)):
# Infer dtype from value (as `full` will use default dtype if dtype=None).
if dtype is None:
dtype = dtypes.numbertype_to_dtype(dtypes.to_dtype(seq_or_number))
return full((), seq_or_number, dtype=dtype, device=device)
return clang.tensor_from_sequence(seq_or_number, dtype=dtype, device=device)
# TODO based on uniform_, check if Torch now has a functional uniform
# NOTE the uniform_ documentation suggests the interval is specified using "from" and "to",
# but from is a reserved keyword in Python
@torchsymbol(is_method=False, id="torch.uniform")
def uniform(
shape: Sequence[int],
minval: NumberLike = 0.0,
maxval: NumberLike = 1.0,
*,
device: DeviceLike,
dtype: dtypeLike,
) -> TensorLike:
device = to_device(maybe_get_default_device(device))
dtype = to_dtype(maybe_get_default_dtype(dtype))
return clang.uniform(shape, minval, maxval, device=device, dtype=dtype)
@torchsymbol(is_method=False, id="torch.uniform_like")
def uniform_like(
a: TensorLike,
/,
minval: NumberLike = 0.0,
maxval: NumberLike = 1.0,
*,
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
) -> TensorLike:
device = to_device(device)
dtype = to_dtype(dtype)
return clang.uniform_like(a, minval, maxval, device=device, dtype=dtype)
@torchsymbol(torch.multinomial, is_method=True, id="torch.multinomial")
def multinomial(
a: TensorLike,
num_samples: int,
replacement: bool = False,
*,
generator: torch.Generator | None = None,
out: TensorLike | None = None,
) -> TensorLike:
utils.check(out is None, lambda: "Non-None out is not supported", NotImplementedError)
# See issue "randomness: enable PyTorch generators for operations like
# multinomial"
utils.check(
generator is None, lambda: f"multinomial does not yet support specifying a generator", NotImplementedError
)
seed = None
samples = prims.multinomial(a, num_samples, replacement, seed)
return samples
# TODO Maybe update this to return an offset of how far to advance the seed to acquire new values
# See issue "Maybe return offset from thunder.torch.uniform_philox"
@torchsymbol(is_method=False, id="torch.uniform_philox")
def uniform_philox(
shape: Sequence[int],
minval: NumberLike = 0.0,
maxval: NumberLike = 1.0,
*,
device: DeviceLike,
dtype: dtypeLike,
seed: int | TensorProxy,
offset: int | TensorProxy,
) -> TensorLike:
device = to_device(maybe_get_default_device(device))
dtype = to_dtype(maybe_get_default_dtype(dtype))
return clang.uniform_philox(shape, minval, maxval, device=device, dtype=dtype, seed=seed, offset=offset)
@torchsymbol(torch.randn)
def randn(
*shape,
generator: None | torch.Generator = None,
dtype: None | dtypeLike = None,
device: None | DeviceLike = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
pin_memory: bool = False,
out: TensorLike = None,
):
utils.check(
not requires_grad, lambda: "requires_grad=True is not yet supported within thunder.jit", NotImplementedError
)
utils.check(layout == torch.strided, lambda: "Only torch.strided layout is supported", NotImplementedError)
utils.check(not pin_memory, lambda: "pin_memory=True is not supported within thunder.jit", NotImplementedError)
# NOTE: Currently, we don't model randomness
utils.check(generator is None, lambda: "generator is not None which is currently unsupported", NotImplementedError)
utils.check(out is None, lambda: "out is not None which is currently unsupported", NotImplementedError)
device = to_device(maybe_get_default_device(device))
dtype = to_dtype(maybe_get_default_dtype(dtype))
shape = tuple(utils.extract_shape_from_varargs(shape))
return prims.randn(shape, device=device, dtype=dtype)
@torchsymbol(torch.randn_like)
def randn_like(
a,
/,
*,
dtype: None | dtypeLike = None,
device: None | DeviceLike = None,
layout: None | torch.layout = None,
requires_grad: bool = False,
memory_format: torch.memory_format = torch.preserve_format,
):
utils.check(
not requires_grad, lambda: "requires_grad=True is not supported within thunder.jit", NotImplementedError
)
utils.check(
layout is None or layout == torch.strided, lambda: "Only torch.strided layout is supported", NotImplementedError
)
utils.check(
memory_format == torch.preserve_format,
lambda: "preserve_format!=torch.preserve_format is not supported within thunder.jit",
NotImplementedError,
)
if dtype is None:
dtype = a.dtype
if device is None:
device = a.device
return randn(a.shape, dtype=dtype, device=device)
@torchsymbol(torch.bernoulli, is_method=True)
def bernoulli(a: TensorLike, *, generator=None, out=None):
# NOTE: Currently, we don't model randomness
utils.check(
generator is None,
lambda: "bernoulli: generator is not None which is currently unsupported",
NotImplementedError,
)
utils.check(out is None, lambda: "bernoulli: out is not None which is currently unsupported", NotImplementedError)
utils.check(dtypes.is_float_dtype(a.dtype), lambda: f"bernoulli only supports floating point dtypes, got {a.dtype}")
return (uniform_like(a) < a).to(a.dtype)
# NOTE zeros, like ones, and unlike full, can accept an integer shape
@torchsymbol(torch.zeros)
def zeros(*shape: int, device: None | DeviceLike = None, dtype: None | dtypeLike = None) -> TensorLike:
shape = utils.extract_shape_from_varargs(shape)
return full(shape, 0, device=device, dtype=maybe_get_default_dtype(dtype))
@torchsymbol(torch.zeros_like)
def zeros_like(a: TensorLike, /, *, device: DeviceLike | None = None, dtype: dtypeLike | None = None) -> TensorLike:
return full_like(a, 0, device=device, dtype=dtype)
@torchsymbol(torch.empty)
def empty(
*size: int,
device: None | DeviceLike = None,
dtype: None | dtypeLike = None,
out: None | TensorLike = None,
layout: torch.layout = torch.strided,
requires_grad: bool = False,
pin_memory: bool = False,
memory_format: torch.memory_format = torch.contiguous_format,
) -> TensorLike:
size = utils.extract_shape_from_varargs(size)
utils.check(out is None, lambda: "empty(): out is not None which is currently unsupported", NotImplementedError)
utils.check(layout == torch.strided, lambda: "Only torch.strided layout is supported", NotImplementedError)
utils.check(
not requires_grad, lambda: "requires_grad=True is not yet supported within thunder.jit", NotImplementedError
)
utils.check(not pin_memory, lambda: "pin_memory=True is not supported within thunder.jit", NotImplementedError)
utils.check(
memory_format == torch.contiguous_format,
lambda: "Only torch.contiguous_format is supported",
NotImplementedError,
)
dtype = to_dtype(maybe_get_default_dtype(dtype))
device = to_device(maybe_get_default_device(device))
return clang.empty(size, device=device, dtype=dtype)
#
# Shape operations
#
# TODO Update this to take a *args series of tensors or a sequence of tensors
@torchsymbol(torch.cat)
def cat(tensors: Sequence[TensorLike], dim: int = 0) -> TensorLike:
return clang.cat(tensors, dim)
@torchsymbol(torch.chunk, is_method=True)
def chunk(a: TensorLike, chunks: int, dim: int = 0) -> Sequence[TensorLike]:
utils.check(a.ndim > 0, lambda: f"chunk: a ({a.ndim=}) must be at least 1-dimensional")
utils.check(chunks > 0, lambda: f"chunk: chunks ({chunks=}) must be greater than 0")
dim = utils.canonicalize_dim(a.ndim, dim)
a_dim_len = a.shape[dim]
# a_dim_len == 0?
# Easy case, return `chunk` number of copies of `a` as slices slice(0, 1) at dim=dim.
if a_dim_len == 0:
return tuple(clang.slice_in_dim(a, 0, 1, dim=dim) for _ in range(chunks))
# chunks == 1?
# Easy case, return a copy of `a` as a slice(0, a_dim_len) at dim=dim.
if chunks == 1:
return (clang.slice_in_dim(a, 0, a_dim_len, dim=dim),)
# NOTE: in the code below a_dim_len > 0 and chunks > 1.
# In the output, the first len - 1 tensors
# will always have shape[dim] = ceil(a.shape[dim] / chunks).
chunk_len = (a_dim_len + chunks - 1) // chunks
# Based on `chunk_len` above, the len of the result is either
# `chunk` or less, and is defined as ceil(a.shape[dim] / chunk_len).
# So we update `chunks` to this new value below.
chunks = (a_dim_len + chunk_len - 1) // chunk_len
chunk_len_last = a_dim_len - (chunks - 1) * chunk_len
# A generator that defines start and stop for each chunk.
chunk_start_end_gen = itertools.chain(
((chunk_start, chunk_start + chunk_len) for chunk_start in range(0, a_dim_len - chunk_len_last, chunk_len)),
# Last chunk
((a_dim_len - chunk_len_last, a_dim_len),),
)
return tuple(clang.slice_in_dim(a, *chunk_data, dim=dim) for chunk_data in chunk_start_end_gen)
@torchsymbol(torch.Tensor.contiguous, is_method=True)
def contiguous(a: TensorLike, /, *, memory_format: torch.memory_format = torch.contiguous_format) -> TensorLike:
# NOTE PyTorch supports the following memory formats:
# - torch.preserve_format
# - torch.contiguous_format
# - torch.channels_last
# - torch.channels_last_3d
#
# torch.channels_last is also known as channels_last_2d, and only applies to 4D tensors (NCHW dims with NHWC strides)
# torch.channels_last_3d only applies to 5D tensors (NCDHW dims with NDHWC strides)
if memory_format is torch.preserve_format:
# TODO Should this case raise a NotImplementedError? We don't know the format of a
# to preserve it
return a
elif memory_format is torch.contiguous_format:
return clang.stride_order(a)
elif memory_format is torch.channels_last:
utils.check(a.ndim == 4, lambda: f"Expected a 4D tensor for the channels last memory format")
return clang.stride_order(a, (3, 0, 2, 1))
elif memory_format is torch.channels_last_3d:
utils.check(a.ndim == 5, lambda: f"Expected a 5D tensor for the channels last 3D memory format")
return clang.stride_order(a, (4, 0, 3, 2, 1))
utils.check(False, lambda: f"Found unexpected memory_format={memory_format}", exception_type=ValueError)
@torchsymbol(torch.diagonal, is_method=True)
def diagonal(a: TensorLike, /, offset: int = 0, dim1: int = 0, dim2: int = 1) -> TensorLike:
return clang.diagonal(a, offset, dim1, dim2)
@torchsymbol(torch.Tensor.expand, is_method=True)
def expand(a: TensorLike, /, *shape: int) -> TensorLike:
return clang.expand(a, *shape)
@torchsymbol(torch.Tensor.expand_as, is_method=True)
def expand_as(a: TensorLike, b: TensorLike, /) -> TensorLike:
return expand(a, b.size())
@torchsymbol(torch.flatten, is_method=True)
def flatten(a: TensorLike, /, start_dim: int = 0, end_dim: int = -1) -> TensorLike:
return clang.flatten(a, start_dim, end_dim)
@torchsymbol(torch.flip, is_method=True)
def flip(a: TensorLike, /, *dims: int) -> TensorLike:
dims = utils.extract_shape_from_varargs(dims)
# PyTorch supports 0-dim inputs with len(dims) <= 1
if a.ndim == 0 and isinstance(dims, Sequence) and len(dims) > 0:
utils.check(
len(dims) == 1
and (
(isinstance(dims[0], (int, IntegerProxy)) and dims[0] in (0, -1))
or (isinstance(dims[0], NumberProxy) and pyval(dims[0]) in (0, -1))
),
lambda: f"Expected {dims=} to be a sequence of integers in range [-1, 0], and of length 1",
)
return clang.flip(a, ())
return clang.flip(a, dims)
@torchsymbol(torch.Tensor.__getitem__, id="torch.Tensor.__getitem__", method_name="getitem")
def getitem(a: TensorLike, /, key) -> TensorLike:
return clang.getitem(a, key)
def matrix_transpose(a: TensorLike, /) -> TensorLike:
"""Transposes the last two dimensions of a tensor.
This function is used to implement the `.mT` attribute.
Args:
a (TensorProxy): The tensor to transpose.
Returns:
TensorProxy: The transposed tensor.
Examples:
>>> a = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> def func(x): return x.mT
>>> traced_func = thunder.compile(func)
>>> traced_func(a)
tensor([[1, 4],
[2, 5],
[3, 6]])
"""
return clang.matrix_transpose(a)
register_method("mT", matrix_transpose)
@torchsymbol(torch.movedim, is_method=True)
def movedim(a: TensorLike, /, source: int | Sequence[int], destination: int | Sequence[int]) -> TensorLike:
return clang.movedim(a, source, destination)
@torchsymbol(torch.nn.functional.pad)
def pad(a: TensorProxy, /, pad: tuple[int, ...], mode: str | None = "constant", value: NumberLike | None = None):
utils.check(mode == "constant", lambda: f"Mode arguments other than constant are not supported")
utils.check(len(pad) % 2 == 0, lambda: f"Padding length must be divisible by 2")
utils.check(
len(pad) <= a.ndim * 2,
lambda: f"Padding length should be less than or equal to two times the input dimension.",
)
pad_config = []
for dim in range(a.ndim * 2 - 1, 0, -2):