diff --git a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandler.java b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandler.java index 483dfb936..0e44485b7 100644 --- a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandler.java +++ b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandler.java @@ -8,6 +8,7 @@ import com.intuit.tank.vm.agent.messages.AgentWsEnvelope; import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus; import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.Type; +import com.intuit.tank.vm.api.enumerated.JobStatus; import com.intuit.tank.vm.vmManager.VMTracker; import com.intuit.tank.vm.settings.AgentConfig; import com.intuit.tank.vm.settings.TankConfig; @@ -102,6 +103,7 @@ public enum AgentWsState { private ServletContext servletContext; private volatile JobManager cachedJobManager; + private volatile AgentStatusLifecycle cachedAgentStatusLifecycle; private volatile VMTracker cachedVMTracker; private volatile AgentConfig cachedAgentConfig; @@ -299,16 +301,31 @@ private void handleStatusUpdate(WebSocketSession session, AgentWsEnvelope envelo status.setInstanceId(boundId); - VMTracker tracker = resolveVMTracker(); - if (tracker == null) { - LOG.warn(new ObjectMessage(Map.of("Message", "Unable to resolve VMTracker for WS status update from " + boundId))); + AgentStatusLifecycle statusLifecycle = resolveAgentStatusLifecycle(); + if (statusLifecycle != null) { + try { + statusLifecycle.setVmStatus(boundId, status); + return; + } catch (Exception e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Failed WS lifecycle status update from " + boundId + ": " + e.getMessage()))); + } + } + + if (isTerminalStatus(status)) { + LOG.error(new ObjectMessage(Map.of("Message", "Unable to process terminal WS status for " + boundId + + " - lifecycle handler unavailable, not falling back to non-terminating tracker update"))); return; } - try { - tracker.setStatus(status); - } catch (Exception e) { - LOG.warn(new ObjectMessage(Map.of("Message", "Failed WS status update from " + boundId + ": " + e.getMessage()))); + VMTracker tracker = resolveVMTracker(); + if (tracker != null) { + try { + tracker.setStatus(status); + } catch (Exception e) { + LOG.warn(new ObjectMessage(Map.of("Message", "Failed WS tracker status update from " + boundId + ": " + e.getMessage()))); + } + } else { + LOG.warn(new ObjectMessage(Map.of("Message", "Unable to resolve status handler for WS status update from " + boundId))); } } @@ -573,6 +590,24 @@ private JobManager resolveJobManager() { } } + private AgentStatusLifecycle resolveAgentStatusLifecycle() { + AgentStatusLifecycle statusLifecycle = cachedAgentStatusLifecycle; + if (statusLifecycle != null) { + return statusLifecycle; + } + if (servletContext == null) { + return null; + } + try { + statusLifecycle = new ServletInjector().getManagedBean(servletContext, AgentStatusLifecycle.class); + cachedAgentStatusLifecycle = statusLifecycle; + return statusLifecycle; + } catch (Exception e) { + LOG.error(new ObjectMessage(Map.of("Message", "Error resolving AgentStatusLifecycle: " + e.getMessage())), e); + return null; + } + } + private VMTracker resolveVMTracker() { VMTracker tracker = cachedVMTracker; if (tracker != null) { @@ -649,6 +684,12 @@ private boolean isTerminalVmStatus(VMStatus status) { || status == VMStatus.disconnected; } + private boolean isTerminalStatus(CloudVmStatus status) { + return status.getJobStatus() == JobStatus.Completed + || status.getVmStatus() == VMStatus.terminated + || status.getVmStatus() == VMStatus.replaced; + } + private void handleJobTransferComplete(String instanceId) { String jobId = agentJobs.get(instanceId); if (jobId == null) { diff --git a/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentStatusLifecycle.java b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentStatusLifecycle.java new file mode 100644 index 000000000..592c2827e --- /dev/null +++ b/rest-mvc/impl/src/main/java/com/intuit/tank/rest/mvc/AgentStatusLifecycle.java @@ -0,0 +1,91 @@ +package com.intuit.tank.rest.mvc; + +import com.intuit.tank.dao.JobInstanceDao; +import com.intuit.tank.project.JobInstance; +import com.intuit.tank.vm.api.enumerated.JobQueueStatus; +import com.intuit.tank.vm.api.enumerated.JobStatus; +import com.intuit.tank.vm.vmManager.VMTerminator; +import com.intuit.tank.vm.vmManager.VMTracker; +import com.intuit.tank.vm.vmManager.models.CloudVmStatus; +import com.intuit.tank.vm.vmManager.models.CloudVmStatusContainer; +import com.intuit.tank.vm.vmManager.models.VMStatus; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; +import org.apache.commons.lang3.math.NumberUtils; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.util.Arrays; +import java.util.Date; +import java.util.List; + +@ApplicationScoped +public class AgentStatusLifecycle { + + private static final Logger LOG = LogManager.getLogger(AgentStatusLifecycle.class); + + @Inject + private VMTracker vmTracker; + + @Inject + private VMTerminator terminator; + + public void setVmStatus(final String instanceId, final CloudVmStatus status) { + status.setInstanceId(instanceId); + vmTracker.setStatus(status); + if (isTerminalStatus(status)) { + terminator.terminate(status.getInstanceId()); + checkJobStatus(status.getJobId()); + } + } + + private boolean isTerminalStatus(CloudVmStatus status) { + return status.getJobStatus() == JobStatus.Completed + || status.getVmStatus() == VMStatus.terminated + || status.getVmStatus() == VMStatus.replaced; + } + + public void checkJobStatus(String jobId) { + CloudVmStatusContainer container = vmTracker.getVmStatusForJob(jobId); + LOG.info("Checking Job Status to see if we can kill reporting instances. Container=" + container); + if (container != null) { + if (container.getEndTime() != null) { + JobInstanceDao dao = new JobInstanceDao(); + + JobInstance finishedJob = dao.findById(Integer.valueOf(jobId)); + if (finishedJob != null && finishedJob.getEndTime() == null + && finishedJob.getStatus() != JobQueueStatus.Deleted) { + finishedJob.setEndTime(new Date()); + finishedJob.setStatus(JobQueueStatus.Completed); + dao.saveOrUpdate(finishedJob); + } + List statuses = Arrays.asList(JobQueueStatus.Running, JobQueueStatus.Starting); + List instances = dao.getForStatus(statuses); + LOG.info("Checking Job Status to see if we can kill reporting instances. found running instances: " + + instances.size()); + boolean killModal = true; + boolean killNonRegional = true; + + for (JobInstance job : instances) { + CloudVmStatusContainer statusForJob = vmTracker.getVmStatusForJob(Integer.toString(job.getId())); + if (!jobId.equals(Integer.toString(job.getId())) && statusForJob != null + && statusForJob.getEndTime() == null) { + LOG.info("Found another job that is not finished: " + job); + } + } + if (killNonRegional || killModal) { + for (CloudVmStatusContainer statusForJob : vmTracker.getAllJobs()) { + if (statusForJob.getEndTime() == null && !NumberUtils.isCreatable(statusForJob.getJobId())) { + killNonRegional = false; + killModal = false; + LOG.info("Cannot kill Reporting instances because of automation job id: " + + statusForJob.getJobId()); + } + } + } + } else { + LOG.info("Container does not have end time set so cannot kill reporting instances."); + } + } + } +} diff --git a/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandlerTest.java b/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandlerTest.java index df8135f2d..e703b4cab 100644 --- a/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandlerTest.java +++ b/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentCommandWebSocketHandlerTest.java @@ -303,4 +303,68 @@ void testStatusUpdateDelegatesToVmTrackerAndEnforcesBoundIdentity() throws Excep assertEquals("job-1", statusCaptor.getValue().getJobId()); assertEquals(JobStatus.Running, statusCaptor.getValue().getJobStatus()); } + + @Test + void testTerminalStatusUpdateDelegatesToAgentStatusLifecycleForTermination() throws Exception { + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + + AgentStatusLifecycle statusLifecycle = mock(AgentStatusLifecycle.class); + Field statusLifecycleField = AgentCommandWebSocketHandler.class.getDeclaredField("cachedAgentStatusLifecycle"); + statusLifecycleField.setAccessible(true); + statusLifecycleField.set(handler, statusLifecycle); + + CloudVmStatus status = new CloudVmStatus( + "i-spoofed", + "job-1", + "sg-1", + JobStatus.Completed, + VMImageType.AGENT, + VMRegion.US_EAST, + VMStatus.terminated, + new ValidationStatus(), + 5, + 0, + new Date(), + new Date()); + + AgentWsEnvelope statusUpdate = AgentWsEnvelope.statusUpdate("i-123", "job-1", status); + handler.handleTextMessage(session, new TextMessage(statusUpdate.toJson())); + + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(CloudVmStatus.class); + verify(statusLifecycle).setVmStatus(eq("i-123"), statusCaptor.capture()); + assertEquals("i-123", statusCaptor.getValue().getInstanceId()); + assertEquals(JobStatus.Completed, statusCaptor.getValue().getJobStatus()); + assertEquals(VMStatus.terminated, statusCaptor.getValue().getVmStatus()); + } + + @Test + void testTerminalStatusDoesNotFallBackToTrackerWhenLifecycleUnavailable() throws Exception { + AgentWsEnvelope hello = AgentWsEnvelope.hello("i-123", "job-1", "sess-1", null); + handler.handleTextMessage(session, new TextMessage(hello.toJson())); + + VMTracker vmTracker = mock(VMTracker.class); + Field vmTrackerField = AgentCommandWebSocketHandler.class.getDeclaredField("cachedVMTracker"); + vmTrackerField.setAccessible(true); + vmTrackerField.set(handler, vmTracker); + + CloudVmStatus status = new CloudVmStatus( + "i-spoofed", + "job-1", + "sg-1", + JobStatus.Completed, + VMImageType.AGENT, + VMRegion.US_EAST, + VMStatus.terminated, + new ValidationStatus(), + 5, + 0, + new Date(), + new Date()); + + AgentWsEnvelope statusUpdate = AgentWsEnvelope.statusUpdate("i-123", "job-1", status); + handler.handleTextMessage(session, new TextMessage(statusUpdate.toJson())); + + verify(vmTracker, never()).setStatus(any(CloudVmStatus.class)); + } } diff --git a/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentStatusLifecycleTest.java b/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentStatusLifecycleTest.java new file mode 100644 index 000000000..2632bc1b1 --- /dev/null +++ b/rest-mvc/impl/src/test/java/com/intuit/tank/rest/mvc/AgentStatusLifecycleTest.java @@ -0,0 +1,84 @@ +package com.intuit.tank.rest.mvc; + +import com.intuit.tank.vm.api.enumerated.JobStatus; +import com.intuit.tank.vm.api.enumerated.VMImageType; +import com.intuit.tank.vm.api.enumerated.VMRegion; +import com.intuit.tank.vm.vmManager.VMTerminator; +import com.intuit.tank.vm.vmManager.VMTracker; +import com.intuit.tank.vm.vmManager.models.CloudVmStatus; +import com.intuit.tank.vm.vmManager.models.VMStatus; +import com.intuit.tank.vm.vmManager.models.ValidationStatus; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.Date; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +@ExtendWith(MockitoExtension.class) +public class AgentStatusLifecycleTest { + + @Mock + private VMTracker vmTracker; + + @Mock + private VMTerminator terminator; + + @InjectMocks + private AgentStatusLifecycle lifecycle; + + @Test + void testRunningStatusUpdatesTrackerWithoutTermination() { + CloudVmStatus status = createStatus("i-running", JobStatus.Running, VMStatus.running); + + lifecycle.setVmStatus("i-running", status); + + verify(vmTracker).setStatus(status); + verify(terminator, never()).terminate("i-running"); + } + + @Test + void testCompletedStatusTriggersTerminationAndEnforcesInstanceId() { + CloudVmStatus status = createStatus("i-spoofed", JobStatus.Completed, VMStatus.running); + + lifecycle.setVmStatus("i-done", status); + + verify(vmTracker).setStatus(argThat(updated -> "i-done".equals(updated.getInstanceId()) + && updated.getJobStatus() == JobStatus.Completed)); + verify(terminator).terminate("i-done"); + verify(vmTracker).getVmStatusForJob("123"); + assertEquals("i-done", status.getInstanceId()); + } + + @Test + void testTerminatedVmStatusTriggersTermination() { + CloudVmStatus status = createStatus("i-term", JobStatus.Running, VMStatus.terminated); + + lifecycle.setVmStatus("i-term", status); + + verify(vmTracker).setStatus(status); + verify(terminator).terminate("i-term"); + } + + private CloudVmStatus createStatus(String instanceId, JobStatus jobStatus, VMStatus vmStatus) { + return new CloudVmStatus( + instanceId, + "123", + "sg-1", + jobStatus, + VMImageType.AGENT, + VMRegion.US_EAST, + vmStatus, + new ValidationStatus(), + 5, + jobStatus == JobStatus.Completed ? 0 : 1, + new Date(), + jobStatus == JobStatus.Completed ? new Date() : null); + } +} diff --git a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java index 189a39bec..73a6bca2b 100644 --- a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java +++ b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClient.java @@ -6,9 +6,12 @@ import com.intuit.tank.vm.agent.messages.AgentWsEnvelope; import com.intuit.tank.vm.agent.messages.AgentWsEnvelope.AckStatus; import com.intuit.tank.vm.agent.messages.DataFileRequest; +import com.intuit.tank.vm.api.enumerated.JobStatus; import com.intuit.tank.vm.settings.TankConfig; +import com.intuit.tank.vm.vmManager.VMTerminator; import com.intuit.tank.vm.vmManager.VMTracker; import com.intuit.tank.vm.vmManager.models.CloudVmStatus; +import com.intuit.tank.vm.vmManager.models.VMStatus; import jakarta.enterprise.context.ApplicationScoped; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -61,9 +64,11 @@ public class ControllerInitiatedAgentWsClient implements AgentWsCommandSender { private final ConcurrentHashMap agentLastSeen = new ConcurrentHashMap<>(); private final ConcurrentHashMap agentWsState = new ConcurrentHashMap<>(); private final ConcurrentHashMap agentTransferProgress = new ConcurrentHashMap<>(); + private final java.util.Set terminationRequestedInstances = ConcurrentHashMap.newKeySet(); private volatile byte[] cachedHarnessJarBytes; private volatile VMTracker vmTracker; + private volatile VMTerminator vmTerminator; public ControllerInitiatedAgentWsClient() { } @@ -72,6 +77,10 @@ public void setVmTracker(VMTracker vmTracker) { this.vmTracker = vmTracker; } + public void setVmTerminator(VMTerminator vmTerminator) { + this.vmTerminator = vmTerminator; + } + public Optional connect(String instanceId, String wsUrl, String token, long helloTimeoutMillis) { try { SessionContext existing = sessions.get(instanceId); @@ -665,12 +674,43 @@ private void handleStatusUpdate(String instanceId, AgentWsEnvelope envelope) { } try { status.setInstanceId(instanceId); + requestTerminationForTerminalStatus(instanceId, status); vmTracker.setStatus(status); } catch (Exception e) { LOG.warn(new ObjectMessage(Map.of("Message", "[WS] Failed status update from " + instanceId + ": " + e.getMessage()))); } } + private void requestTerminationForTerminalStatus(String instanceId, CloudVmStatus status) { + if (!isTerminalStatus(status)) { + return; + } + VMTerminator terminator = vmTerminator; + if (terminator == null) { + LOG.error(new ObjectMessage(Map.of("Message", "[WS] Terminal status from " + instanceId + + " but VMTerminator is unavailable; instance termination was not scheduled"))); + return; + } + if (!terminationRequestedInstances.add(instanceId)) { + return; + } + try { + LOG.info(new ObjectMessage(Map.of("Message", "[WS] Scheduling VM termination for terminal status from " + + instanceId + " job " + status.getJobId()))); + terminator.terminate(instanceId); + } catch (Exception e) { + terminationRequestedInstances.remove(instanceId); + LOG.error(new ObjectMessage(Map.of("Message", "[WS] Failed scheduling VM termination for " + + instanceId + ": " + e.getMessage())), e); + } + } + + private boolean isTerminalStatus(CloudVmStatus status) { + return status.getJobStatus() == JobStatus.Completed + || status.getVmStatus() == VMStatus.terminated + || status.getVmStatus() == VMStatus.replaced; + } + private void onClosed(String instanceId, WebSocket webSocket) { SessionContext context = sessions.get(instanceId); if (context == null) { @@ -693,6 +733,7 @@ private void onClosed(String instanceId, WebSocket webSocket) { "[WS] Ignoring close for replaced session " + instanceId))); return; } + terminationRequestedInstances.remove(instanceId); context.markClosed(); fileTransferReady.remove(instanceId); PendingChunkAck pending = pendingChunkAcks.remove(instanceId); diff --git a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java index aa329d272..a9c8b99d6 100644 --- a/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java +++ b/tank_vmManager/src/main/java/com/intuit/tank/perfManager/workLoads/JobManager.java @@ -61,6 +61,7 @@ import com.intuit.tank.vm.vmManager.JobRequest; import com.intuit.tank.vm.vmManager.JobVmCalculator; import com.intuit.tank.vm.vmManager.RegionRequest; +import com.intuit.tank.vm.vmManager.VMTerminator; import com.intuit.tank.vmManager.environment.amazon.AmazonInstance; import org.apache.logging.log4j.message.ObjectMessage; @@ -80,6 +81,9 @@ public class JobManager implements Serializable { @Inject private VMTracker vmTracker; + @Inject + private VMTerminator vmTerminator; + @Inject private StandaloneAgentTracker standaloneTracker; @@ -521,6 +525,7 @@ private com.intuit.tank.vm.agent.messages.AgentWsCommandSender getWsCommandSende public ControllerInitiatedAgentWsClient getControllerInitiatedAgentWsClient() { controllerInitiatedAgentWsClient.setVmTracker(vmTracker); + controllerInitiatedAgentWsClient.setVmTerminator(vmTerminator); com.intuit.tank.vm.agent.messages.AgentWsCommandSender.setStaticInstance(controllerInitiatedAgentWsClient); return controllerInitiatedAgentWsClient; } diff --git a/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClientTest.java b/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClientTest.java new file mode 100644 index 000000000..8f646d2fa --- /dev/null +++ b/tank_vmManager/src/test/java/com/intuit/tank/perfManager/workLoads/ControllerInitiatedAgentWsClientTest.java @@ -0,0 +1,147 @@ +package com.intuit.tank.perfManager.workLoads; + +import com.intuit.tank.vm.agent.messages.AgentWsEnvelope; +import com.intuit.tank.vm.api.enumerated.JobStatus; +import com.intuit.tank.vm.api.enumerated.VMImageType; +import com.intuit.tank.vm.api.enumerated.VMRegion; +import com.intuit.tank.vm.vmManager.VMTerminator; +import com.intuit.tank.vm.vmManager.VMTracker; +import com.intuit.tank.vm.vmManager.models.CloudVmStatus; +import com.intuit.tank.vm.vmManager.models.VMStatus; +import com.intuit.tank.vm.vmManager.models.ValidationStatus; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.net.http.WebSocket; +import java.util.Date; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class ControllerInitiatedAgentWsClientTest { + + @Test + void testTerminalStatusSchedulesTerminationAndUpdatesTracker() throws Exception { + ControllerInitiatedAgentWsClient client = new ControllerInitiatedAgentWsClient(); + VMTracker vmTracker = mock(VMTracker.class); + VMTerminator vmTerminator = mock(VMTerminator.class); + client.setVmTracker(vmTracker); + client.setVmTerminator(vmTerminator); + + CloudVmStatus status = createStatus("i-spoofed", JobStatus.Completed, VMStatus.terminated); + invokeHandleStatusUpdate(client, "i-agent", AgentWsEnvelope.statusUpdate("i-agent", "job-1", status)); + + verify(vmTerminator).terminate("i-agent"); + verify(vmTracker).setStatus(argThat(updated -> "i-agent".equals(updated.getInstanceId()) + && updated.getJobStatus() == JobStatus.Completed + && updated.getVmStatus() == VMStatus.terminated)); + assertEquals("i-agent", status.getInstanceId()); + } + + @Test + void testTerminalStatusSchedulesTerminationOnlyOnce() throws Exception { + ControllerInitiatedAgentWsClient client = new ControllerInitiatedAgentWsClient(); + VMTracker vmTracker = mock(VMTracker.class); + VMTerminator vmTerminator = mock(VMTerminator.class); + client.setVmTracker(vmTracker); + client.setVmTerminator(vmTerminator); + + CloudVmStatus status = createStatus("i-agent", JobStatus.Completed, VMStatus.terminated); + AgentWsEnvelope statusUpdate = AgentWsEnvelope.statusUpdate("i-agent", "job-1", status); + + invokeHandleStatusUpdate(client, "i-agent", statusUpdate); + invokeHandleStatusUpdate(client, "i-agent", statusUpdate); + + verify(vmTerminator, times(1)).terminate("i-agent"); + verify(vmTracker, times(2)).setStatus(status); + } + + @Test + void testSessionCloseClearsTerminationDedup() throws Exception { + ControllerInitiatedAgentWsClient client = new ControllerInitiatedAgentWsClient(); + VMTracker vmTracker = mock(VMTracker.class); + VMTerminator vmTerminator = mock(VMTerminator.class); + client.setVmTracker(vmTracker); + client.setVmTerminator(vmTerminator); + + CloudVmStatus status = createStatus("i-agent", JobStatus.Completed, VMStatus.terminated); + AgentWsEnvelope statusUpdate = AgentWsEnvelope.statusUpdate("i-agent", "job-1", status); + invokeHandleStatusUpdate(client, "i-agent", statusUpdate); + + WebSocket webSocket = mock(WebSocket.class); + installSession(client, "i-agent", webSocket); + invokeOnClosed(client, "i-agent", webSocket); + invokeHandleStatusUpdate(client, "i-agent", statusUpdate); + + verify(vmTerminator, times(2)).terminate("i-agent"); + } + + @Test + void testNonTerminalStatusUpdatesTrackerWithoutTermination() throws Exception { + ControllerInitiatedAgentWsClient client = new ControllerInitiatedAgentWsClient(); + VMTracker vmTracker = mock(VMTracker.class); + VMTerminator vmTerminator = mock(VMTerminator.class); + client.setVmTracker(vmTracker); + client.setVmTerminator(vmTerminator); + + CloudVmStatus status = createStatus("i-agent", JobStatus.Running, VMStatus.running); + invokeHandleStatusUpdate(client, "i-agent", AgentWsEnvelope.statusUpdate("i-agent", "job-1", status)); + + verify(vmTerminator, never()).terminate("i-agent"); + verify(vmTracker).setStatus(status); + } + + private void invokeHandleStatusUpdate(ControllerInitiatedAgentWsClient client, String instanceId, + AgentWsEnvelope envelope) throws Exception { + Method method = ControllerInitiatedAgentWsClient.class.getDeclaredMethod( + "handleStatusUpdate", String.class, AgentWsEnvelope.class); + method.setAccessible(true); + method.invoke(client, instanceId, envelope); + } + + private void invokeOnClosed(ControllerInitiatedAgentWsClient client, String instanceId, WebSocket webSocket) throws Exception { + Method method = ControllerInitiatedAgentWsClient.class.getDeclaredMethod( + "onClosed", String.class, WebSocket.class); + method.setAccessible(true); + method.invoke(client, instanceId, webSocket); + } + + @SuppressWarnings("unchecked") + private void installSession(ControllerInitiatedAgentWsClient client, String instanceId, WebSocket webSocket) throws Exception { + Class sessionContextClass = Class.forName( + "com.intuit.tank.perfManager.workLoads.ControllerInitiatedAgentWsClient$SessionContext"); + Constructor constructor = sessionContextClass.getDeclaredConstructor(WebSocket.class, CompletableFuture.class); + constructor.setAccessible(true); + Object sessionContext = constructor.newInstance(webSocket, new CompletableFuture()); + + Field sessionsField = ControllerInitiatedAgentWsClient.class.getDeclaredField("sessions"); + sessionsField.setAccessible(true); + ConcurrentHashMap sessions = + (ConcurrentHashMap) sessionsField.get(client); + sessions.put(instanceId, sessionContext); + } + + private CloudVmStatus createStatus(String instanceId, JobStatus jobStatus, VMStatus vmStatus) { + return new CloudVmStatus( + instanceId, + "job-1", + "sg-1", + jobStatus, + VMImageType.AGENT, + VMRegion.US_EAST, + vmStatus, + new ValidationStatus(), + 5, + jobStatus == JobStatus.Completed ? 0 : 1, + new Date(), + jobStatus == JobStatus.Completed ? new Date() : null); + } +}