package org.apache.flink.streaming.runtime.partitioner;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.java.tuple.Tuple;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.executiongraph.ExecutionEdge;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.TestingExecutionGraphBuilder;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.class */
public class RescalePartitionerTest extends StreamPartitionerTest {
    @Override // org.apache.flink.streaming.runtime.partitioner.StreamPartitionerTest
    public StreamPartitioner<Tuple> createPartitioner() {
        RescalePartitioner rescalePartitioner = new RescalePartitioner();
        Assert.assertFalse(rescalePartitioner.isBroadcast());
        return rescalePartitioner;
    }

    @Test
    public void testSelectChannelsInterval() {
        this.streamPartitioner.setup(3);
        assertSelectedChannel(0);
        assertSelectedChannel(1);
        assertSelectedChannel(2);
        assertSelectedChannel(0);
    }

    @Test
    public void testExecutionGraphGeneration() throws Exception {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(4);
        executionEnvironment.addSource(new ParallelSourceFunction<String>() { // from class: org.apache.flink.streaming.runtime.partitioner.RescalePartitionerTest.1
            private static final long serialVersionUID = 7772338606389180774L;

            public void run(SourceFunction.SourceContext<String> sourceContext) throws Exception {
            }

            public void cancel() {
            }
        }).setParallelism(2).rescale().flatMap(new FlatMapFunction<String, Tuple2<String, Integer>>() { // from class: org.apache.flink.streaming.runtime.partitioner.RescalePartitionerTest.2
            private static final long serialVersionUID = -5255930322161596829L;

            public void flatMap(String str, Collector<Tuple2<String, Integer>> collector) throws Exception {
            }

            public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
                flatMap((String) obj, (Collector<Tuple2<String, Integer>>) collector);
            }
        }).rescale().print().setParallelism(2);
        List verticesSortedTopologicallyFromSources = executionEnvironment.getStreamGraph().getJobGraph().getVerticesSortedTopologicallyFromSources();
        JobVertex jobVertex = (JobVertex) verticesSortedTopologicallyFromSources.get(0);
        JobVertex jobVertex2 = (JobVertex) verticesSortedTopologicallyFromSources.get(1);
        JobVertex jobVertex3 = (JobVertex) verticesSortedTopologicallyFromSources.get(2);
        Assert.assertEquals(2L, jobVertex.getParallelism());
        Assert.assertEquals(4L, jobVertex2.getParallelism());
        Assert.assertEquals(2L, jobVertex3.getParallelism());
        ExecutionGraph build = TestingExecutionGraphBuilder.newBuilder().build();
        try {
            build.attachJobGraph(verticesSortedTopologicallyFromSources);
        } catch (JobException e) {
            e.printStackTrace();
            Assert.fail("Building ExecutionGraph failed: " + e.getMessage());
        }
        ExecutionJobVertex jobVertex4 = build.getJobVertex(jobVertex.getID());
        ExecutionJobVertex jobVertex5 = build.getJobVertex(jobVertex2.getID());
        ExecutionJobVertex jobVertex6 = build.getJobVertex(jobVertex3.getID());
        Assert.assertEquals(0L, jobVertex4.getInputs().size());
        Assert.assertEquals(1L, jobVertex5.getInputs().size());
        Assert.assertEquals(4L, jobVertex5.getParallelism());
        ExecutionVertex[] taskVertices = jobVertex5.getTaskVertices();
        HashMap hashMap = new HashMap();
        for (ExecutionVertex executionVertex : taskVertices) {
            Assert.assertEquals(1L, executionVertex.getNumberOfInputs());
            Assert.assertEquals(1L, executionVertex.getInputEdges(0).length);
            ExecutionEdge executionEdge = executionVertex.getInputEdges(0)[0];
            Assert.assertEquals(jobVertex.getID(), executionEdge.getSource().getProducer().getJobvertexId());
            int partitionNumber = executionEdge.getSource().getPartitionNumber();
            if (hashMap.containsKey(Integer.valueOf(partitionNumber))) {
                hashMap.put(Integer.valueOf(partitionNumber), Integer.valueOf(((Integer) hashMap.get(Integer.valueOf(partitionNumber))).intValue() + 1));
            } else {
                hashMap.put(Integer.valueOf(partitionNumber), 1);
            }
        }
        Assert.assertEquals(2L, hashMap.size());
        Iterator it = hashMap.values().iterator();
        while (it.hasNext()) {
            Assert.assertEquals(2L, ((Integer) it.next()).intValue());
        }
        Assert.assertEquals(1L, jobVertex6.getInputs().size());
        Assert.assertEquals(2L, jobVertex6.getParallelism());
        ExecutionVertex[] taskVertices2 = jobVertex6.getTaskVertices();
        HashSet hashSet = new HashSet();
        for (ExecutionVertex executionVertex2 : taskVertices2) {
            Assert.assertEquals(1L, executionVertex2.getNumberOfInputs());
            Assert.assertEquals(2L, executionVertex2.getInputEdges(0).length);
            ExecutionEdge executionEdge2 = executionVertex2.getInputEdges(0)[0];
            ExecutionEdge executionEdge3 = executionVertex2.getInputEdges(0)[1];
            Assert.assertEquals(jobVertex2.getID(), executionEdge2.getSource().getProducer().getJobvertexId());
            Assert.assertEquals(jobVertex2.getID(), executionEdge3.getSource().getProducer().getJobvertexId());
            int partitionNumber2 = executionEdge2.getSource().getPartitionNumber();
            Assert.assertFalse(hashSet.contains(Integer.valueOf(partitionNumber2)));
            hashSet.add(Integer.valueOf(partitionNumber2));
            int partitionNumber3 = executionEdge3.getSource().getPartitionNumber();
            Assert.assertFalse(hashSet.contains(Integer.valueOf(partitionNumber3)));
            hashSet.add(Integer.valueOf(partitionNumber3));
        }
        Assert.assertEquals(4L, hashSet.size());
    }
}
