Skip to content

Instantly share code, notes, and snippets.

@c21
Created April 22, 2021 05:25
Show Gist options
  • Save c21/dabf176cbc18a5e2138bc0a29e81c878 to your computer and use it in GitHub Desktop.
Save c21/dabf176cbc18a5e2138bc0a29e81c878 to your computer and use it in GitHub Desktop.
Generated code for final aggregation in unit test with fallback enabled
== Subtree 2 / 2 (maxMethodCodeSize:248; maxConstantPoolSize:282(0.43% used); numInnerClasses:2) ==
*(2) HashAggregateWithControlledFallback ArrayBuffer(key#57) List(avg(value#58)) List(key#57, avg(value#58)#59 AS avg(value)#60) fallbackStartsAt=(2,3)
+- Exchange hashpartitioning(key#57, 5), ENSURE_REQUIREMENTS, [id=#65]
+- *(1) HashAggregateWithControlledFallback ArrayBuffer(key#57) List(partial_avg(value#58)) ArrayBuffer(key#57, sum#65, count#66L) fallbackStartsAt=(2,3)
+- *(1) ColumnarToRow
+- FileScan parquet default.agg1[key#57,value#58] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/private/var/folders/y5/hnsw8mz93vs57ngcd30y6y9c0000gn/T/warehous..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<key:int,value:int>
Generated code:
/* 001 */ public Object generate(Object[] references) {
/* 002 */ return new GeneratedIteratorForCodegenStage2(references);
/* 003 */ }
/* 004 */
/* 005 */ // codegenStageId=2
/* 006 */ final class GeneratedIteratorForCodegenStage2 extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 007 */ private Object[] references;
/* 008 */ private scala.collection.Iterator[] inputs;
/* 009 */ private boolean agg_initAgg_0;
/* 010 */ private boolean agg_bufIsNull_0;
/* 011 */ private double agg_bufValue_0;
/* 012 */ private boolean agg_bufIsNull_1;
/* 013 */ private long agg_bufValue_1;
/* 014 */ private agg_FastHashMap_0 agg_fastHashMap_0;
/* 015 */ private org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> agg_fastHashMapIter_0;
/* 016 */ private org.apache.spark.unsafe.KVIterator agg_mapIter_0;
/* 017 */ private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap_0;
/* 018 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter_0;
/* 019 */ private scala.collection.Iterator inputadapter_input_0;
/* 020 */ private int agg_fallbackCounter_0;
/* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[] agg_mutableStateArray_0 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter[2];
/* 022 */
/* 023 */ public GeneratedIteratorForCodegenStage2(Object[] references) {
/* 024 */ this.references = references;
/* 025 */ }
/* 026 */
/* 027 */ public void init(int index, scala.collection.Iterator[] inputs) {
/* 028 */ partitionIndex = index;
/* 029 */ this.inputs = inputs;
/* 030 */
/* 031 */ inputadapter_input_0 = inputs[0];
/* 032 */ agg_mutableStateArray_0[0] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(1, 0);
/* 033 */ agg_mutableStateArray_0[1] = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(2, 0);
/* 034 */
/* 035 */ }
/* 036 */
/* 037 */ public class agg_FastHashMap_0 {
/* 038 */ private org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch batch;
/* 039 */ private int[] buckets;
/* 040 */ private int capacity = 1 << 1;
/* 041 */ private double loadFactor = 0.5;
/* 042 */ private int numBuckets = (int) (capacity / loadFactor);
/* 043 */ private int maxSteps = 2;
/* 044 */ private int numRows = 0;
/* 045 */ private Object emptyVBase;
/* 046 */ private long emptyVOff;
/* 047 */ private int emptyVLen;
/* 048 */ private boolean isBatchFull = false;
/* 049 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter;
/* 050 */
/* 051 */ public agg_FastHashMap_0(
/* 052 */ org.apache.spark.memory.TaskMemoryManager taskMemoryManager,
/* 053 */ InternalRow emptyAggregationBuffer) {
/* 054 */ batch = org.apache.spark.sql.catalyst.expressions.RowBasedKeyValueBatch
/* 055 */ .allocate(((org.apache.spark.sql.types.StructType) references[1] /* keySchemaTerm */), ((org.apache.spark.sql.types.StructType) references[2] /* valueSchemaTerm */), taskMemoryManager, capacity);
/* 056 */
/* 057 */ final UnsafeProjection valueProjection = UnsafeProjection.create(((org.apache.spark.sql.types.StructType) references[2] /* valueSchemaTerm */));
/* 058 */ final byte[] emptyBuffer = valueProjection.apply(emptyAggregationBuffer).getBytes();
/* 059 */
/* 060 */ emptyVBase = emptyBuffer;
/* 061 */ emptyVOff = Platform.BYTE_ARRAY_OFFSET;
/* 062 */ emptyVLen = emptyBuffer.length;
/* 063 */
/* 064 */ agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(
/* 065 */ 1, 0);
/* 066 */
/* 067 */ buckets = new int[numBuckets];
/* 068 */ java.util.Arrays.fill(buckets, -1);
/* 069 */ }
/* 070 */
/* 071 */ public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(int agg_key_0) {
/* 072 */ long h = hash(agg_key_0);
/* 073 */ int step = 0;
/* 074 */ int idx = (int) h & (numBuckets - 1);
/* 075 */ while (step < maxSteps) {
/* 076 */ // Return bucket index if it's either an empty slot or already contains the key
/* 077 */ if (buckets[idx] == -1) {
/* 078 */ if (numRows < capacity && !isBatchFull) {
/* 079 */ agg_rowWriter.reset();
/* 080 */ agg_rowWriter.zeroOutNullBytes();
/* 081 */ agg_rowWriter.write(0, agg_key_0);
/* 082 */ org.apache.spark.sql.catalyst.expressions.UnsafeRow agg_result
/* 083 */ = agg_rowWriter.getRow();
/* 084 */ Object kbase = agg_result.getBaseObject();
/* 085 */ long koff = agg_result.getBaseOffset();
/* 086 */ int klen = agg_result.getSizeInBytes();
/* 087 */
/* 088 */ UnsafeRow vRow
/* 089 */ = batch.appendRow(kbase, koff, klen, emptyVBase, emptyVOff, emptyVLen);
/* 090 */ if (vRow == null) {
/* 091 */ isBatchFull = true;
/* 092 */ } else {
/* 093 */ buckets[idx] = numRows++;
/* 094 */ }
/* 095 */ return vRow;
/* 096 */ } else {
/* 097 */ // No more space
/* 098 */ return null;
/* 099 */ }
/* 100 */ } else if (equals(idx, agg_key_0)) {
/* 101 */ return batch.getValueRow(buckets[idx]);
/* 102 */ }
/* 103 */ idx = (idx + 1) & (numBuckets - 1);
/* 104 */ step++;
/* 105 */ }
/* 106 */ // Didn't find it
/* 107 */ return null;
/* 108 */ }
/* 109 */
/* 110 */ private boolean equals(int idx, int agg_key_0) {
/* 111 */ UnsafeRow row = batch.getKeyRow(buckets[idx]);
/* 112 */ return (row.getInt(0) == agg_key_0);
/* 113 */ }
/* 114 */
/* 115 */ private long hash(int agg_key_0) {
/* 116 */ long agg_hash_0 = 0;
/* 117 */
/* 118 */ int agg_result_0 = agg_key_0;
/* 119 */ agg_hash_0 = (agg_hash_0 ^ (0x9e3779b9)) + agg_result_0 + (agg_hash_0 << 6) + (agg_hash_0 >>> 2);
/* 120 */
/* 121 */ return agg_hash_0;
/* 122 */ }
/* 123 */
/* 124 */ public org.apache.spark.unsafe.KVIterator<UnsafeRow, UnsafeRow> rowIterator() {
/* 125 */ return batch.rowIterator();
/* 126 */ }
/* 127 */
/* 128 */ public void close() {
/* 129 */ batch.close();
/* 130 */ }
/* 131 */
/* 132 */ }
/* 133 */
/* 134 */ private void agg_doAggregateWithKeysOutput_0(UnsafeRow agg_keyTerm_0, UnsafeRow agg_bufferTerm_0)
/* 135 */ throws java.io.IOException {
/* 136 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[6] /* numOutputRows */).add(1);
/* 137 */
/* 138 */ boolean agg_isNull_10 = agg_keyTerm_0.isNullAt(0);
/* 139 */ int agg_value_12 = agg_isNull_10 ?
/* 140 */ -1 : (agg_keyTerm_0.getInt(0));
/* 141 */ boolean agg_isNull_11 = agg_bufferTerm_0.isNullAt(0);
/* 142 */ double agg_value_13 = agg_isNull_11 ?
/* 143 */ -1.0 : (agg_bufferTerm_0.getDouble(0));
/* 144 */ boolean agg_isNull_12 = agg_bufferTerm_0.isNullAt(1);
/* 145 */ long agg_value_14 = agg_isNull_12 ?
/* 146 */ -1L : (agg_bufferTerm_0.getLong(1));
/* 147 */ boolean agg_isNull_15 = agg_isNull_12;
/* 148 */ double agg_value_17 = -1.0;
/* 149 */ if (!agg_isNull_12) {
/* 150 */ agg_value_17 = (double) agg_value_14;
/* 151 */ }
/* 152 */ boolean agg_isNull_13 = false;
/* 153 */ double agg_value_15 = -1.0;
/* 154 */ if (agg_isNull_15 || agg_value_17 == 0) {
/* 155 */ agg_isNull_13 = true;
/* 156 */ } else {
/* 157 */ if (agg_isNull_11) {
/* 158 */ agg_isNull_13 = true;
/* 159 */ } else {
/* 160 */ agg_value_15 = (double)(agg_value_13 / agg_value_17);
/* 161 */ }
/* 162 */ }
/* 163 */
/* 164 */ agg_mutableStateArray_0[1].reset();
/* 165 */
/* 166 */ agg_mutableStateArray_0[1].zeroOutNullBytes();
/* 167 */
/* 168 */ if (agg_isNull_10) {
/* 169 */ agg_mutableStateArray_0[1].setNullAt(0);
/* 170 */ } else {
/* 171 */ agg_mutableStateArray_0[1].write(0, agg_value_12);
/* 172 */ }
/* 173 */
/* 174 */ if (agg_isNull_13) {
/* 175 */ agg_mutableStateArray_0[1].setNullAt(1);
/* 176 */ } else {
/* 177 */ agg_mutableStateArray_0[1].write(1, agg_value_15);
/* 178 */ }
/* 179 */ append((agg_mutableStateArray_0[1].getRow()));
/* 180 */
/* 181 */ }
/* 182 */
/* 183 */ private void agg_doConsume_0(InternalRow inputadapter_row_0, int agg_expr_0_0, boolean agg_exprIsNull_0_0, double agg_expr_1_0, boolean agg_exprIsNull_1_0, long agg_expr_2_0, boolean agg_exprIsNull_2_0) throws java.io.IOException {
/* 184 */ UnsafeRow agg_unsafeRowAggBuffer_0 = null;
/* 185 */ UnsafeRow agg_fastAggBuffer_0 = null;
/* 186 */
/* 187 */ if (!agg_exprIsNull_0_0) {
/* 188 */ agg_fastAggBuffer_0 = agg_fastHashMap_0.findOrInsert(
/* 189 */ agg_expr_0_0);
/* 190 */ }
/* 191 */ // Cannot find the key in fast hash map, try regular hash map.
/* 192 */ if (agg_fastAggBuffer_0 == null) {
/* 193 */ // generate grouping key
/* 194 */ agg_mutableStateArray_0[0].reset();
/* 195 */
/* 196 */ agg_mutableStateArray_0[0].zeroOutNullBytes();
/* 197 */
/* 198 */ if (agg_exprIsNull_0_0) {
/* 199 */ agg_mutableStateArray_0[0].setNullAt(0);
/* 200 */ } else {
/* 201 */ agg_mutableStateArray_0[0].write(0, agg_expr_0_0);
/* 202 */ }
/* 203 */ int agg_unsafeRowKeyHash_0 = (agg_mutableStateArray_0[0].getRow()).hashCode();
/* 204 */ if (agg_fallbackCounter_0 < 3) {
/* 205 */ // try to get the buffer from hash map
/* 206 */ agg_unsafeRowAggBuffer_0 =
/* 207 */ agg_hashMap_0.getAggregationBufferFromUnsafeRow((agg_mutableStateArray_0[0].getRow()), agg_unsafeRowKeyHash_0);
/* 208 */ }
/* 209 */ // Can't allocate buffer from the hash map. Spill the map and fallback to sort-based
/* 210 */ // aggregation after processing all input rows.
/* 211 */ if (agg_unsafeRowAggBuffer_0 == null) {
/* 212 */ if (agg_sorter_0 == null) {
/* 213 */ agg_sorter_0 = agg_hashMap_0.destructAndCreateExternalSorter();
/* 214 */ } else {
/* 215 */ agg_sorter_0.merge(agg_hashMap_0.destructAndCreateExternalSorter());
/* 216 */ }
/* 217 */ agg_fallbackCounter_0 = 0;
/* 218 */ // the hash map had be spilled, it should have enough memory now,
/* 219 */ // try to allocate buffer again.
/* 220 */ agg_unsafeRowAggBuffer_0 = agg_hashMap_0.getAggregationBufferFromUnsafeRow(
/* 221 */ (agg_mutableStateArray_0[0].getRow()), agg_unsafeRowKeyHash_0);
/* 222 */ if (agg_unsafeRowAggBuffer_0 == null) {
/* 223 */ // failed to allocate the first page
/* 224 */ throw new org.apache.spark.memory.SparkOutOfMemoryError("No enough memory for aggregation");
/* 225 */ }
/* 226 */ }
/* 227 */
/* 228 */ }
/* 229 */
/* 230 */ agg_fallbackCounter_0 += 1;
/* 231 */
/* 232 */ // Updates the proper row buffer
/* 233 */ if (agg_fastAggBuffer_0 != null) {
/* 234 */ agg_unsafeRowAggBuffer_0 = agg_fastAggBuffer_0;
/* 235 */ }
/* 236 */
/* 237 */ // common sub-expressions
/* 238 */
/* 239 */ // evaluate aggregate functions and update aggregation buffers
/* 240 */ agg_doAggregate_avg_0(agg_expr_2_0, agg_expr_1_0, agg_unsafeRowAggBuffer_0, agg_exprIsNull_2_0, agg_exprIsNull_1_0);
/* 241 */
/* 242 */ }
/* 243 */
/* 244 */ private void agg_doAggregate_avg_0(long agg_expr_2_0, double agg_expr_1_0, org.apache.spark.sql.catalyst.InternalRow agg_unsafeRowAggBuffer_0, boolean agg_exprIsNull_2_0, boolean agg_exprIsNull_1_0) throws java.io.IOException {
/* 245 */ boolean agg_isNull_4 = true;
/* 246 */ double agg_value_6 = -1.0;
/* 247 */ boolean agg_isNull_5 = agg_unsafeRowAggBuffer_0.isNullAt(0);
/* 248 */ double agg_value_7 = agg_isNull_5 ?
/* 249 */ -1.0 : (agg_unsafeRowAggBuffer_0.getDouble(0));
/* 250 */ if (!agg_isNull_5) {
/* 251 */ if (!agg_exprIsNull_1_0) {
/* 252 */ agg_isNull_4 = false; // resultCode could change nullability.
/* 253 */
/* 254 */ agg_value_6 = agg_value_7 + agg_expr_1_0;
/* 255 */
/* 256 */ }
/* 257 */
/* 258 */ }
/* 259 */ boolean agg_isNull_7 = true;
/* 260 */ long agg_value_9 = -1L;
/* 261 */ boolean agg_isNull_8 = agg_unsafeRowAggBuffer_0.isNullAt(1);
/* 262 */ long agg_value_10 = agg_isNull_8 ?
/* 263 */ -1L : (agg_unsafeRowAggBuffer_0.getLong(1));
/* 264 */ if (!agg_isNull_8) {
/* 265 */ if (!agg_exprIsNull_2_0) {
/* 266 */ agg_isNull_7 = false; // resultCode could change nullability.
/* 267 */
/* 268 */ agg_value_9 = agg_value_10 + agg_expr_2_0;
/* 269 */
/* 270 */ }
/* 271 */
/* 272 */ }
/* 273 */
/* 274 */ if (!agg_isNull_4) {
/* 275 */ agg_unsafeRowAggBuffer_0.setDouble(0, agg_value_6);
/* 276 */ } else {
/* 277 */ agg_unsafeRowAggBuffer_0.setNullAt(0);
/* 278 */ }
/* 279 */
/* 280 */ if (!agg_isNull_7) {
/* 281 */ agg_unsafeRowAggBuffer_0.setLong(1, agg_value_9);
/* 282 */ } else {
/* 283 */ agg_unsafeRowAggBuffer_0.setNullAt(1);
/* 284 */ }
/* 285 */ }
/* 286 */
/* 287 */ private void agg_doAggregateWithKeys_0() throws java.io.IOException {
/* 288 */ while ( inputadapter_input_0.hasNext()) {
/* 289 */ InternalRow inputadapter_row_0 = (InternalRow) inputadapter_input_0.next();
/* 290 */
/* 291 */ boolean inputadapter_isNull_0 = inputadapter_row_0.isNullAt(0);
/* 292 */ int inputadapter_value_0 = inputadapter_isNull_0 ?
/* 293 */ -1 : (inputadapter_row_0.getInt(0));
/* 294 */ boolean inputadapter_isNull_1 = inputadapter_row_0.isNullAt(1);
/* 295 */ double inputadapter_value_1 = inputadapter_isNull_1 ?
/* 296 */ -1.0 : (inputadapter_row_0.getDouble(1));
/* 297 */ boolean inputadapter_isNull_2 = inputadapter_row_0.isNullAt(2);
/* 298 */ long inputadapter_value_2 = inputadapter_isNull_2 ?
/* 299 */ -1L : (inputadapter_row_0.getLong(2));
/* 300 */
/* 301 */ agg_doConsume_0(inputadapter_row_0, inputadapter_value_0, inputadapter_isNull_0, inputadapter_value_1, inputadapter_isNull_1, inputadapter_value_2, inputadapter_isNull_2);
/* 302 */ // shouldStop check is eliminated
/* 303 */ }
/* 304 */
/* 305 */ agg_fastHashMapIter_0 = agg_fastHashMap_0.rowIterator();
/* 306 */ agg_mapIter_0 = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).finishAggregate(agg_hashMap_0, agg_sorter_0, ((org.apache.spark.sql.execution.metric.SQLMetric) references[3] /* peakMemory */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[4] /* spillSize */), ((org.apache.spark.sql.execution.metric.SQLMetric) references[5] /* avgHashProbe */));
/* 307 */
/* 308 */ }
/* 309 */
/* 310 */ protected void processNext() throws java.io.IOException {
/* 311 */ if (!agg_initAgg_0) {
/* 312 */ agg_initAgg_0 = true;
/* 313 */ agg_fastHashMap_0 = new agg_FastHashMap_0(((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskContext().taskMemoryManager(), ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getEmptyAggregationBuffer());
/* 314 */
/* 315 */ ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).getTaskContext().addTaskCompletionListener(
/* 316 */ new org.apache.spark.util.TaskCompletionListener() {
/* 317 */ @Override
/* 318 */ public void onTaskCompletion(org.apache.spark.TaskContext context) {
/* 319 */ agg_fastHashMap_0.close();
/* 320 */ }
/* 321 */ });
/* 322 */
/* 323 */ agg_hashMap_0 = ((org.apache.spark.sql.execution.aggregate.HashAggregateExec) references[0] /* plan */).createHashMap();
/* 324 */ long wholestagecodegen_beforeAgg_0 = System.nanoTime();
/* 325 */ agg_doAggregateWithKeys_0();
/* 326 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[7] /* aggTime */).add((System.nanoTime() - wholestagecodegen_beforeAgg_0) / 1000000);
/* 327 */ }
/* 328 */ // output the result
/* 329 */
/* 330 */ while ( agg_fastHashMapIter_0.next()) {
/* 331 */ UnsafeRow agg_aggKey_0 = (UnsafeRow) agg_fastHashMapIter_0.getKey();
/* 332 */ UnsafeRow agg_aggBuffer_0 = (UnsafeRow) agg_fastHashMapIter_0.getValue();
/* 333 */ agg_doAggregateWithKeysOutput_0(agg_aggKey_0, agg_aggBuffer_0);
/* 334 */
/* 335 */ if (shouldStop()) return;
/* 336 */ }
/* 337 */ agg_fastHashMap_0.close();
/* 338 */
/* 339 */ while ( agg_mapIter_0.next()) {
/* 340 */ UnsafeRow agg_aggKey_0 = (UnsafeRow) agg_mapIter_0.getKey();
/* 341 */ UnsafeRow agg_aggBuffer_0 = (UnsafeRow) agg_mapIter_0.getValue();
/* 342 */ agg_doAggregateWithKeysOutput_0(agg_aggKey_0, agg_aggBuffer_0);
/* 343 */ if (shouldStop()) return;
/* 344 */ }
/* 345 */ agg_mapIter_0.close();
/* 346 */ if (agg_sorter_0 == null) {
/* 347 */ agg_hashMap_0.free();
/* 348 */ }
/* 349 */ }
/* 350 */
/* 351 */ }
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment