Skip to content

Commit

Permalink
Suppress infinite-recursion warning (#2307)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2307

- Suppress infinite-recursion warning in location that is known
to not cause infinite recursion, as it is an implementation of
pure virtual using CRTP

Reviewed By: r-barnes

Differential Revision: D53359145

fbshipit-source-id: 4a06134efecbda49d353a17fe60b9a6496ee1b32
  • Loading branch information
q10 authored and facebook-github-bot committed Feb 6, 2024
1 parent 7889f64 commit 86acdbd
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 39 deletions.
2 changes: 1 addition & 1 deletion bench/EmbeddingSpMDMNBit2Benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ static void print_benchmark_results() {
<< "autovec b/w (GB/s), autovec effective b/w (GB/s), autovec time, "
<< "ref b/w (GB/s), ref effective b/w (GB/s), ref time, "
<< "asmjit speedup ratio, autovec speedup ratio" << std::endl;
for (int i = 0; i < benchmarks.size(); ++i) {
for (size_t i = 0; i < benchmarks.size(); ++i) {
BenchmarkSpec& spec = benchmarks[i].first;
BenchmarkResult& res = benchmarks[i].second;
float asmjit_speedup = res.ref_bw > 0.0 ? res.asmjit_bw / res.ref_bw : 0;
Expand Down
30 changes: 18 additions & 12 deletions fbgemm_gpu/fbgemm_gpu/split_embedding_inference_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,11 @@ def _process_split_embs(self, model: torch.nn.Module) -> None:
pruned_weight.size()[0],
D,
weight_ty,
EmbeddingLocation.HOST
if use_cpu
else EmbeddingLocation.DEVICE,
(
EmbeddingLocation.HOST
if use_cpu
else EmbeddingLocation.DEVICE
),
)
)
index_remapping_list.append(index_remapping)
Expand All @@ -144,19 +146,23 @@ def _process_split_embs(self, model: torch.nn.Module) -> None:

q_child = IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=new_embedding_specs,
index_remapping=index_remapping_list
if self.pruning_ratio is not None
else None,
index_remapping=(
index_remapping_list if self.pruning_ratio is not None else None
),
pooling_mode=child.pooling_mode,
device="cpu" if use_cpu else torch.cuda.current_device(),
weight_lists=weight_lists,
use_array_for_index_remapping=self.use_array_for_index_remapping,
fp8_exponent_bits=self._get_quantization_config("exponent_bits")
if is_fp8_weight
else None,
fp8_exponent_bias=self._get_quantization_config("exponent_bias")
if is_fp8_weight
else None,
fp8_exponent_bits=(
self._get_quantization_config("exponent_bits")
if is_fp8_weight
else None
),
fp8_exponent_bias=(
self._get_quantization_config("exponent_bias")
if is_fp8_weight
else None
),
)
setattr(model, name, q_child)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,11 @@ def construct_cache_state(
start, end = _cache_hash_size_cumsum[t_], _cache_hash_size_cumsum[t_ + 1]
cache_index_table_map[start:end] = [t] * (end - start)
cache_hash_size_cumsum = [
_cache_hash_size_cumsum[t_]
if location_list[t_] == EmbeddingLocation.MANAGED_CACHING
else -1
(
_cache_hash_size_cumsum[t_]
if location_list[t_] == EmbeddingLocation.MANAGED_CACHING
else -1
)
for t_ in feature_table_map
]
cache_hash_size_cumsum.append(total_cache_hash_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,9 @@ def _apply_cache_state(
], "Only 1-way or 32-way(64-way for AMD) implmeneted for now"

self.cache_algorithm = cache_algorithm
# pyre-ignore[16]
self.timestep_counter = torch.classes.fbgemm.AtomicCounter()
# pyre-ignore[16]
self.timestep_prefetch_size = torch.classes.fbgemm.AtomicCounter()

self.max_prefetch_depth = MAX_PREFETCH_DEPTH
Expand All @@ -959,6 +961,7 @@ def _apply_cache_state(
lxu_cache_locations_empty = torch.empty(
0, device=self.current_device, dtype=torch.int32
).fill_(-1)
# pyre-ignore[16]
self.lxu_cache_locations_list = torch.classes.fbgemm.TensorQueue(
lxu_cache_locations_empty
)
Expand Down Expand Up @@ -1100,9 +1103,11 @@ def _apply_cache_state(
self.register_buffer(
"lxu_state",
torch.zeros(
size=(self.total_cache_hash_size + 1,)
if cache_algorithm == CacheAlgorithm.LFU
else (cache_sets, self.cache_assoc),
size=(
(self.total_cache_hash_size + 1,)
if cache_algorithm == CacheAlgorithm.LFU
else (cache_sets, self.cache_assoc)
),
device=self.current_device,
dtype=torch.int64,
),
Expand Down Expand Up @@ -1294,7 +1299,7 @@ def split_embedding_weights_with_scale_bias(
@torch.jit.export
def split_embedding_weights(
self,
split_scale_shifts: bool = True
split_scale_shifts: bool = True,
# When true, return list of two tensors, the first with weights and
# the second with scale_bias.
# This should've been named as split_scale_bias.
Expand All @@ -1303,11 +1308,13 @@ def split_embedding_weights(
"""
Returns a list of weights, split by table
"""
splits: List[
Tuple[Tensor, Optional[Tensor], Optional[Tensor]]
] = self.split_embedding_weights_with_scale_bias(
split_scale_bias_mode=(1 if split_scale_shifts else 0)
# fmt: off
splits: List[Tuple[Tensor, Optional[Tensor], Optional[Tensor]]] = (
self.split_embedding_weights_with_scale_bias(
split_scale_bias_mode=(1 if split_scale_shifts else 0)
)
)
# fmt: on
return [
(split_weight_scale_bias[0], split_weight_scale_bias[1])
for split_weight_scale_bias in splits
Expand Down Expand Up @@ -1411,9 +1418,11 @@ def set_index_remappings(
# Hash mapping pruning
if not use_array_for_index_remapping:
capacities = [
round_up(int(row * 1.0 / pruning_hash_load_factor), 32)
if index_remap is not None
else 0
(
round_up(int(row * 1.0 / pruning_hash_load_factor), 32)
if index_remap is not None
else 0
)
for (index_remap, row) in zip(index_remapping, rows)
]
hash_table = torch.empty(
Expand Down Expand Up @@ -1445,6 +1454,7 @@ def set_index_remappings(

if self.use_cpu:
self.index_remapping_hash_table_cpu = (
# pyre-ignore[16]
torch.classes.fbgemm.PrunedMapCPU()
)
self.index_remapping_hash_table_cpu.insert(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,11 @@ def __init__( # noqa C901
embedding_specs,
rowwise=rowwise,
cacheable=False,
placement=EmbeddingLocation.MANAGED
if ((not rowwise) and uvm_non_rowwise_momentum)
else None,
placement=(
EmbeddingLocation.MANAGED
if ((not rowwise) and uvm_non_rowwise_momentum)
else None
),
),
prefix="momentum1",
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
Expand All @@ -671,9 +673,11 @@ def __init__( # noqa C901
embedding_specs,
rowwise=rowwise,
cacheable=False,
placement=EmbeddingLocation.MANAGED
if ((not rowwise) and uvm_non_rowwise_momentum)
else None,
placement=(
EmbeddingLocation.MANAGED
if ((not rowwise) and uvm_non_rowwise_momentum)
else None
),
),
prefix="momentum2",
# pyre-fixme[6]: Expected `Type[Type[torch._dtype]]` for 3rd param
Expand Down Expand Up @@ -1411,9 +1415,11 @@ def get_optimizer_state(self) -> List[Dict[str, torch.Tensor]]:
or self.optimizer == OptimType.EXACT_ADAGRAD
):
list_of_state_dict = [
{"sum": states[0], "prev_iter": states[1], "row_counter": states[2]}
if self._used_rowwise_adagrad_with_counter
else {"sum": states[0]}
(
{"sum": states[0], "prev_iter": states[1], "row_counter": states[2]}
if self._used_rowwise_adagrad_with_counter
else {"sum": states[0]}
)
for states in split_optimizer_states
]
elif self.optimizer == OptimType.SGD or self.optimizer == OptimType.EXACT_SGD:
Expand Down Expand Up @@ -1741,9 +1747,11 @@ def _apply_cache_state(
self.register_buffer(
"lxu_state",
torch.zeros(
size=(self.total_cache_hash_size + 1,)
if cache_algorithm == CacheAlgorithm.LFU
else (cache_sets, DEFAULT_ASSOC),
size=(
(self.total_cache_hash_size + 1,)
if cache_algorithm == CacheAlgorithm.LFU
else (cache_sets, DEFAULT_ASSOC)
),
device=self.current_device,
dtype=torch.int64,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(
prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
ssd_directory,
ssd_shards,
Expand Down Expand Up @@ -770,6 +771,7 @@ def max_ty_D(ty: SparseType) -> int:
prefix="ssd_table_batched_embeddings", dir=ssd_storage_directory
)
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
ssd_directory,
ssd_shards,
Expand All @@ -794,8 +796,10 @@ def max_ty_D(ty: SparseType) -> int:
self.ssd_set_end = torch.cuda.Event()

# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.timestep_counter = torch.classes.fbgemm.AtomicCounter()
# pyre-fixme[4]: Attribute must be annotated.
# pyre-ignore[16]
self.timestep_prefetch_size = torch.classes.fbgemm.AtomicCounter()

self.weights_dev: torch.Tensor = torch.empty(
Expand Down
9 changes: 9 additions & 0 deletions include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ class PackMatrix {
int cols = 0,
const BlockingFactors* params = nullptr);

#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winfinite-recursion"
#endif

/**
* @return Pointer to a buffer containing row offset results. Some packing
* objects fuse row offset computation for later requantization step.
Expand All @@ -147,6 +152,10 @@ class PackMatrix {
return static_cast<const PT*>(this)->getRowOffsetBuffer();
}

#if defined(__clang__)
#pragma clang diagnostic pop
#endif

/**
* @brief When k loop is also tiled/blocked, this function is used to check if
* have executed computations for the last k block so that we can perform
Expand Down

0 comments on commit 86acdbd

Please sign in to comment.