diff --git a/environments/environments.go b/environments/environments.go index 9ab58e22..ae1e2912 100644 --- a/environments/environments.go +++ b/environments/environments.go @@ -86,7 +86,7 @@ type TLSEnvironment struct { CarverInitPath string CarverBlockPath string AcceptEnrolls bool - UserID int + UserID uint } // MapEnvironments to hold the TLS environments by name and UUID diff --git a/nodes/nodes.go b/nodes/nodes.go index 1010bd00..5fad0287 100644 --- a/nodes/nodes.go +++ b/nodes/nodes.go @@ -36,7 +36,8 @@ type OsqueryNode struct { LastConfig time.Time LastQueryRead time.Time LastQueryWrite time.Time - UserID int + UserID uint + EnvironmentID uint } // ArchiveOsqueryNode as abstraction of an archived node @@ -66,7 +67,8 @@ type ArchiveOsqueryNode struct { LastConfig time.Time LastQueryRead time.Time LastQueryWrite time.Time - UserID int + UserID uint + EnvironmentID uint } // StatsData to display node stats @@ -136,6 +138,14 @@ func (n *NodeManager) CheckByUUIDEnv(uuid, environment string) bool { return (results > 0) } +// CheckByUUIDEnvID to check if node exists by UUID in a specific environment +// UUID is expected uppercase +func (n *NodeManager) CheckByUUIDEnvID(uuid string, envID int) bool { + var results int64 + n.DB.Model(&OsqueryNode{}).Where("uuid = ? AND environment_id = ?", strings.ToUpper(uuid), envID).Count(&results) + return (results > 0) +} + // CheckByHost to check if node exists by Hostname func (n *NodeManager) CheckByHost(host string) bool { var results int64 @@ -468,6 +478,7 @@ func nodeArchiveFromNode(node OsqueryNode, trigger string) ArchiveOsqueryNode { LastQueryRead: node.LastQueryRead, LastQueryWrite: node.LastQueryWrite, UserID: node.UserID, + EnvironmentID: node.EnvironmentID, } } diff --git a/tls/handlers/handlers.go b/tls/handlers/handlers.go index f032936c..d01dac62 100644 --- a/tls/handlers/handlers.go +++ b/tls/handlers/handlers.go @@ -255,7 +255,7 @@ func (h *HandlersTLS) EnrollHandler(w http.ResponseWriter, r *http.Request) { if h.checkValidSecret(t.EnrollSecret, env) { // Generate node_key using UUID as entropy nodeKey = generateNodeKey(t.HostIdentifier, time.Now()) - newNode = nodeFromEnroll(t, env.Name, utils.GetIP(r), nodeKey, len(body), env.UserID) + newNode = nodeFromEnroll(t, env, utils.GetIP(r), nodeKey, len(body)) // Check if UUID exists already, if so archive node and enroll new node if h.Nodes.CheckByUUIDEnv(t.HostIdentifier, env.Name) { if err := h.Nodes.Archive(t.HostIdentifier, "exists"); err != nil { diff --git a/tls/handlers/utils.go b/tls/handlers/utils.go index 28392bc2..64bbad05 100644 --- a/tls/handlers/utils.go +++ b/tls/handlers/utils.go @@ -68,7 +68,7 @@ func (h *HandlersTLS) checkExpiredPath(maybeExpired time.Time) bool { } // Helper to convert an enrollment request into a osquery node -func nodeFromEnroll(req types.EnrollRequest, environment, ipaddress, nodekey string, recBytes, envUserID int) nodes.OsqueryNode { +func nodeFromEnroll(req types.EnrollRequest, env environments.TLSEnvironment, ipaddress, nodekey string, recBytes int) nodes.OsqueryNode { // Prepare the enrollment request to be stored as raw JSON enrollRaw, err := json.Marshal(req) if err != nil { @@ -88,7 +88,7 @@ func nodeFromEnroll(req types.EnrollRequest, environment, ipaddress, nodekey str IPAddress: ipaddress, Username: "unknown", OsqueryUser: "unknown", - Environment: environment, + Environment: env.Name, CPU: strings.TrimRight(req.HostDetails.EnrollSystemInfo.CPUBrand, "\x00"), Memory: req.HostDetails.EnrollSystemInfo.PhysicalMemory, HardwareSerial: req.HostDetails.EnrollSystemInfo.HardwareSerial, @@ -100,7 +100,8 @@ func nodeFromEnroll(req types.EnrollRequest, environment, ipaddress, nodekey str LastConfig: time.Time{}, LastQueryRead: time.Time{}, LastQueryWrite: time.Time{}, - UserID: envUserID, + UserID: env.UserID, + EnvironmentID: env.ID, } }