diff --git a/torchsig/datasets/synthetic.py b/torchsig/datasets/synthetic.py index ca15842..a98b8bc 100644 --- a/torchsig/datasets/synthetic.py +++ b/torchsig/datasets/synthetic.py @@ -155,7 +155,7 @@ def __init__( modulations=gfsks, num_iq_samples=num_iq_samples, num_samples_per_class=num_samples_per_class, - iq_samples_per_symbol=8 if iq_samples_per_symbol is None else iq_samples_per_symbol, + iq_samples_per_symbol=8, random_data=random_data, random_pulse_shaping=random_pulse_shaping, **kwargs @@ -692,13 +692,15 @@ def __init__( for freq_idx, freq_name in enumerate(map(str.lower, self.modulations)): for idx in range(self.num_samples_per_class): - if "g" in freq_name: - bandwidth = np.random.uniform(0.1, 0.5) if self.random_pulse_shaping else 0.35 - else: - bandwidth = np.random.uniform( - (1 / self.iq_samples_per_symbol) * 1.25, - (1 / self.iq_samples_per_symbol) * 3.75, - ) if self.random_pulse_shaping else 0.0 + # modulation index scales the bandwidth of the signal, and + # iq_samples_per_symbol is used as an oversampling rate in + # FSKDataset class, therefore the signal bandwidth can be + # approximated by mod_idx/iq_samples_per_symbol + mod_idx = self._mod_index(freq_name) + bandwidth = np.random.uniform( + (mod_idx / self.iq_samples_per_symbol) * 1.25, + (mod_idx / self.iq_samples_per_symbol) * 3.75, + ) if self.random_pulse_shaping else 0.0 signal_description = SignalDescription( sample_rate=0, bits_per_symbol=np.log2(len(freq_map[freq_name])), @@ -739,28 +741,24 @@ def _generate_samples(self, item: Tuple) -> np.ndarray: symbols = const_oversampled[symbol_nums] symbols_repeat = xp.repeat(symbols, samples_per_symbol_recalculated) - symbols_repeat = np.insert(symbols_repeat,0,0) # start at zero phase - filtered = symbols_repeat if "g" in const_name: + # GMSK, GFSK taps = self._gaussian_taps(samples_per_symbol_recalculated,bandwidth) signal_description.excess_bandwidth = bandwidth filtered = xp.convolve(xp.array(symbols_repeat), xp.array(taps), "same") - - if ("gfsk" in const_name): - # bluetooth - mod_idx = 0.32 - elif ("msk" in const_name): - # MSK, GMSK - mod_idx = 0.5 else: - # FSK - mod_idx = 1.0 + # FSK, MSK + filtered = symbols_repeat + + # insert a zero at first sample to start at zero phase + filtered = np.insert(filtered,0,0) + mod_idx = self._mod_index(const_name) phase = xp.cumsum(xp.array(filtered) * 1j * mod_idx * np.pi) modulated = xp.exp(phase) - if "g" not in const_name and self.random_pulse_shaping: + if self.random_pulse_shaping: # Apply a randomized LPF simulating a noisy detector/burst extractor, then downsample to ~fs/2 bw lpf_bandwidth = bandwidth num_taps = int(np.ceil(50 * 2 * np.pi / lpf_bandwidth / .125 / 22)) @@ -814,6 +812,20 @@ def _gaussian_taps(self, samples_per_symbol, BT: float = 0.35) -> np.ndarray: return p + def _mod_index(self, const_name): + # returns the modulation index based on the modulation + if ("gfsk" in const_name): + # bluetooth + mod_idx = 0.32 + elif ("msk" in const_name): + # MSK, GMSK + mod_idx = 0.5 + else: + # FSK + mod_idx = 1.0 + return mod_idx + + class AMDataset(SyntheticDataset): """AM Dataset