Skip to content

Commit

Permalink
Update kitty_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
HaozheQi authored Mar 9, 2020
1 parent e0b90a8 commit ccac70f
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions kitty_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,14 @@ def regularizePC2(input_size, PC,):
return regularizePC(PC=PC, input_size=input_size)


def regularizePC(PC,input_size):
def regularizePC(PC,input_size,istrain=True):
PC = np.array(PC.points, dtype=np.float32)
if np.shape(PC)[1] > 2:
if PC.shape[0] > 3:
PC = PC[0:3, :]
if PC.shape[1] != int(input_size/2):
np.random.seed(1)
if not istrain:
np.random.seed(1)
new_pts_idx = np.random.randint(
low=0, high=PC.shape[1], size=int(input_size/2), dtype=np.int64)
PC = PC[:, new_pts_idx]
Expand All @@ -108,13 +109,14 @@ def regularizePC(PC,input_size):

return torch.from_numpy(PC).float()

def regularizePCwithlabel(PC,label,reg, input_size):
def regularizePCwithlabel(PC,label,reg, input_size,istrain=True):
PC = np.array(PC.points, dtype=np.float32)
if np.shape(PC)[1] > 2:
if PC.shape[0] > 3:
PC = PC[0:3, :]
if PC.shape[1] != input_size:
np.random.seed(1)
if not istrain:
np.random.seed(1)
new_pts_idx = np.random.randint(
low=0, high=PC.shape[1], size=input_size, dtype=np.int64)
PC = PC[:, new_pts_idx]
Expand Down

0 comments on commit ccac70f

Please sign in to comment.