diff --git a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java index 65d62f68ad5b4..d2ee6269c134e 100644 --- a/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java +++ b/core/src/main/java/org/apache/spark/memory/MemoryConsumer.java @@ -73,7 +73,9 @@ public void spill() throws IOException { * * This should be implemented by subclass. * - * Note: In order to avoid possible deadlock, should not call acquireMemory() from spill(). + * Note: In order to avoid possible deadlock, implementations must release memory synchronously + * on the calling thread and must not acquire task memory from spill(), either directly or from + * another thread. * * Note: today, this only frees Tungsten-managed pages. * @@ -115,7 +117,8 @@ public void freeArray(LongArray array) { * @throws SparkOutOfMemoryError */ protected MemoryBlock allocatePage(long required) { - MemoryBlock page = taskMemoryManager.allocatePage(Math.max(pageSize, required), this); + MemoryBlock page = + taskMemoryManager.allocatePageWithMinimum(Math.max(pageSize, required), required, this); if (page == null || page.size() < required) { throwOom(page, required); } @@ -131,6 +134,11 @@ protected void freePage(MemoryBlock page) { taskMemoryManager.freePage(page, this); } + /** Returns whether this page came from a minimum retry after a partial allocation failed. */ + protected boolean isPageAllocationFromMinimumRetry(MemoryBlock page) { + return taskMemoryManager.isPageAllocationFromMinimumRetry(page); + } + /** * Allocates memory of `size`. */ diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index 7d2a3fb63f9e2..099d95c6f9de5 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -29,6 +29,7 @@ import org.apache.spark.internal.SparkLoggerFactory; import org.apache.spark.internal.LogKeys; import org.apache.spark.internal.MDC; +import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.util.Utils; @@ -94,8 +95,14 @@ public class TaskMemoryManager { */ private final BitSet allocatedPages = new BitSet(PAGE_TABLE_SIZE); + /** Pages allocated at the caller's minimum after a partial allocation also failed. */ + @GuardedBy("this") + private final BitSet pagesAllocatedFromMinimumRetry = new BitSet(PAGE_TABLE_SIZE); + private final MemoryManager memoryManager; + private final MemoryAllocator tungstenMemoryAllocator; + private final long taskAttemptId; /** @@ -114,7 +121,25 @@ public class TaskMemoryManager { /** * The amount of memory that is acquired but not used. */ - private volatile long acquiredButNotUsed = 0L; + @GuardedBy("this") + private long acquiredButNotUsed = 0L; + + /** + * Prevent nested page allocations while spilling from recursively entering allocator recovery. + */ + private final ThreadLocal inPageAllocationRecovery = + ThreadLocal.withInitial(() -> false); + + private static final class PageAllocationRequest { + private MemoryConsumer consumer; + private long minimumSize; + } + + /** + * Carries a padded page request's minimum usable size through the existing virtual allocatePage + * entry point, so subclasses overriding that method continue to intercept page allocations. + */ + private final ThreadLocal pageAllocationRequest = new ThreadLocal<>(); /** * Current off heap memory usage by this task. @@ -144,8 +169,17 @@ public class TaskMemoryManager { * Construct a new TaskMemoryManager. */ public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) { + this(memoryManager, taskAttemptId, memoryManager.tungstenMemoryAllocator()); + } + + @VisibleForTesting + TaskMemoryManager( + MemoryManager memoryManager, + long taskAttemptId, + MemoryAllocator tungstenMemoryAllocator) { this.tungstenMemoryMode = memoryManager.tungstenMemoryMode(); this.memoryManager = memoryManager; + this.tungstenMemoryAllocator = tungstenMemoryAllocator; this.taskAttemptId = taskAttemptId; this.consumers = new HashSet<>(); } @@ -253,6 +287,24 @@ private long trySpillAndAcquire( int idx) { MemoryMode mode = requestingConsumer.getMode(); MemoryConsumer consumerToSpill = cList.get(idx); + long released = spillConsumer(requestingConsumer, requested, consumerToSpill); + if (released > 0) { + // When our spill handler releases memory, `ExecutionMemoryPool#releaseMemory()` will + // immediately notify other tasks that memory has been freed, and they may acquire the + // newly-freed memory before we have a chance to do so (SPARK-35486). Therefore we may + // not be able to acquire all the memory that was just spilled. In that case, we will + // try again in the next loop iteration. + return memoryManager.acquireExecutionMemory(requested, taskAttemptId, mode); + } else { + cList.remove(idx); + return 0; + } + } + + private long spillConsumer( + MemoryConsumer requestingConsumer, + long requested, + MemoryConsumer consumerToSpill) { if (logger.isDebugEnabled()) { logger.debug("Task {} try to spill {} from {} for {}", taskAttemptId, Utils.bytesToString(requested), consumerToSpill, requestingConsumer); @@ -265,17 +317,8 @@ private long trySpillAndAcquire( Utils.bytesToString(released), Utils.bytesToString(requested), consumerToSpill, requestingConsumer); } - - // When our spill handler releases memory, `ExecutionMemoryPool#releaseMemory()` will - // immediately notify other tasks that memory has been freed, and they may acquire the - // newly-freed memory before we have a chance to do so (SPARK-35486). Therefore we may - // not be able to acquire all the memory that was just spilled. In that case, we will - // try again in the next loop iteration. - return memoryManager.acquireExecutionMemory(requested, taskAttemptId, mode); - } else { - cList.remove(idx); - return 0; } + return released; } catch (ClosedByInterruptException | InterruptedIOException e) { // This called by user to kill a task (e.g: speculative task). logger.error("Error while calling spill() on {}", e, @@ -295,6 +338,96 @@ private long trySpillAndAcquire( } } + /** + * Spill task-managed memory after the allocator rejects a grant which the memory manager thought + * was available. Unlike acquireExecutionMemory(), this does not request another grant and cannot + * block waiting for fair-share memory. + */ + private synchronized long spillConsumersForPageAllocation( + long required, + MemoryConsumer requestingConsumer) { + TreeMap> sortedConsumers = new TreeMap<>(); + for (MemoryConsumer c : consumers) { + if (c.getUsed() > 0 && c.getMode() == requestingConsumer.getMode()) { + long key = c == requestingConsumer ? 0 : c.getUsed(); + List list = + sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); + list.add(c); + } + } + + long released = 0L; + while (released < required && !sortedConsumers.isEmpty()) { + Map.Entry> currentEntry = + sortedConsumers.ceilingEntry(required - released); + if (currentEntry == null) { + currentEntry = sortedConsumers.lastEntry(); + } + List cList = currentEntry.getValue(); + int idx = cList.size() - 1; + MemoryConsumer consumerToSpill = cList.get(idx); + long usedBeforeSpill = consumerToSpill.getUsed(); + spillConsumer(requestingConsumer, required - released, consumerToSpill); + // Measure net tracked memory released. Spill callbacks must not reacquire execution memory; + // if a custom consumer violates that contract, conservatively treat it as making no progress. + long actuallyReleased = Math.max(0L, usedBeforeSpill - consumerToSpill.getUsed()); + if (actuallyReleased > 0) { + released = Math.addExact(released, actuallyReleased); + } else { + cList.remove(idx); + } + if (cList.isEmpty()) { + sortedConsumers.remove(currentEntry.getKey()); + } + } + return released; + } + + private long recoverFromPageAllocationFailure( + long required, + MemoryConsumer requestingConsumer) { + if (inPageAllocationRecovery.get()) { + return 0; + } + + inPageAllocationRecovery.set(true); + try { + return spillConsumersForPageAllocation(required, requestingConsumer); + } finally { + inPageAllocationRecovery.remove(); + } + } + + private long acquireAdditionalExecutionMemoryForPageAllocation( + long required, + MemoryConsumer requestingConsumer) { + if (inPageAllocationRecovery.get()) { + return 0; + } + + inPageAllocationRecovery.set(true); + try { + return acquireExecutionMemory(required, requestingConsumer); + } finally { + inPageAllocationRecovery.remove(); + } + } + + private void logPageAllocationFailure(long allocationSize, int retryCount, OutOfMemoryError e) { + try { + if (retryCount == 0) { + logger.warn("Failed to allocate a page ({} bytes), try spilling task memory.", e, + MDC.of(LogKeys.PAGE_SIZE, allocationSize)); + } else { + logger.warn("Failed to allocate a page ({} bytes) after {} spill retries.", + MDC.of(LogKeys.PAGE_SIZE, allocationSize), + MDC.of(LogKeys.NUM_RETRY, retryCount)); + } + } catch (OutOfMemoryError ignored) { + // Preserve allocator recovery even if diagnostics cannot allocate memory. + } + } + /** * Release N bytes of execution memory for a MemoryConsumer. */ @@ -355,10 +488,6 @@ public long pageSizeBytes() { return memoryManager.pageSizeBytes(); } - public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { - return allocatePage(size, consumer, 0); - } - /** * Allocate a block of memory that will be tracked in the MemoryManager's page table; this is * intended for allocating large blocks of Tungsten memory that will be shared between operators. @@ -368,12 +497,45 @@ public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { * * @throws TooLargePageException */ - private MemoryBlock allocatePage( + public MemoryBlock allocatePage(long size, MemoryConsumer consumer) { + PageAllocationRequest request = pageAllocationRequest.get(); + long minimumSize = request != null && request.consumer == consumer + ? Math.min(request.minimumSize, size) + : size; + return allocatePageInternal(size, minimumSize, consumer); + } + + MemoryBlock allocatePageWithMinimum( + long size, + long minimumSize, + MemoryConsumer consumer) { + assert(minimumSize >= 0 && minimumSize <= size); + PageAllocationRequest request = pageAllocationRequest.get(); + if (request == null) { + request = new PageAllocationRequest(); + pageAllocationRequest.set(request); + } + MemoryConsumer previousConsumer = request.consumer; + long previousMinimumSize = request.minimumSize; + request.consumer = consumer; + request.minimumSize = minimumSize; + try { + return allocatePage(size, consumer); + } finally { + request.consumer = previousConsumer; + request.minimumSize = previousMinimumSize; + } + } + + private MemoryBlock allocatePageInternal( long size, - MemoryConsumer consumer, - int retryCount) { + long minimumSize, + MemoryConsumer consumer) { assert(consumer != null); assert(consumer.getMode() == tungstenMemoryMode); + if (inPageAllocationRecovery.get()) { + return null; + } if (size > MAXIMUM_PAGE_SIZE_BYTES) { throw new TooLargePageException(size); } @@ -394,33 +556,108 @@ private MemoryBlock allocatePage( allocatedPages.set(pageNumber); } MemoryBlock page = null; + boolean pageAllocated = false; + int retryCount = 0; + long allocationSize = acquired; + long partialAllocationSize = 0; + boolean tryingPartialAllocation = false; + boolean minimumRetryAfterPartialAllocationFailure = false; try { - page = memoryManager.tungstenMemoryAllocator().allocate(acquired); - } catch (OutOfMemoryError e) { - if (retryCount == 0) { - logger.warn("Failed to allocate a page ({} bytes) for {} times, try again.", e, - MDC.of(LogKeys.PAGE_SIZE, acquired), - MDC.of(LogKeys.NUM_RETRY, retryCount)); - } else { - logger.warn("Failed to allocate a page ({} bytes) for {} times, try again.", - MDC.of(LogKeys.PAGE_SIZE, acquired), - MDC.of(LogKeys.NUM_RETRY, retryCount)); + while (true) { + try { + page = tungstenMemoryAllocator.allocate(allocationSize); + break; + } catch (OutOfMemoryError e) { + logPageAllocationFailure(allocationSize, retryCount, e); + if (tryingPartialAllocation) { + if (minimumSize > 0 && minimumSize < allocationSize) { + // Reuse the retained grant for one final attempt at the caller's usable minimum. + long surplus = Math.subtractExact(acquired, minimumSize); + releaseExecutionMemory(surplus, consumer); + acquired = minimumSize; + allocationSize = minimumSize; + minimumRetryAfterPartialAllocationFailure = true; + retryCount++; + continue; + } + return null; + } + long released = recoverFromPageAllocationFailure(allocationSize, consumer); + if (released > 0) { + long remaining = allocationSize - partialAllocationSize; + partialAllocationSize += Math.min(released, remaining); + } else if (partialAllocationSize > 0 && partialAllocationSize < allocationSize) { + // Preserve one bounded attempt to combine the memory made available by spilling with + // any remaining free-tail grant, then fall back to a partial page for callers that can + // use one. The additional grant may already include the spilled memory, so do not add + // it to partialAllocationSize. + long additionalAcquired = + acquireAdditionalExecutionMemoryForPageAllocation(size, consumer); + if (additionalAcquired > 0) { + long overlap = + additionalAcquired >= partialAllocationSize ? partialAllocationSize : 0L; + if (overlap > 0) { + releaseExecutionMemory(overlap, consumer); + } + acquired = Math.addExact(acquired, additionalAcquired - overlap); + } + allocationSize = + additionalAcquired > 0 ? additionalAcquired : partialAllocationSize; + tryingPartialAllocation = true; + } else if (partialAllocationSize == 0) { + // Preserve one bounded attempt to acquire a smaller free-tail grant. The previous + // recursive implementation could return a partial page this way after retaining the + // rejected grant, but could also retry without bound. + long additionalAcquired = + acquireAdditionalExecutionMemoryForPageAllocation(size, consumer); + if (additionalAcquired <= 0) { + return null; + } + acquired = Math.addExact(acquired, additionalAcquired); + allocationSize = additionalAcquired; + tryingPartialAllocation = true; + } else if (minimumSize > 0 && minimumSize < allocationSize) { + // The original grant is still reserved. If the caller padded a smaller allocation to + // the configured page size, make one bounded attempt at the minimum usable size without + // acquiring more execution memory. + long surplus = Math.subtractExact(acquired, minimumSize); + releaseExecutionMemory(surplus, consumer); + acquired = minimumSize; + allocationSize = minimumSize; + tryingPartialAllocation = true; + minimumRetryAfterPartialAllocationFailure = true; + } else { + return null; + } + retryCount++; + } } - // there is no enough memory actually, it means the actual free memory is smaller than - // MemoryManager thought, we should keep the acquired memory. + page.pageNumber = pageNumber; + pageTable[pageNumber] = page; synchronized (this) { - acquiredButNotUsed += acquired; - allocatedPages.clear(pageNumber); + acquiredButNotUsed = Math.addExact(acquiredButNotUsed, acquired - page.size()); + if (minimumRetryAfterPartialAllocationFailure) { + pagesAllocatedFromMinimumRetry.set(pageNumber); + } + } + pageAllocated = true; + if (logger.isTraceEnabled()) { + logger.trace("Allocate page number {} ({} bytes)", pageNumber, allocationSize); + } + return page; + } finally { + if (!pageAllocated) { + if (page != null) { + page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; + tungstenMemoryAllocator.free(page); + } + synchronized (this) { + pageTable[pageNumber] = null; + allocatedPages.clear(pageNumber); + acquiredButNotUsed = Math.addExact(acquiredButNotUsed, acquired); + } } - // this could trigger spilling to free some pages. - return allocatePage(size, consumer, retryCount + 1); - } - page.pageNumber = pageNumber; - pageTable[pageNumber] = page; - if (logger.isTraceEnabled()) { - logger.trace("Allocate page number {} ({} bytes)", pageNumber, acquired); } - return page; } /** @@ -437,6 +674,7 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { pageTable[page.pageNumber] = null; synchronized (this) { allocatedPages.clear(page.pageNumber); + pagesAllocatedFromMinimumRetry.clear(page.pageNumber); } if (logger.isTraceEnabled()) { logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size()); @@ -446,10 +684,19 @@ public void freePage(MemoryBlock page, MemoryConsumer consumer) { // Doing this allows the MemoryAllocator to detect when a TaskMemoryManager-managed // page has been inappropriately directly freed without calling TMM.freePage(). page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; - memoryManager.tungstenMemoryAllocator().free(page); + tungstenMemoryAllocator.free(page); releaseExecutionMemory(pageSize, consumer); } + boolean isPageAllocationFromMinimumRetry(MemoryBlock page) { + synchronized (this) { + int pageNumber = page.pageNumber; + return pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE && + allocatedPages.get(pageNumber) && + pagesAllocatedFromMinimumRetry.get(pageNumber); + } + } + /** * Given a memory page and offset within that page, encode this address into a 64-bit long. * This address will remain valid as long as the corresponding page has not been freed. @@ -526,6 +773,7 @@ public long getOffsetInPage(long pagePlusOffsetAddress) { * value can be used to detect memory leaks. */ public long cleanUpAllAllocatedMemory() { + final long acquiredButNotUsedToRelease; synchronized (this) { for (MemoryConsumer c: consumers) { if (c != null && c.getUsed() > 0) { @@ -543,14 +791,19 @@ public long cleanUpAllAllocatedMemory() { logger.debug("unreleased page: {} in task {}", page, taskAttemptId); } page.pageNumber = MemoryBlock.FREED_IN_TMM_PAGE_NUMBER; - memoryManager.tungstenMemoryAllocator().free(page); + tungstenMemoryAllocator.free(page); } } Arrays.fill(pageTable, null); + allocatedPages.clear(); + pagesAllocatedFromMinimumRetry.clear(); + acquiredButNotUsedToRelease = acquiredButNotUsed; + acquiredButNotUsed = 0L; } // release the memory that is not used by any consumer (acquired for pages in tungsten mode). - memoryManager.releaseExecutionMemory(acquiredButNotUsed, taskAttemptId, tungstenMemoryMode); + memoryManager.releaseExecutionMemory( + acquiredButNotUsedToRelease, taskAttemptId, tungstenMemoryMode); return memoryManager.releaseAllExecutionMemoryForTask(taskAttemptId); } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index f053135c4dbd7..f899698cecb85 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -304,7 +304,8 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { } writeSortedFile(false); - final long spillSize = freeMemory(); + final long spillSize = getMemoryUsage(); + freeMemory(); inMemSorter.reset(); // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the // records. Otherwise, if the task is over allocated memory, then without freeing the memory @@ -363,6 +364,11 @@ public void cleanupResources() { } } + private void allocateInitialPointerArray() { + LongArray array = allocateArray(inMemSorter.getInitialSizeWithUsableCapacity()); + inMemSorter.expandPointerArray(array); + } + /** * Checks whether there is enough space to insert an additional record in to the sort pointer * array and grows the array if additional space is required. If the required space cannot be @@ -371,29 +377,59 @@ public void cleanupResources() { private void growPointerArrayIfNecessary() throws IOException { assert(inMemSorter != null); if (!inMemSorter.hasSpaceForAnotherRecord()) { + if (!inMemSorter.hasPointerArray()) { + allocateInitialPointerArray(); + return; + } + long used = inMemSorter.getMemoryUsage(); - LongArray array; + LongArray array = null; try { // could trigger spilling array = allocateArray(used / 8 * 2); } catch (TooLargePageException e) { // The pointer array is too big to fix in a single page, spill. spill(); - return; } catch (SparkOutOfMemoryError e) { // should have trigger spilling - if (!inMemSorter.hasSpaceForAnotherRecord()) { + if (!"UNABLE_TO_ACQUIRE_MEMORY".equals(e.getCondition()) || + inMemSorter.hasPointerArray()) { logger.error("Unable to grow the pointer array"); throw e; } - return; } // check if spilling is triggered or not - if (inMemSorter.hasSpaceForAnotherRecord()) { - freeArray(array); - } else { - inMemSorter.expandPointerArray(array); + if (!inMemSorter.hasPointerArray()) { + // A spill reset the pointer array while allocateArray() was in progress. Reuse a successful + // growth allocation, or restore the minimum usable initial array if allocation failed. + if (array != null) { + inMemSorter.expandPointerArray(array); + } else { + allocateInitialPointerArray(); + } + return; + } + inMemSorter.expandPointerArray(array); + } + } + + private void acquireNewPageWithPointerArrayFallback(int required) { + try { + acquireNewPageIfNecessary(required); + } catch (SparkOutOfMemoryError e) { + long minimumPointerArrayBytes = + Math.multiplyExact(inMemSorter.getInitialSizeWithUsableCapacity(), 8L); + if (!"UNABLE_TO_ACQUIRE_MEMORY".equals(e.getCondition()) || + inMemSorter.numRecords() != 0 || + inMemSorter.getMemoryUsage() <= minimumPointerArrayBytes) { + throw e; } + // A growth allocation retained after spilling can consume all memory made available by the + // spill. Since the sorter is still empty, shrink the pointer array and retry the data page + // once so that pointer growth cannot starve record storage. + inMemSorter.reset(); + allocateInitialPointerArray(); + acquireNewPageIfNecessary(required); } } @@ -447,7 +483,24 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p final int uaoSize = UnsafeAlignedOffset.getUaoSize(); // Need 4 or 8 bytes to store the record length. final int required = length + uaoSize; - acquireNewPageIfNecessary(required); + acquireNewPageWithPointerArrayFallback(required); + // Data page allocation may spill and reset the pointer array, so check its capacity again. + try { + growPointerArrayIfNecessary(); + } catch (SparkOutOfMemoryError e) { + if (!"UNABLE_TO_ACQUIRE_MEMORY".equals(e.getCondition()) || + inMemSorter.hasPointerArray() || + inMemSorter.numRecords() != 0 || + currentPage == null || + allocatedPages.size() != 1) { + throw e; + } + // The newly acquired empty data page consumed the remaining fair-share memory. Release it, + // restore the minimum pointer array first, and retry the data page once. + freeMemory(); + allocateInitialPointerArray(); + acquireNewPageIfNecessary(required); + } assert(currentPage != null); final Object base = currentPage.getBaseObject(); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index 7ab522d26c7da..49aa0434a48e2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -76,9 +76,13 @@ public int compare(PackedRecordPointer left, PackedRecordPointer right) { } private int getUsableCapacity() { + return getUsableCapacity(array.size()); + } + + private int getUsableCapacity(long size) { // Radix sort requires same amount of used memory as buffer, Tim sort requires // half of the used memory as buffer. - return (int) (array.size() / (useRadixSort ? 2 : 1.5)); + return (int) (size / (useRadixSort ? 2 : 1.5)); } public void free() { @@ -92,33 +96,47 @@ public int numRecords() { return pos; } + public int getInitialSize() { + return initialSize; + } + + public long getInitialSizeWithUsableCapacity() { + long size = initialSize; + while (getUsableCapacity(size) == 0) { + size = Math.multiplyExact(size, 2L); + } + return size; + } + + public boolean hasPointerArray() { + return array != null; + } + public void reset() { - // Reset `pos` here so that `spill` triggered by the below `allocateArray` will be no-op. pos = 0; if (consumer != null) { - consumer.freeArray(array); - // As `array` has been released, we should set it to `null` to avoid accessing it before - // `allocateArray` returns. `usableCapacity` is also set to `0` to avoid any codes writing - // data to `ShuffleInMemorySorter` when `array` is `null` (e.g., in - // ShuffleExternalSorter.growPointerArrayIfNecessary, we may try to access - // `ShuffleInMemorySorter` when `allocateArray` throws SparkOutOfMemoryError). + if (array != null) { + consumer.freeArray(array); + } + // Allocate the replacement lazily. reset() is called while spilling, and allocating here can + // recursively trigger another spill while a partially complete allocation is still retained. array = null; usableCapacity = 0; - array = consumer.allocateArray(initialSize); - usableCapacity = getUsableCapacity(); } } public void expandPointerArray(LongArray newArray) { - assert(newArray.size() > array.size()); - Platform.copyMemory( - array.getBaseObject(), - array.getBaseOffset(), - newArray.getBaseObject(), - newArray.getBaseOffset(), - pos * 8L - ); - consumer.freeArray(array); + if (array != null) { + assert(newArray.size() > array.size()); + Platform.copyMemory( + array.getBaseObject(), + array.getBaseOffset(), + newArray.getBaseObject(), + newArray.getBaseOffset(), + pos * 8L + ); + consumer.freeArray(array); + } array = newArray; usableCapacity = getUsableCapacity(); } @@ -182,6 +200,10 @@ public void loadNext() { * Return an iterator over record pointers in sorted order. */ public ShuffleSorterIterator getSortedIterator() { + if (pos == 0) { + return new ShuffleSorterIterator(0, array, 0); + } + int offset = 0; if (useRadixSort) { offset = RadixSort.sort( diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 486bfd62bc97a..fc1bcd8bb3362 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -836,11 +836,19 @@ public boolean append(Object kbase, long koff, int klen, Object vbase, long voff * @return whether there is enough space to allocate the new page. */ private boolean acquireNewPage(long required) { + final MemoryBlock page; try { - currentPage = allocatePage(required); + page = allocatePage(required); } catch (SparkOutOfMemoryError e) { return false; } + // Retaining exact-fit minimum-retry pages would consume one page-table slot per map entry. + if (required < pageSizeBytes && page.size() == required && + isPageAllocationFromMinimumRetry(page)) { + freePage(page); + return false; + } + currentPage = page; dataPages.add(currentPage); UnsafeAlignedOffset.putSize(currentPage.getBaseObject(), currentPage.getBaseOffset(), 0); pageCursor = UnsafeAlignedOffset.getUaoSize(); diff --git a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java index 25543690b8322..6c408c0c928b5 100644 --- a/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java +++ b/core/src/test/java/org/apache/spark/memory/TaskMemoryManagerSuite.java @@ -17,16 +17,156 @@ package org.apache.spark.memory; +import java.io.IOException; +import java.util.concurrent.atomic.AtomicInteger; + import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.apache.spark.SparkConf; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; import org.apache.spark.internal.config.package$; public class TaskMemoryManagerSuite { + private static final class TestAllocator implements MemoryAllocator { + private int failuresRemaining; + private int allocationAttempts; + + TestAllocator(int failuresRemaining) { + this.failuresRemaining = failuresRemaining; + } + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + allocationAttempts++; + if (failuresRemaining > 0) { + failuresRemaining--; + // checkstyle.off: RegexpSinglelineJava + throw new OutOfMemoryError("test allocator failure"); + // checkstyle.on: RegexpSinglelineJava + } + return MemoryAllocator.HEAP.allocate(size); + } + + @Override + public void free(MemoryBlock memory) { + MemoryAllocator.HEAP.free(memory); + } + } + + private static final class SizeLimitedAllocator implements MemoryAllocator { + private final long maximumAllocationSize; + private int allocationAttempts; + private long lastAllocationSize; + + SizeLimitedAllocator(long maximumAllocationSize) { + this.maximumAllocationSize = maximumAllocationSize; + } + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + allocationAttempts++; + lastAllocationSize = size; + if (size > maximumAllocationSize) { + // checkstyle.off: RegexpSinglelineJava + throw new OutOfMemoryError("test allocator failure"); + // checkstyle.on: RegexpSinglelineJava + } + return MemoryAllocator.HEAP.allocate(size); + } + + @Override + public void free(MemoryBlock memory) { + MemoryAllocator.HEAP.free(memory); + } + } + + private static final class CompetingTaskAllocator implements MemoryAllocator { + private final MemoryManager memoryManager; + private int allocationAttempts; + + CompetingTaskAllocator(MemoryManager memoryManager) { + this.memoryManager = memoryManager; + } + + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + allocationAttempts++; + if (allocationAttempts == 2) { + long acquired = + memoryManager.acquireExecutionMemory(2048L, 1L, MemoryMode.ON_HEAP); + Assertions.assertEquals(2048L, acquired); + } + if (size > 1024L) { + // checkstyle.off: RegexpSinglelineJava + throw new OutOfMemoryError("test allocator failure"); + // checkstyle.on: RegexpSinglelineJava + } + return MemoryAllocator.HEAP.allocate(size); + } + + @Override + public void free(MemoryBlock memory) { + MemoryAllocator.HEAP.free(memory); + } + } + + private static final class PageAllocatingConsumer extends MemoryConsumer { + PageAllocatingConsumer(TaskMemoryManager manager, long pageSize) { + super(manager, pageSize, MemoryMode.ON_HEAP); + } + + MemoryBlock allocate(long required) { + return allocatePage(required); + } + + void freeAllocatedPage(MemoryBlock page) { + freePage(page); + } + + @Override + public long spill(long size, MemoryConsumer trigger) { + return 0; + } + } + + private static final class AllocatingSpillConsumer extends TestMemoryConsumer { + private MemoryBlock nestedPage; + + AllocatingSpillConsumer(TaskMemoryManager manager) { + super(manager); + } + + @Override + public long spill(long size, MemoryConsumer trigger) throws IOException { + nestedPage = taskMemoryManager.allocatePage(256, this); + long used = getUsed(); + free(used); + return used; + } + } + + private static final class NonSpillingAllocatingConsumer extends TestMemoryConsumer { + private int spillAttempts; + private MemoryBlock nestedPage; + + NonSpillingAllocatingConsumer(TaskMemoryManager manager) { + super(manager); + } + + @Override + public long spill(long size, MemoryConsumer trigger) { + spillAttempts++; + nestedPage = taskMemoryManager.allocatePage(256, this); + return 0; + } + } + @Test public void leakedPageMemoryIsDetected() { final TaskMemoryManager manager = new TaskMemoryManager( @@ -102,6 +242,429 @@ public void callingFreePageOnDirectlyAllocatedPageTriggersAssertionError() { Assertions.assertThrows(AssertionError.class, () -> manager.freePage(dataPage, c)); } + @Test + public void pageAllocationFailureWithoutSpillableMemoryReturnsNull() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(4096); + final TestAllocator allocator = new TestAllocator(Integer.MAX_VALUE); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final MemoryConsumer c = new TestMemoryConsumer(manager); + + Assertions.assertNull(manager.allocatePage(4096, c)); + Assertions.assertEquals(1, allocator.allocationAttempts); + Assertions.assertEquals(4096, manager.getMemoryConsumptionForThisTask()); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void pageAllocationFailureSpillsAndRetriesTheSameGrant() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(5120); + final TestAllocator allocator = new TestAllocator(1); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final TestMemoryConsumer existingConsumer = new TestMemoryConsumer(manager); + final PageAllocatingConsumer requestingConsumer = new PageAllocatingConsumer(manager, 4096); + existingConsumer.use(1024); + + final MemoryBlock page = requestingConsumer.allocate(4096); + Assertions.assertNotNull(page); + Assertions.assertEquals(2, allocator.allocationAttempts); + Assertions.assertEquals(0, existingConsumer.getUsed()); + Assertions.assertEquals(4096, requestingConsumer.getUsed()); + Assertions.assertEquals(4096, manager.getMemoryConsumptionForThisTask()); + + requestingConsumer.freeAllocatedPage(page); + Assertions.assertEquals(0, requestingConsumer.getUsed()); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + } + + @Test + public void pageAllocationFailureCanRetryWithPartialPage() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(5120); + final SizeLimitedAllocator allocator = new SizeLimitedAllocator(1024); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final TestMemoryConsumer existingConsumer = new TestMemoryConsumer(manager); + final PageAllocatingConsumer requestingConsumer = new PageAllocatingConsumer(manager, 4096); + existingConsumer.use(1024); + + final MemoryBlock page = requestingConsumer.allocate(1024); + Assertions.assertEquals(1024, page.size()); + Assertions.assertEquals(3, allocator.allocationAttempts); + Assertions.assertEquals(1024, allocator.lastAllocationSize); + Assertions.assertEquals(0, existingConsumer.getUsed()); + Assertions.assertEquals(1024, requestingConsumer.getUsed()); + Assertions.assertEquals(4096, manager.getMemoryConsumptionForThisTask()); + + requestingConsumer.freeAllocatedPage(page); + Assertions.assertEquals(0, requestingConsumer.getUsed()); + Assertions.assertEquals(3072, manager.getMemoryConsumptionForThisTask()); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void pageAllocationFailureCombinesSpillWithFreeTail() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(6144); + final SizeLimitedAllocator allocator = new SizeLimitedAllocator(2048); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final TestMemoryConsumer existingConsumer = new TestMemoryConsumer(manager); + final PageAllocatingConsumer requestingConsumer = new PageAllocatingConsumer(manager, 4096); + existingConsumer.use(1024); + + final MemoryBlock page = requestingConsumer.allocate(2048); + Assertions.assertEquals(2048, page.size()); + Assertions.assertEquals(3, allocator.allocationAttempts); + Assertions.assertEquals(2048, allocator.lastAllocationSize); + Assertions.assertEquals(0, existingConsumer.getUsed()); + Assertions.assertEquals(2048, requestingConsumer.getUsed()); + Assertions.assertEquals(5120, manager.getMemoryConsumptionForThisTask()); + + requestingConsumer.freeAllocatedPage(page); + Assertions.assertEquals(0, requestingConsumer.getUsed()); + Assertions.assertEquals(3072, manager.getMemoryConsumptionForThisTask()); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void pageAllocationFailurePreservesSmallerGrantAfterCompetingAcquisition() { + final SparkConf conf = new SparkConf(false) + .set("spark.testing", "true") + .set("spark.testing.memory", "10240") + .set("spark.memory.fraction", "1.0") + .set("spark.memory.storageFraction", "0.5"); + final MemoryManager memoryManager = UnifiedMemoryManager$.MODULE$.apply(conf, 1); + final CompetingTaskAllocator allocator = new CompetingTaskAllocator(memoryManager); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final TestMemoryConsumer existingConsumer = new TestMemoryConsumer(manager); + final PageAllocatingConsumer requestingConsumer = new PageAllocatingConsumer(manager, 4096); + existingConsumer.use(2048); + + final MemoryBlock page = requestingConsumer.allocate(1024); + Assertions.assertEquals(1024, page.size()); + Assertions.assertEquals(3, allocator.allocationAttempts); + Assertions.assertEquals(0, existingConsumer.getUsed()); + Assertions.assertEquals(1024, requestingConsumer.getUsed()); + Assertions.assertEquals(5120, manager.getMemoryConsumptionForThisTask()); + + requestingConsumer.freeAllocatedPage(page); + Assertions.assertEquals(4096, manager.getMemoryConsumptionForThisTask()); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + memoryManager.releaseExecutionMemory(2048L, 1L, MemoryMode.ON_HEAP); + } + + @Test + public void pageAllocationFailureReleasesSurplusGrantAfterFullSpill() { + final SparkConf conf = new SparkConf(false) + .set("spark.testing", "true") + .set("spark.testing.memory", "10240") + .set("spark.memory.fraction", "1.0") + .set("spark.memory.storageFraction", "0.5"); + final MemoryManager memoryManager = UnifiedMemoryManager$.MODULE$.apply(conf, 1); + final CompetingTaskAllocator allocator = new CompetingTaskAllocator(memoryManager); + final AtomicInteger acquireCalls = new AtomicInteger(); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator) { + @Override + public long acquireExecutionMemory(long size, MemoryConsumer consumer) { + Assertions.assertEquals( + 1, acquireCalls.incrementAndGet(), "page recovery must not acquire another grant"); + return super.acquireExecutionMemory(size, consumer); + } + }; + final TestMemoryConsumer existingConsumer = new TestMemoryConsumer(manager); + final PageAllocatingConsumer requestingConsumer = new PageAllocatingConsumer(manager, 4096); + existingConsumer.use(4096); + acquireCalls.set(0); + + final MemoryBlock page = requestingConsumer.allocate(1024); + Assertions.assertEquals(1, acquireCalls.get()); + Assertions.assertEquals(1024, page.size()); + Assertions.assertEquals(3, allocator.allocationAttempts); + Assertions.assertEquals(0, existingConsumer.getUsed()); + Assertions.assertEquals(1024, requestingConsumer.getUsed()); + Assertions.assertEquals(1024, manager.getMemoryConsumptionForThisTask()); + + requestingConsumer.freeAllocatedPage(page); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + memoryManager.releaseExecutionMemory(2048L, 1L, MemoryMode.ON_HEAP); + } + + @Test + public void pageAllocationFailureCanRetryWithFreeTail() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(5120); + final SizeLimitedAllocator allocator = new SizeLimitedAllocator(1024); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final PageAllocatingConsumer requestingConsumer = new PageAllocatingConsumer(manager, 4096); + + final MemoryBlock page = requestingConsumer.allocate(1024); + Assertions.assertNotNull(page); + Assertions.assertEquals(1024, page.size()); + Assertions.assertEquals(2, allocator.allocationAttempts); + Assertions.assertEquals(1024, allocator.lastAllocationSize); + Assertions.assertEquals(1024, requestingConsumer.getUsed()); + Assertions.assertEquals(5120, manager.getMemoryConsumptionForThisTask()); + + requestingConsumer.freeAllocatedPage(page); + Assertions.assertEquals(0, requestingConsumer.getUsed()); + Assertions.assertEquals(4096, manager.getMemoryConsumptionForThisTask()); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void pageAllocationFailureCanRetryMinimumAfterFullFreeTailGrant() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(9216); + final SizeLimitedAllocator allocator = new SizeLimitedAllocator(1024); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final PageAllocatingConsumer requestingConsumer = new PageAllocatingConsumer(manager, 4096); + + final MemoryBlock page = requestingConsumer.allocate(1024); + Assertions.assertEquals(1024, page.size()); + Assertions.assertEquals(3, allocator.allocationAttempts); + Assertions.assertEquals(1024, allocator.lastAllocationSize); + Assertions.assertEquals(1024, requestingConsumer.getUsed()); + Assertions.assertEquals(1024, manager.getMemoryConsumptionForThisTask()); + + requestingConsumer.freeAllocatedPage(page); + Assertions.assertEquals(0, requestingConsumer.getUsed()); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void pageAllocationFailureMinimumRetryIsBounded() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(9216); + final SizeLimitedAllocator allocator = new SizeLimitedAllocator(0); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final PageAllocatingConsumer requestingConsumer = new PageAllocatingConsumer(manager, 4096); + + Assertions.assertThrows( + SparkOutOfMemoryError.class, () -> requestingConsumer.allocate(1024)); + Assertions.assertEquals(3, allocator.allocationAttempts); + Assertions.assertEquals(1024, allocator.lastAllocationSize); + Assertions.assertEquals(0, requestingConsumer.getUsed()); + Assertions.assertEquals(1024, manager.getMemoryConsumptionForThisTask()); + + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void bytesToBytesMapAcceptsExactFitInitialGrant() { + final long pageSize = 4096L; + final long recordPageSize = + (3L * UnsafeAlignedOffset.getUaoSize()) + (3L * Long.BYTES); + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(1024L + recordPageSize); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0); + final BytesToBytesMap map = new BytesToBytesMap(manager, 64, pageSize); + final long[] row = new long[]{1L}; + + try { + final BytesToBytesMap.Location location = + map.lookup(row, Platform.LONG_ARRAY_OFFSET, Long.BYTES); + Assertions.assertTrue(location.append( + row, Platform.LONG_ARRAY_OFFSET, Long.BYTES, + row, Platform.LONG_ARRAY_OFFSET, Long.BYTES)); + Assertions.assertEquals( + 1024L + recordPageSize, manager.getMemoryConsumptionForThisTask()); + } finally { + map.free(); + } + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void bytesToBytesMapAcceptsExactFitFreeTailRetry() { + final long pageSize = 4096L; + final long recordPageSize = + (3L * UnsafeAlignedOffset.getUaoSize()) + (3L * Long.BYTES); + final AtomicInteger exactFitAllocations = new AtomicInteger(); + final MemoryAllocator allocator = new MemoryAllocator() { + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + if (size == pageSize) { + // checkstyle.off: RegexpSinglelineJava + throw new OutOfMemoryError("test allocator failure"); + // checkstyle.on: RegexpSinglelineJava + } + if (size == recordPageSize) { + exactFitAllocations.incrementAndGet(); + } + return MemoryAllocator.HEAP.allocate(size); + } + + @Override + public void free(MemoryBlock memory) { + MemoryAllocator.HEAP.free(memory); + } + }; + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(1024L + pageSize + recordPageSize); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final BytesToBytesMap map = new BytesToBytesMap(manager, 64, pageSize); + final long[] row = new long[]{1L}; + + try { + final BytesToBytesMap.Location location = + map.lookup(row, Platform.LONG_ARRAY_OFFSET, Long.BYTES); + Assertions.assertTrue(location.append( + row, Platform.LONG_ARRAY_OFFSET, Long.BYTES, + row, Platform.LONG_ARRAY_OFFSET, Long.BYTES)); + Assertions.assertEquals(1, exactFitAllocations.get()); + Assertions.assertEquals( + 1024L + pageSize + recordPageSize, manager.getMemoryConsumptionForThisTask()); + } finally { + map.free(); + } + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void bytesToBytesMapFallsBackFromExactFitPartialPage() { + final long pageSize = 4096L; + final long recordPageSize = + (3L * UnsafeAlignedOffset.getUaoSize()) + (3L * Long.BYTES); + final AtomicInteger exactFitAllocations = new AtomicInteger(); + final MemoryAllocator allocator = new MemoryAllocator() { + @Override + public MemoryBlock allocate(long size) throws OutOfMemoryError { + if (size == pageSize) { + // checkstyle.off: RegexpSinglelineJava + throw new OutOfMemoryError("test allocator failure"); + // checkstyle.on: RegexpSinglelineJava + } + if (size == recordPageSize) { + exactFitAllocations.incrementAndGet(); + } + return MemoryAllocator.HEAP.allocate(size); + } + + @Override + public void free(MemoryBlock memory) { + MemoryAllocator.HEAP.free(memory); + } + }; + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(9216L); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final BytesToBytesMap map = new BytesToBytesMap(manager, 64, pageSize); + final long[] row = new long[]{1L}; + + try { + final BytesToBytesMap.Location location = + map.lookup(row, Platform.LONG_ARRAY_OFFSET, Long.BYTES); + Assertions.assertFalse(location.append( + row, Platform.LONG_ARRAY_OFFSET, Long.BYTES, + row, Platform.LONG_ARRAY_OFFSET, Long.BYTES)); + Assertions.assertEquals(1, exactFitAllocations.get()); + Assertions.assertEquals(1024L, manager.getMemoryConsumptionForThisTask()); + } finally { + map.free(); + } + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void bytesToBytesMapFallsBackFromExactFitPageAfterSpill() { + final long pageSize = 4096L; + final long recordPageSize = + (3L * UnsafeAlignedOffset.getUaoSize()) + (3L * Long.BYTES); + final SizeLimitedAllocator allocator = new SizeLimitedAllocator(1024L); + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(1024L + (2L * pageSize)); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final BytesToBytesMap map = new BytesToBytesMap(manager, 64, pageSize); + final TestMemoryConsumer spillableConsumer = new TestMemoryConsumer(manager); + final long[] row = new long[]{1L}; + spillableConsumer.use(pageSize); + + try { + final BytesToBytesMap.Location location = + map.lookup(row, Platform.LONG_ARRAY_OFFSET, Long.BYTES); + Assertions.assertFalse(location.append( + row, Platform.LONG_ARRAY_OFFSET, Long.BYTES, + row, Platform.LONG_ARRAY_OFFSET, Long.BYTES)); + Assertions.assertEquals(4, allocator.allocationAttempts); + Assertions.assertEquals(recordPageSize, allocator.lastAllocationSize); + Assertions.assertEquals(0, spillableConsumer.getUsed()); + Assertions.assertEquals(1024L, manager.getMemoryConsumptionForThisTask()); + } finally { + map.free(); + } + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void freeTailAcquisitionDoesNotReenterPageAllocation() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(6144); + final SizeLimitedAllocator allocator = new SizeLimitedAllocator(1024); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final NonSpillingAllocatingConsumer existingConsumer = + new NonSpillingAllocatingConsumer(manager); + final PageAllocatingConsumer requestingConsumer = new PageAllocatingConsumer(manager, 4096); + existingConsumer.use(1024); + + final MemoryBlock page = requestingConsumer.allocate(1024); + Assertions.assertNotNull(page); + Assertions.assertEquals(2, existingConsumer.spillAttempts); + Assertions.assertNull(existingConsumer.nestedPage); + Assertions.assertEquals(1024, existingConsumer.getUsed()); + Assertions.assertEquals(1024, requestingConsumer.getUsed()); + Assertions.assertEquals(6144, manager.getMemoryConsumptionForThisTask()); + + existingConsumer.free(1024); + requestingConsumer.freeAllocatedPage(page); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void nestedPageAllocationFailureDoesNotReenterRecovery() { + final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); + memoryManager.limit(4096); + final TestAllocator allocator = new TestAllocator(Integer.MAX_VALUE); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final AllocatingSpillConsumer existingConsumer = new AllocatingSpillConsumer(manager); + final TestMemoryConsumer requestingConsumer = new TestMemoryConsumer(manager); + existingConsumer.use(1024); + + Assertions.assertNull(manager.allocatePage(1024, requestingConsumer)); + Assertions.assertNull(existingConsumer.nestedPage); + Assertions.assertEquals(2, allocator.allocationAttempts); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + Assertions.assertEquals(0, manager.getMemoryConsumptionForThisTask()); + } + + @Test + public void offHeapPageAllocationFailureReturnsNull() { + final SparkConf conf = new SparkConf() + .set(package$.MODULE$.MEMORY_OFFHEAP_ENABLED(), true) + .set(package$.MODULE$.MEMORY_OFFHEAP_SIZE(), 4096L); + final TestMemoryManager memoryManager = new TestMemoryManager(conf); + memoryManager.limit(4096); + final TestAllocator allocator = new TestAllocator(Integer.MAX_VALUE); + final TaskMemoryManager manager = new TaskMemoryManager(memoryManager, 0, allocator); + final MemoryConsumer c = new TestMemoryConsumer(manager, MemoryMode.OFF_HEAP); + + Assertions.assertNull(manager.allocatePage(4096, c)); + Assertions.assertEquals(1, allocator.allocationAttempts); + Assertions.assertEquals(0, manager.cleanUpAllAllocatedMemory()); + } + @Test public void cooperativeSpilling() { final TestMemoryManager memoryManager = new TestMemoryManager(new SparkConf()); diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java index aed8ccea93527..4f277d52d8384 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java @@ -32,6 +32,7 @@ import org.apache.spark.memory.TestMemoryConsumer; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.LongArray; import org.apache.spark.unsafe.memory.MemoryBlock; public class ShuffleInMemorySorterSuite { @@ -57,6 +58,33 @@ public void testSortingEmptyInput() { Assertions.assertFalse(iter.hasNext()); } + @Test + public void testResetAllocatesPointerArrayLazily() { + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter( + consumer, 4, shouldUseRadixSort()); + Assertions.assertTrue(sorter.hasSpaceForAnotherRecord()); + + sorter.reset(); + Assertions.assertEquals(0, sorter.getMemoryUsage()); + Assertions.assertFalse(sorter.hasSpaceForAnotherRecord()); + Assertions.assertFalse(sorter.getSortedIterator().hasNext()); + + final LongArray array = consumer.allocateArray(sorter.getInitialSize()); + sorter.expandPointerArray(array); + Assertions.assertTrue(sorter.hasSpaceForAnotherRecord()); + sorter.free(); + } + + @Test + public void testInitialSizeWithUsableCapacity() { + final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter( + consumer, 1, shouldUseRadixSort()); + + Assertions.assertEquals(1, sorter.getInitialSize()); + Assertions.assertEquals(2, sorter.getInitialSizeWithUsableCapacity()); + sorter.free(); + } + @Test public void testBasicSorting() { final String[] dataToSort = new String[] { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 26f0a86354478..3bddc8afcd5f5 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -557,6 +557,22 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() thro assertEquals(3, spillFilesCreated.size()); } + @Test + public void smallInitialSortBufferDoesNotSpillEveryRecordRadixOff() throws Exception { + conf.set(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE(), 1L); + conf.set(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT(), false); + writeEnoughRecordsToTriggerSortBufferExpansionAndSpill(); + assertEquals(12, spillFilesCreated.size()); + } + + @Test + public void smallInitialSortBufferDoesNotSpillEveryRecordRadixOn() throws Exception { + conf.set(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE(), 1L); + conf.set(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT(), true); + writeEnoughRecordsToTriggerSortBufferExpansionAndSpill(); + assertEquals(14, spillFilesCreated.size()); + } + private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { memoryManager.limit(DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16); final UnsafeShuffleWriter writer = createWriter(false); diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala index 46b4e8b5202dc..f363afecf331d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/ShuffleExternalSorterSuite.scala @@ -24,11 +24,11 @@ import org.scalatestplus.mockito.MockitoSugar import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} -import org.apache.spark.internal.config.MEMORY_FRACTION -import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED +import org.apache.spark.internal.config.{MEMORY_FRACTION, MEMORY_OFFHEAP_ENABLED, SHUFFLE_SORT_USE_RADIXSORT} import org.apache.spark.internal.config.Tests._ import org.apache.spark.memory._ -import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.{Platform, UnsafeAlignedOffset} +import org.apache.spark.unsafe.memory.MemoryBlock class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { @@ -56,13 +56,14 @@ class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext wi memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed > 400) { val acquireExecutionMemoryMethod = memoryManager.getClass.getMethods.filter(_.getName == "acquireExecutionMemory").head - acquireExecutionMemoryMethod.invoke( - memoryManager, - JLong.valueOf( - memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed - 400), - JLong.valueOf(1L), // taskAttemptId - MemoryMode.ON_HEAP - ).asInstanceOf[java.lang.Long] + acquireExecutionMemoryMethod + .invoke( + memoryManager, + JLong.valueOf( + memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed - 400), + JLong.valueOf(1L), // taskAttemptId + MemoryMode.ON_HEAP) + .asInstanceOf[java.lang.Long] } super.acquireExecutionMemory(required, consumer) } @@ -116,93 +117,346 @@ class ShuffleExternalSorterSuite extends SparkFunSuite with LocalSparkContext wi parameters = Map("requestedBytes" -> "800", "receivedBytes" -> "400")) } - test("cleanupResources should not NPE when reset fails to reallocate array") { - // Reproduces a bug where: - // 1. insertRecord() triggers spill -> reset() -> array = null -> allocateArray() throws OOM - // 2. OOM propagates out of insertRecord() - // 3. UnsafeShuffleWriter's finally block calls cleanupResources() - // 4. cleanupResources() -> freeMemory() -> updatePeakMemoryUsed() -> getMemoryUsage() - // -> inMemSorter.getMemoryUsage() -> NPE because inMemSorter.array is still null - // - // The root cause: reset() sets array = null, then allocateArray() fails. The sorter is left - // with inMemSorter != null but inMemSorter.array == null. cleanupResources() calls - // freeMemory() which calls getMemoryUsage() before reaching inMemSorter.free(). + test("cleanupResources should handle lazily reset pointer array") { val conf = new SparkConf() .setMaster("local[1]") .setAppName("ShuffleExternalSorterSuite") .set(IS_TESTING, true) - .set(TEST_MEMORY, 1600L) + .set(TEST_MEMORY, 10L * 1024 * 1024) .set(MEMORY_FRACTION, 0.9999) .set(MEMORY_OFFHEAP_ENABLED, false) sc = new SparkContext(conf) val memoryManager = UnifiedMemoryManager(conf, 1) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + val taskContext = mock[TaskContext] + val taskMetrics = new TaskMetrics + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + 100, + 1, + conf, + new ShuffleWriteMetrics) + val bytes = new Array[Byte](1) + sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + sorter.spill() + + val inMemSorterField = sorter.getClass.getDeclaredField("inMemSorter") + inMemSorterField.setAccessible(true) + val arrayField = classOf[ShuffleInMemorySorter].getDeclaredField("array") + arrayField.setAccessible(true) + assert( + arrayField.get(inMemSorterField.get(sorter)) == null, + "spill should leave the pointer array unallocated until the next insert") + + sorter.cleanupResources() + } + + test("spill should report all released memory") { + val conf = new SparkConf() + .setMaster("local[1]") + .setAppName("ShuffleExternalSorterSuite") + .set(IS_TESTING, true) + .set(TEST_MEMORY, 10L * 1024 * 1024) + .set(MEMORY_FRACTION, 0.9999) + + sc = new SparkContext(conf) + val memoryManager = UnifiedMemoryManager(conf, 1) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + val taskContext = mock[TaskContext] + val taskMetrics = new TaskMetrics + when(taskContext.taskMetrics()).thenReturn(taskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + 100, + 1, + conf, + new ShuffleWriteMetrics) - var shouldStealMemory = false + sorter.insertRecord(new Array[Byte](1), Platform.BYTE_ARRAY_OFFSET, 1, 0) + val usedBeforeSpill = sorter.getUsed + val spillSize = sorter.spill(Long.MaxValue, sorter) - // Override acquireExecutionMemory to steal freed memory during reset()'s allocateArray(), - // forcing the allocation to fail with OOM. + assert(spillSize === usedBeforeSpill - sorter.getUsed) + assert(taskMetrics.memoryBytesSpilled === spillSize) + sorter.cleanupResources() + } + + test("minimum pointer array should grow before the first insert") { + val conf = new SparkConf() + .setMaster("local[1]") + .setAppName("ShuffleExternalSorterSuite") + .set(IS_TESTING, true) + .set(TEST_MEMORY, 10L * 1024 * 1024) + .set(MEMORY_FRACTION, 0.9999) + + sc = new SparkContext(conf) + val memoryManager = UnifiedMemoryManager(conf, 1) + + Seq(false, true).foreach { useRadixSort => + conf.set(SHUFFLE_SORT_USE_RADIXSORT, useRadixSort) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + val taskContext = mock[TaskContext] + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + 1, + 1, + conf, + new ShuffleWriteMetrics) + + sorter.insertRecord(new Array[Byte](1), Platform.BYTE_ARRAY_OFFSET, 1, 0) + assert(sorter.closeAndGetSpills().length === 1) + sorter.cleanupResources() + } + } + + test("successful growth allocation after spill should be reused") { + val conf = new SparkConf() + .setMaster("local[1]") + .setAppName("ShuffleExternalSorterSuite") + .set(IS_TESTING, true) + .set(TEST_MEMORY, 10L * 1024 * 1024) + .set(MEMORY_FRACTION, 0.9999) + + sc = new SparkContext(conf) + val memoryManager = UnifiedMemoryManager(conf, 1) + + Seq(false, true).foreach { useRadixSort => + conf.set(SHUFFLE_SORT_USE_RADIXSORT, useRadixSort) + val initialSize = 1 + var spillOnGrowthAllocation = false + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) { + override def allocatePage(size: Long, consumer: MemoryConsumer): MemoryBlock = { + if (spillOnGrowthAllocation) { + spillOnGrowthAllocation = false + consumer.spill(size, consumer) + } + super.allocatePage(size, consumer) + } + } + val taskContext = mock[TaskContext] + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + initialSize, + 1, + conf, + new ShuffleWriteMetrics) + + sorter.insertRecord(new Array[Byte](1), Platform.BYTE_ARRAY_OFFSET, 1, 0) + spillOnGrowthAllocation = true + sorter.insertRecord(new Array[Byte](1), Platform.BYTE_ARRAY_OFFSET, 1, 0) + + val inMemSorterField = sorter.getClass.getDeclaredField("inMemSorter") + inMemSorterField.setAccessible(true) + val inMemSorter = inMemSorterField.get(sorter).asInstanceOf[ShuffleInMemorySorter] + assert( + inMemSorter.getMemoryUsage === inMemSorter.getInitialSizeWithUsableCapacity * 2L * 8L) + assert(sorter.closeAndGetSpills().length === 2) + sorter.cleanupResources() + } + } + + test("pointer fallback should preserve spill failures") { + val conf = new SparkConf() + .setMaster("local[1]") + .setAppName("ShuffleExternalSorterSuite") + .set(IS_TESTING, true) + .set(TEST_MEMORY, 10L * 1024 * 1024) + .set(MEMORY_FRACTION, 0.9999) + + sc = new SparkContext(conf) + val memoryManager = UnifiedMemoryManager(conf, 1) + var spillOnGrowthAllocation = false + var failNextPageAllocation = false val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) { - override def acquireExecutionMemory(required: Long, consumer: MemoryConsumer): Long = { - if (shouldStealMemory && - memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed > 400) { - val acquireExecutionMemoryMethod = - memoryManager.getClass.getMethods.filter(_.getName == "acquireExecutionMemory").head - acquireExecutionMemoryMethod.invoke( - memoryManager, - JLong.valueOf( - memoryManager.maxHeapMemory - memoryManager.executionMemoryUsed - 400), - JLong.valueOf(1L), - MemoryMode.ON_HEAP - ).asInstanceOf[java.lang.Long] + override def allocatePage(size: Long, consumer: MemoryConsumer): MemoryBlock = { + if (spillOnGrowthAllocation) { + spillOnGrowthAllocation = false + consumer.spill(size, consumer) + failNextPageAllocation = true + } else if (failNextPageAllocation) { + failNextPageAllocation = false + val parameters = new java.util.HashMap[String, String]() + parameters.put("consumerToSpill", "test") + parameters.put("message", "test failure") + // scalastyle:off throwerror + throw new SparkOutOfMemoryError("SPILL_OUT_OF_MEMORY", parameters) + // scalastyle:on throwerror } - super.acquireExecutionMemory(required, consumer) + super.allocatePage(size, consumer) } } val taskContext = mock[TaskContext] - val taskMetrics = new TaskMetrics - when(taskContext.taskMetrics()).thenReturn(taskMetrics) + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics) val sorter = new ShuffleExternalSorter( taskMemoryManager, sc.env.blockManager, taskContext, - 100, // initialSize: ShuffleInMemorySorter needs 800 bytes (100 * 8) + 1, 1, conf, new ShuffleWriteMetrics) - val inMemSorter = { - val field = sorter.getClass.getDeclaredField("inMemSorter") - field.setAccessible(true) - field.get(sorter).asInstanceOf[ShuffleInMemorySorter] + + sorter.insertRecord(new Array[Byte](1), Platform.BYTE_ARRAY_OFFSET, 1, 0) + spillOnGrowthAllocation = true + val error = intercept[SparkOutOfMemoryError] { + sorter.insertRecord(new Array[Byte](1), Platform.BYTE_ARRAY_OFFSET, 1, 0) } + assert(error.getCondition === "SPILL_OUT_OF_MEMORY") + sorter.cleanupResources() + } - // Fill the pointer array until there's no space for another record. - val bytes = new Array[Byte](1) - while (inMemSorter.hasSpaceForAnotherRecord) { - sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + test("pointer growth should preserve spill failures") { + val conf = new SparkConf() + .setMaster("local[1]") + .setAppName("ShuffleExternalSorterSuite") + .set(IS_TESTING, true) + .set(TEST_MEMORY, 10L * 1024 * 1024) + .set(MEMORY_FRACTION, 0.9999) + + sc = new SparkContext(conf) + val memoryManager = UnifiedMemoryManager(conf, 1) + var failGrowthAfterSpill = false + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) { + override def allocatePage(size: Long, consumer: MemoryConsumer): MemoryBlock = { + if (failGrowthAfterSpill) { + failGrowthAfterSpill = false + consumer.spill(size, consumer) + val parameters = new java.util.HashMap[String, String]() + parameters.put("consumerToSpill", "test") + parameters.put("message", "test failure") + // scalastyle:off throwerror + throw new SparkOutOfMemoryError("SPILL_OUT_OF_MEMORY", parameters) + // scalastyle:on throwerror + } + super.allocatePage(size, consumer) + } } + val taskContext = mock[TaskContext] + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + 1, + 1, + conf, + new ShuffleWriteMetrics) - // Enable memory stealing so that when spill -> reset() -> allocateArray() runs, the freed - // memory is consumed before allocateArray can use it, causing OOM. - shouldStealMemory = true + sorter.insertRecord(new Array[Byte](1), Platform.BYTE_ARRAY_OFFSET, 1, 0) + failGrowthAfterSpill = true + val error = intercept[SparkOutOfMemoryError] { + sorter.insertRecord(new Array[Byte](1), Platform.BYTE_ARRAY_OFFSET, 1, 0) + } + assert(error.getCondition === "SPILL_OUT_OF_MEMORY") + sorter.cleanupResources() + } - // insertRecord triggers spill -> reset() -> array = null -> allocateArray fails -> OOM - intercept[SparkOutOfMemoryError] { - sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, 1, 0) + test("data page allocation spill should restore the pointer array") { + val conf = new SparkConf() + .setMaster("local[1]") + .setAppName("ShuffleExternalSorterSuite") + .set(IS_TESTING, true) + .set(TEST_MEMORY, 10L * 1024 * 1024) + .set(MEMORY_FRACTION, 0.9999) + + sc = new SparkContext(conf) + + val memoryManager = UnifiedMemoryManager(conf, 1) + var spillOnNextPageAllocation = false + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) { + override def allocatePage(size: Long, consumer: MemoryConsumer): MemoryBlock = { + if (spillOnNextPageAllocation) { + spillOnNextPageAllocation = false + consumer.spill(size, consumer) + } + super.allocatePage(size, consumer) + } } + val taskContext = mock[TaskContext] + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + 100, + 1, + conf, + new ShuffleWriteMetrics) - // Verify the broken state: inMemSorter != null but inMemSorter.array == null - val inMemSorterField = sorter.getClass.getDeclaredField("inMemSorter") - inMemSorterField.setAccessible(true) - assert(inMemSorterField.get(sorter) != null, "inMemSorter should still be non-null") - val arrayField = classOf[ShuffleInMemorySorter].getDeclaredField("array") - arrayField.setAccessible(true) - assert(arrayField.get(inMemSorterField.get(sorter)) == null, - "inMemSorter.array should be null (reset freed it, allocateArray failed)") + sorter.insertRecord(new Array[Byte](1), Platform.BYTE_ARRAY_OFFSET, 1, 0) + spillOnNextPageAllocation = true + val bytes = + new Array[Byte]((taskMemoryManager.pageSizeBytes() - UnsafeAlignedOffset.getUaoSize).toInt) + sorter.insertRecord(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, 0) - // Without the fix, this NPEs in: - // cleanupResources -> freeMemory -> updatePeakMemoryUsed -> getMemoryUsage - // -> inMemSorter.getMemoryUsage -> array.size() -> NPE + assert(sorter.closeAndGetSpills().length === 2) sorter.cleanupResources() } + + test("data page allocation should not starve pointer restoration") { + val numCores = 4 + val conf = new SparkConf(false) + .setMaster(s"local[$numCores]") + .setAppName("ShuffleExternalSorterSuite") + .set(TEST_MEMORY, 512L * 1024 * 1024) + .set(MEMORY_FRACTION, 0.01) + + sc = new SparkContext(conf) + val memoryManager = UnifiedMemoryManager(conf, numCores) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) + val taskContext = mock[TaskContext] + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics) + val sorter = new ShuffleExternalSorter( + taskMemoryManager, + sc.env.blockManager, + taskContext, + 4096, + 1, + conf, + new ShuffleWriteMetrics) + + val pageSize = taskMemoryManager.pageSizeBytes() + assert(memoryManager.maxHeapMemory / numCores < pageSize) + val firstRecord = new Array[Byte]((pageSize - UnsafeAlignedOffset.getUaoSize).toInt) + sorter.insertRecord(firstRecord, Platform.BYTE_ARRAY_OFFSET, firstRecord.length, 0) + + val acquireExecutionMemoryMethod = + memoryManager.getClass.getMethods.filter(_.getName == "acquireExecutionMemory").head + val releaseExecutionMemoryMethod = + memoryManager.getClass.getMethods.filter(_.getName == "releaseExecutionMemory").head + (1L until numCores.toLong).foreach { taskAttemptId => + val granted = acquireExecutionMemoryMethod.invoke( + memoryManager, + JLong.valueOf(1L), + JLong.valueOf(taskAttemptId), + MemoryMode.ON_HEAP).asInstanceOf[JLong] + assert(granted === 1L) + } + + try { + sorter.insertRecord(new Array[Byte](1), Platform.BYTE_ARRAY_OFFSET, 1, 0) + } finally { + sorter.cleanupResources() + (1L until numCores.toLong).foreach { taskAttemptId => + releaseExecutionMemoryMethod.invoke( + memoryManager, + JLong.valueOf(1L), + JLong.valueOf(taskAttemptId), + MemoryMode.ON_HEAP) + } + assert(taskMemoryManager.cleanUpAllAllocatedMemory() === 0L) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/HybridQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/HybridQueue.scala index 90996c5526453..4eb0a1425b91e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/HybridQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/HybridQueue.scala @@ -110,10 +110,15 @@ abstract class HybridQueue[T, Q <: Queue[T]]( case _: SparkOutOfMemoryError => null } - val buffer = if (page != null) { - createInMemoryQueue(page) - } else { + val exactFitPartialPage = page != null && + required < memManager.pageSizeBytes() && page.size() == required + val buffer = if (page == null || exactFitPartialPage) { + if (page != null) { + freePage(page) + } createDiskQueue() + } else { + createInMemoryQueue(page) } synchronized { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala index 10d3b1429600f..3610b35816403 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala @@ -21,11 +21,11 @@ import java.io.File import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.internal.config._ -import org.apache.spark.memory.{MemoryMode, TaskMemoryManager, TestMemoryManager} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, TaskMemoryManager, TestMemoryManager} import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.unsafe.memory.{MemoryAllocator, MemoryBlock} import org.apache.spark.util.Utils class RowQueueSuite extends SparkFunSuite with EncryptionFunSuite { @@ -94,6 +94,32 @@ class RowQueueSuite extends SparkFunSuite with EncryptionFunSuite { queue.close() } + test("hybrid queue uses disk for an exact-fit partial page") { + val conf = new SparkConf(false) + val serManager = createSerializerManager(conf) + val mem = new TestMemoryManager(conf) + var pageFreed = false + val taskM = new TaskMemoryManager(mem, 0) { + override def allocatePage(size: Long, consumer: MemoryConsumer): MemoryBlock = { + MemoryAllocator.HEAP.allocate(20) + } + + override def freePage(page: MemoryBlock, consumer: MemoryConsumer): Unit = { + pageFreed = true + MemoryAllocator.HEAP.free(page) + } + } + val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1, serManager) + val row = new UnsafeRow(1) + row.pointTo(new Array[Byte](16), 16) + + assert(queue.add(row) === QueueMode.DISK) + assert(pageFreed) + assert(queue.getUsed === 0) + assert(queue.remove().getSizeInBytes === 16) + queue.close() + } + Seq(true, false).foreach { isOffHeap => encryptionTest(s"hybrid queue (offHeap=$isOffHeap)") { conf => conf.set(MEMORY_OFFHEAP_ENABLED, isOffHeap)