diff --git a/api/methods/report.go b/api/methods/report.go index 3e2db3b..165bc50 100644 --- a/api/methods/report.go +++ b/api/methods/report.go @@ -39,8 +39,9 @@ func SetUnitName(c *gin.Context) { unitId := c.MustGet("UnitId").(string) dbpool, dbctx := storage.ReportInstance() // check if uuid is valid - _, err := dbpool.Exec(dbctx, "SELECT name FROM units WHERE uuid = $1", unitId) - if err != nil { + var uuid string + err := dbpool.QueryRow(dbctx, "SELECT uuid FROM units WHERE uuid = $1", unitId).Scan(&uuid) + if err != nil || uuid != unitId { // insert a new unit and return the id _, err := dbpool.Exec(dbctx, "INSERT INTO units (uuid, name) VALUES ($1, $2)", unitId, req.Name) if err != nil { @@ -72,13 +73,10 @@ func checkUnitId(unitId string) error { dbpool, dbctx := storage.ReportInstance() // check if uuid is valid - _, err := dbpool.Exec(dbctx, "SELECT uuid FROM units WHERE uuid = $1", unitId) - if err != nil { - // insert a new unit and return the id - _, err := dbpool.Exec(dbctx, "INSERT INTO units (uuid) VALUES ($1)", unitId) - if err != nil { - return errors.New("error inserting unit") - } + var uuid string + err := dbpool.QueryRow(dbctx, "SELECT uuid FROM units WHERE uuid = $1", unitId).Scan(&uuid) + if err != nil || uuid != unitId { + return errors.New("unit not found") } return nil