/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.presto.sql.planner;

import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.planner.DependencyExtractor;
import com.facebook.presto.sql.planner.EqualityInference;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.tree.ArithmeticBinaryExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.QualifiedNameReference;
import com.facebook.presto.util.ImmutableCollectors;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import java.util.Arrays;
import java.util.Collection;
import java.util.Set;
import org.testng.Assert;
import org.testng.annotations.Test;

public class TestEqualityInference {
    @Test
    public void testTransitivity() throws Exception {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        TestEqualityInference.addEquality("a1", "b1", builder);
        TestEqualityInference.addEquality("b1", "c1", builder);
        TestEqualityInference.addEquality("d1", "c1", builder);
        TestEqualityInference.addEquality("a2", "b2", builder);
        TestEqualityInference.addEquality("b2", "a2", builder);
        TestEqualityInference.addEquality("b2", "c2", builder);
        TestEqualityInference.addEquality("d2", "b2", builder);
        TestEqualityInference.addEquality("c2", "d2", builder);
        EqualityInference inference = builder.build();
        Assert.assertEquals((Object)inference.rewriteExpression(TestEqualityInference.someExpression("a1", "a2"), TestEqualityInference.matchesSymbols("d1", "d2")), (Object)TestEqualityInference.someExpression("d1", "d2"));
        Assert.assertEquals((Object)inference.rewriteExpression(TestEqualityInference.someExpression("a1", "c1"), TestEqualityInference.matchesSymbols("b1")), (Object)TestEqualityInference.someExpression("b1", "b1"));
        Assert.assertEquals((Object)inference.rewriteExpression(TestEqualityInference.someExpression("a1", "a2"), TestEqualityInference.matchesSymbols("b1", "d2", "c3")), (Object)TestEqualityInference.someExpression("b1", "d2"));
        Assert.assertEquals((Object)inference.getScopedCanonical((Expression)TestEqualityInference.nameReference("a2"), TestEqualityInference.matchesSymbols("c2", "d2")), (Object)inference.getScopedCanonical((Expression)TestEqualityInference.nameReference("b2"), TestEqualityInference.matchesSymbols("c2", "d2")));
        Expression canonical = inference.getScopedCanonical((Expression)TestEqualityInference.nameReference("a2"), TestEqualityInference.matchesSymbols("c2", "d2"));
        Assert.assertEquals((Object)inference.rewriteExpression(TestEqualityInference.someExpression("a2", "b2"), TestEqualityInference.matchesSymbols("c2", "d2")), (Object)TestEqualityInference.someExpression(canonical, canonical));
    }

    @Test
    public void testTriviallyRewritable() throws Exception {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        Expression expression = builder.build().rewriteExpression(TestEqualityInference.someExpression("a1", "a2"), TestEqualityInference.matchesSymbols("a1", "a2"));
        Assert.assertEquals((Object)expression, (Object)TestEqualityInference.someExpression("a1", "a2"));
    }

    @Test
    public void testUnrewritable() throws Exception {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        TestEqualityInference.addEquality("a1", "b1", builder);
        TestEqualityInference.addEquality("a2", "b2", builder);
        EqualityInference inference = builder.build();
        Assert.assertNull((Object)inference.rewriteExpression(TestEqualityInference.someExpression("a1", "a2"), TestEqualityInference.matchesSymbols("b1", "c1")));
        Assert.assertNull((Object)inference.rewriteExpression(TestEqualityInference.someExpression("c1", "c2"), TestEqualityInference.matchesSymbols("a1", "a2")));
    }

    @Test
    public void testParseEqualityExpression() throws Exception {
        EqualityInference inference = new EqualityInference.Builder().addEquality(TestEqualityInference.equals("a1", "b1")).addEquality(TestEqualityInference.equals("a1", "c1")).addEquality(TestEqualityInference.equals("c1", "a1")).build();
        Expression expression = inference.rewriteExpression(TestEqualityInference.someExpression("a1", "b1"), TestEqualityInference.matchesSymbols("c1"));
        Assert.assertEquals((Object)expression, (Object)TestEqualityInference.someExpression("c1", "c1"));
    }

    @Test(expectedExceptions={IllegalArgumentException.class})
    public void testInvalidEqualityExpression1() throws Exception {
        new EqualityInference.Builder().addEquality(TestEqualityInference.equals("a1", "a1"));
    }

    @Test(expectedExceptions={IllegalArgumentException.class})
    public void testInvalidEqualityExpression2() throws Exception {
        new EqualityInference.Builder().addEquality(TestEqualityInference.someExpression("a1", "b1"));
    }

    @Test(expectedExceptions={IllegalArgumentException.class})
    public void testInvalidEqualityExpression3() throws Exception {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        TestEqualityInference.addEquality("a1", "a1", builder);
    }

    @Test
    public void testExtractInferrableEqualities() throws Exception {
        EqualityInference inference = new EqualityInference.Builder().extractInferenceCandidates(ExpressionUtils.and((Expression[])new Expression[]{TestEqualityInference.equals("a1", "b1"), TestEqualityInference.equals("b1", "c1"), TestEqualityInference.someExpression("c1", "d1")})).build();
        Assert.assertEquals((Object)TestEqualityInference.nameReference("c1"), (Object)inference.rewriteExpression((Expression)TestEqualityInference.nameReference("a1"), TestEqualityInference.matchesSymbols("c1")));
        Assert.assertNull((Object)inference.rewriteExpression((Expression)TestEqualityInference.nameReference("a1"), TestEqualityInference.matchesSymbols("d1")));
    }

    @Test
    public void testEqualityPartitionGeneration() throws Exception {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        builder.addEquality((Expression)TestEqualityInference.nameReference("a1"), (Expression)TestEqualityInference.nameReference("b1"));
        builder.addEquality(TestEqualityInference.add("a1", "a1"), TestEqualityInference.multiply((Expression)TestEqualityInference.nameReference("a1"), (Expression)TestEqualityInference.number(2L)));
        builder.addEquality((Expression)TestEqualityInference.nameReference("b1"), (Expression)TestEqualityInference.nameReference("c1"));
        builder.addEquality(TestEqualityInference.add("a1", "a1"), (Expression)TestEqualityInference.nameReference("c1"));
        builder.addEquality(TestEqualityInference.add("a1", "b1"), (Expression)TestEqualityInference.nameReference("c1"));
        EqualityInference inference = builder.build();
        EqualityInference.EqualityPartition emptyScopePartition = inference.generateEqualitiesPartitionedBy(Predicates.alwaysFalse());
        Assert.assertTrue((boolean)emptyScopePartition.getScopeEqualities().isEmpty());
        Assert.assertFalse((boolean)emptyScopePartition.getScopeComplementEqualities().isEmpty());
        Assert.assertTrue((boolean)emptyScopePartition.getScopeStraddlingEqualities().isEmpty());
        EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(TestEqualityInference.matchesSymbols("c1"));
        Assert.assertFalse((boolean)equalityPartition.getScopeEqualities().isEmpty());
        Assert.assertTrue((boolean)Iterables.all((Iterable)equalityPartition.getScopeEqualities(), TestEqualityInference.matchesSymbolScope(TestEqualityInference.matchesSymbols("c1"))));
        Assert.assertTrue((boolean)Iterables.all((Iterable)equalityPartition.getScopeEqualities(), (Predicate)EqualityInference.isInferenceCandidate()));
        Assert.assertFalse((boolean)equalityPartition.getScopeComplementEqualities().isEmpty());
        Assert.assertTrue((boolean)Iterables.all((Iterable)equalityPartition.getScopeComplementEqualities(), TestEqualityInference.matchesSymbolScope((Predicate<Symbol>)Predicates.not(TestEqualityInference.matchesSymbols("c1")))));
        Assert.assertTrue((boolean)Iterables.all((Iterable)equalityPartition.getScopeComplementEqualities(), (Predicate)EqualityInference.isInferenceCandidate()));
        Assert.assertFalse((boolean)equalityPartition.getScopeStraddlingEqualities().isEmpty());
        Assert.assertTrue((boolean)Iterables.any((Iterable)equalityPartition.getScopeStraddlingEqualities(), TestEqualityInference.matchesStraddlingScope(TestEqualityInference.matchesSymbols("c1"))));
        Assert.assertTrue((boolean)Iterables.all((Iterable)equalityPartition.getScopeStraddlingEqualities(), (Predicate)EqualityInference.isInferenceCandidate()));
        EqualityInference newInference = new EqualityInference.Builder().addAllEqualities((Iterable)equalityPartition.getScopeEqualities()).addAllEqualities((Iterable)equalityPartition.getScopeComplementEqualities()).addAllEqualities((Iterable)equalityPartition.getScopeStraddlingEqualities()).build();
        EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(TestEqualityInference.matchesSymbols("c1"));
        Assert.assertEquals(TestEqualityInference.setCopy(equalityPartition.getScopeEqualities()), TestEqualityInference.setCopy(newEqualityPartition.getScopeEqualities()));
        Assert.assertEquals(TestEqualityInference.setCopy(equalityPartition.getScopeComplementEqualities()), TestEqualityInference.setCopy(newEqualityPartition.getScopeComplementEqualities()));
        Assert.assertEquals(TestEqualityInference.setCopy(equalityPartition.getScopeStraddlingEqualities()), TestEqualityInference.setCopy(newEqualityPartition.getScopeStraddlingEqualities()));
    }

    @Test
    public void testMultipleEqualitySetsPredicateGeneration() throws Exception {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        TestEqualityInference.addEquality("a1", "b1", builder);
        TestEqualityInference.addEquality("b1", "c1", builder);
        TestEqualityInference.addEquality("c1", "d1", builder);
        TestEqualityInference.addEquality("a2", "b2", builder);
        TestEqualityInference.addEquality("b2", "c2", builder);
        TestEqualityInference.addEquality("c2", "d2", builder);
        EqualityInference inference = builder.build();
        EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(TestEqualityInference.symbolBeginsWith("a", "b"));
        Assert.assertFalse((boolean)equalityPartition.getScopeEqualities().isEmpty());
        Assert.assertTrue((boolean)Iterables.all((Iterable)equalityPartition.getScopeEqualities(), TestEqualityInference.matchesSymbolScope(TestEqualityInference.symbolBeginsWith("a", "b"))));
        Assert.assertTrue((boolean)Iterables.all((Iterable)equalityPartition.getScopeEqualities(), (Predicate)EqualityInference.isInferenceCandidate()));
        Assert.assertFalse((boolean)equalityPartition.getScopeComplementEqualities().isEmpty());
        Assert.assertTrue((boolean)Iterables.all((Iterable)equalityPartition.getScopeComplementEqualities(), TestEqualityInference.matchesSymbolScope((Predicate<Symbol>)Predicates.not(TestEqualityInference.symbolBeginsWith("a", "b")))));
        Assert.assertTrue((boolean)Iterables.all((Iterable)equalityPartition.getScopeComplementEqualities(), (Predicate)EqualityInference.isInferenceCandidate()));
        Assert.assertFalse((boolean)equalityPartition.getScopeStraddlingEqualities().isEmpty());
        Assert.assertTrue((boolean)Iterables.any((Iterable)equalityPartition.getScopeStraddlingEqualities(), TestEqualityInference.matchesStraddlingScope(TestEqualityInference.symbolBeginsWith("a", "b"))));
        Assert.assertTrue((boolean)Iterables.all((Iterable)equalityPartition.getScopeStraddlingEqualities(), (Predicate)EqualityInference.isInferenceCandidate()));
        EqualityInference newInference = new EqualityInference.Builder().addAllEqualities((Iterable)equalityPartition.getScopeEqualities()).addAllEqualities((Iterable)equalityPartition.getScopeComplementEqualities()).addAllEqualities((Iterable)equalityPartition.getScopeStraddlingEqualities()).build();
        EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(TestEqualityInference.symbolBeginsWith("a", "b"));
        Assert.assertEquals(TestEqualityInference.setCopy(equalityPartition.getScopeEqualities()), TestEqualityInference.setCopy(newEqualityPartition.getScopeEqualities()));
        Assert.assertEquals(TestEqualityInference.setCopy(equalityPartition.getScopeComplementEqualities()), TestEqualityInference.setCopy(newEqualityPartition.getScopeComplementEqualities()));
        Assert.assertEquals(TestEqualityInference.setCopy(equalityPartition.getScopeStraddlingEqualities()), TestEqualityInference.setCopy(newEqualityPartition.getScopeStraddlingEqualities()));
    }

    @Test
    public void testSubExpressionRewrites() throws Exception {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        builder.addEquality((Expression)TestEqualityInference.nameReference("a1"), TestEqualityInference.add("b", "c"));
        builder.addEquality((Expression)TestEqualityInference.nameReference("a2"), TestEqualityInference.multiply((Expression)TestEqualityInference.nameReference("b"), TestEqualityInference.add("b", "c")));
        builder.addEquality((Expression)TestEqualityInference.nameReference("a3"), TestEqualityInference.multiply((Expression)TestEqualityInference.nameReference("a1"), TestEqualityInference.add("b", "c")));
        EqualityInference inference = builder.build();
        Assert.assertEquals((Object)inference.rewriteExpression(TestEqualityInference.add("b", "c"), TestEqualityInference.symbolBeginsWith("a")), (Object)TestEqualityInference.nameReference("a1"));
        Assert.assertEquals((Object)inference.rewriteExpression(TestEqualityInference.multiply((Expression)TestEqualityInference.nameReference("ax"), TestEqualityInference.add("b", "c")), TestEqualityInference.symbolBeginsWith("a")), (Object)TestEqualityInference.multiply((Expression)TestEqualityInference.nameReference("ax"), (Expression)TestEqualityInference.nameReference("a1")));
        Assert.assertEquals((Object)inference.rewriteExpression(TestEqualityInference.multiply((Expression)TestEqualityInference.nameReference("a1"), TestEqualityInference.add("b", "c")), TestEqualityInference.symbolBeginsWith("a")), (Object)TestEqualityInference.nameReference("a3"));
    }

    @Test
    public void testConstantEqualities() throws Exception {
        EqualityInference.Builder builder = new EqualityInference.Builder();
        TestEqualityInference.addEquality("a1", "b1", builder);
        TestEqualityInference.addEquality("b1", "c1", builder);
        builder.addEquality((Expression)TestEqualityInference.nameReference("c1"), (Expression)TestEqualityInference.number(1L));
        EqualityInference inference = builder.build();
        Assert.assertEquals((Object)inference.rewriteExpression((Expression)TestEqualityInference.nameReference("a1"), TestEqualityInference.matchesSymbols("a1", "b1")), (Object)TestEqualityInference.number(1L));
        EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(TestEqualityInference.matchesSymbols("a1", "b1"));
        Assert.assertEquals(TestEqualityInference.equalitiesAsSets(equalityPartition.getScopeEqualities()), TestEqualityInference.set(TestEqualityInference.set(TestEqualityInference.nameReference("a1"), TestEqualityInference.number(1L)), TestEqualityInference.set(TestEqualityInference.nameReference("b1"), TestEqualityInference.number(1L))));
        Assert.assertEquals(TestEqualityInference.equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), TestEqualityInference.set(TestEqualityInference.set(TestEqualityInference.nameReference("c1"), TestEqualityInference.number(1L))));
        Assert.assertTrue((boolean)equalityPartition.getScopeStraddlingEqualities().isEmpty());
    }

    private static Predicate<Expression> matchesSymbolScope(Predicate<Symbol> symbolScope) {
        return expression -> Iterables.all((Iterable)DependencyExtractor.extractUnique((Expression)expression), (Predicate)symbolScope);
    }

    private static Predicate<Expression> matchesStraddlingScope(Predicate<Symbol> symbolScope) {
        return expression -> {
            Set symbols = DependencyExtractor.extractUnique((Expression)expression);
            return Iterables.any((Iterable)symbols, (Predicate)symbolScope) && Iterables.any((Iterable)symbols, (Predicate)Predicates.not((Predicate)symbolScope));
        };
    }

    private static void addEquality(String symbol1, String symbol2, EqualityInference.Builder builder) {
        builder.addEquality((Expression)TestEqualityInference.nameReference(symbol1), (Expression)TestEqualityInference.nameReference(symbol2));
    }

    private static Expression someExpression(String symbol1, String symbol2) {
        return TestEqualityInference.someExpression((Expression)TestEqualityInference.nameReference(symbol1), (Expression)TestEqualityInference.nameReference(symbol2));
    }

    private static Expression someExpression(Expression expression1, Expression expression2) {
        return new ComparisonExpression(ComparisonExpression.Type.GREATER_THAN, expression1, expression2);
    }

    private static Expression add(String symbol1, String symbol2) {
        return TestEqualityInference.add((Expression)TestEqualityInference.nameReference(symbol1), (Expression)TestEqualityInference.nameReference(symbol2));
    }

    private static Expression add(Expression expression1, Expression expression2) {
        return new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.ADD, expression1, expression2);
    }

    private static Expression multiply(String symbol1, String symbol2) {
        return TestEqualityInference.multiply((Expression)TestEqualityInference.nameReference(symbol1), (Expression)TestEqualityInference.nameReference(symbol2));
    }

    private static Expression multiply(Expression expression1, Expression expression2) {
        return new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Type.MULTIPLY, expression1, expression2);
    }

    private static Expression equals(String symbol1, String symbol2) {
        return TestEqualityInference.equals((Expression)TestEqualityInference.nameReference(symbol1), (Expression)TestEqualityInference.nameReference(symbol2));
    }

    private static Expression equals(Expression expression1, Expression expression2) {
        return new ComparisonExpression(ComparisonExpression.Type.EQUAL, expression1, expression2);
    }

    private static QualifiedNameReference nameReference(String symbol) {
        return new QualifiedNameReference(new Symbol(symbol).toQualifiedName());
    }

    private static LongLiteral number(long number) {
        return new LongLiteral(String.valueOf(number));
    }

    private static Predicate<Symbol> matchesSymbols(String ... symbols) {
        return TestEqualityInference.matchesSymbols(Arrays.asList(symbols));
    }

    private static Predicate<Symbol> matchesSymbols(Collection<String> symbols) {
        Set symbolSet = (Set)symbols.stream().map(Symbol::new).collect(ImmutableCollectors.toImmutableSet());
        return Predicates.in((Collection)symbolSet);
    }

    private static Predicate<Symbol> symbolBeginsWith(String ... prefixes) {
        return TestEqualityInference.symbolBeginsWith(Arrays.asList(prefixes));
    }

    private static Predicate<Symbol> symbolBeginsWith(Iterable<String> prefixes) {
        return symbol -> {
            for (String prefix : prefixes) {
                if (!symbol.getName().startsWith(prefix)) continue;
                return true;
            }
            return false;
        };
    }

    private static Set<Set<Expression>> equalitiesAsSets(Iterable<Expression> expressions) {
        ImmutableSet.Builder builder = ImmutableSet.builder();
        for (Expression expression : expressions) {
            builder.add(TestEqualityInference.equalityAsSet(expression));
        }
        return builder.build();
    }

    private static Set<Expression> equalityAsSet(Expression expression) {
        Preconditions.checkArgument((boolean)(expression instanceof ComparisonExpression));
        ComparisonExpression comparisonExpression = (ComparisonExpression)expression;
        Preconditions.checkArgument((comparisonExpression.getType() == ComparisonExpression.Type.EQUAL ? 1 : 0) != 0);
        return ImmutableSet.of((Object)comparisonExpression.getLeft(), (Object)comparisonExpression.getRight());
    }

    private static <E> Set<E> set(E ... elements) {
        return TestEqualityInference.setCopy(Arrays.asList(elements));
    }

    private static <E> Set<E> setCopy(Iterable<E> elements) {
        return ImmutableSet.copyOf(elements);
    }
}

