/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements. See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License. You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.kafka.common.security.auth;

import org.apache.kafka.common.config.SaslConfigs;
import org.apache.kafka.common.network.Authenticator;
import org.apache.kafka.common.network.TransportLayer;
import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder;
import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
import org.apache.kafka.common.security.scram.internals.ScramMechanism;
import org.junit.Test;

import javax.net.ssl.SSLSession;
import javax.security.sasl.SaslServer;
import java.net.InetAddress;
import java.security.Principal;

import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class DefaultKafkaPrincipalBuilderTest {

    @Test
    @SuppressWarnings("deprecation")
    public void testUseOldPrincipalBuilderForPlaintextIfProvided() throws Exception {
        TransportLayer transportLayer = mock(TransportLayer.class);
        Authenticator authenticator = mock(Authenticator.class);
        PrincipalBuilder oldPrincipalBuilder = mock(PrincipalBuilder.class);

        when(oldPrincipalBuilder.buildPrincipal(any(), any())).thenReturn(new DummyPrincipal("foo"));

        DefaultKafkaPrincipalBuilder builder = DefaultKafkaPrincipalBuilder.fromOldPrincipalBuilder(authenticator,
                transportLayer, oldPrincipalBuilder, null);

        KafkaPrincipal principal = builder.build(new PlaintextAuthenticationContext(
                InetAddress.getLocalHost(), SecurityProtocol.PLAINTEXT.name()));
        assertEquals(KafkaPrincipal.USER_TYPE, principal.getPrincipalType());
        assertEquals("foo", principal.getName());

        builder.close();

        verify(oldPrincipalBuilder).buildPrincipal(transportLayer, authenticator);
        verify(oldPrincipalBuilder).close();
    }

    @Test
    public void testReturnAnonymousPrincipalForPlaintext() throws Exception {
        try (DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(null)) {
            assertEquals(KafkaPrincipal.ANONYMOUS, builder.build(
                    new PlaintextAuthenticationContext(InetAddress.getLocalHost(), SecurityProtocol.PLAINTEXT.name())));
        }
    }

    @Test
    @SuppressWarnings("deprecation")
    public void testUseOldPrincipalBuilderForSslIfProvided() throws Exception {
        TransportLayer transportLayer = mock(TransportLayer.class);
        Authenticator authenticator = mock(Authenticator.class);
        PrincipalBuilder oldPrincipalBuilder = mock(PrincipalBuilder.class);
        SSLSession session = mock(SSLSession.class);

        when(oldPrincipalBuilder.buildPrincipal(any(), any()))
                .thenReturn(new DummyPrincipal("foo"));

        DefaultKafkaPrincipalBuilder builder = DefaultKafkaPrincipalBuilder.fromOldPrincipalBuilder(authenticator,
                transportLayer, oldPrincipalBuilder, null);

        KafkaPrincipal principal = builder.build(
                new SslAuthenticationContext(session, InetAddress.getLocalHost(), SecurityProtocol.PLAINTEXT.name()));
        assertEquals(KafkaPrincipal.USER_TYPE, principal.getPrincipalType());
        assertEquals("foo", principal.getName());

        builder.close();

        verify(oldPrincipalBuilder).buildPrincipal(transportLayer, authenticator);
        verify(oldPrincipalBuilder).close();
    }

    @Test
    public void testUseSessionPeerPrincipalForSsl() throws Exception {
        SSLSession session = mock(SSLSession.class);

        when(session.getPeerPrincipal()).thenReturn(new DummyPrincipal("foo"));

        DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(null);

        KafkaPrincipal principal = builder.build(
                new SslAuthenticationContext(session, InetAddress.getLocalHost(), SecurityProtocol.PLAINTEXT.name()));
        assertEquals(KafkaPrincipal.USER_TYPE, principal.getPrincipalType());
        assertEquals("foo", principal.getName());

        builder.close();

        verify(session, atLeastOnce()).getPeerPrincipal();
    }

    @Test
    public void testPrincipalBuilderScram() throws Exception {
        SaslServer server = mock(SaslServer.class);

        when(server.getMechanismName()).thenReturn(ScramMechanism.SCRAM_SHA_256.mechanismName());
        when(server.getAuthorizationID()).thenReturn("foo");

        DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(null);

        KafkaPrincipal principal = builder.build(new SaslAuthenticationContext(server,
                SecurityProtocol.SASL_PLAINTEXT, InetAddress.getLocalHost(), SecurityProtocol.SASL_PLAINTEXT.name()));
        assertEquals(KafkaPrincipal.USER_TYPE, principal.getPrincipalType());
        assertEquals("foo", principal.getName());

        builder.close();

        verify(server, atLeastOnce()).getMechanismName();
        verify(server, atLeastOnce()).getAuthorizationID();
    }

    @Test
    public void testPrincipalBuilderGssapi() throws Exception {
        SaslServer server = mock(SaslServer.class);
        KerberosShortNamer kerberosShortNamer = mock(KerberosShortNamer.class);

        when(server.getMechanismName()).thenReturn(SaslConfigs.GSSAPI_MECHANISM);
        when(server.getAuthorizationID()).thenReturn("foo/host@REALM.COM");
        when(kerberosShortNamer.shortName(any())).thenReturn("foo");

        DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(kerberosShortNamer);

        KafkaPrincipal principal = builder.build(new SaslAuthenticationContext(server,
                SecurityProtocol.SASL_PLAINTEXT, InetAddress.getLocalHost(), SecurityProtocol.SASL_PLAINTEXT.name()));
        assertEquals(KafkaPrincipal.USER_TYPE, principal.getPrincipalType());
        assertEquals("foo", principal.getName());

        builder.close();

        verify(server, atLeastOnce()).getMechanismName();
        verify(server, atLeastOnce()).getAuthorizationID();
        verify(kerberosShortNamer, atLeastOnce()).shortName(any());
    }

    private static class DummyPrincipal implements Principal {
        private final String name;

        private DummyPrincipal(String name) {
            this.name = name;
        }

        @Override
        public String getName() {
            return name;
        }
    }

}
