/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.io;

import com.google.auto.service.AutoService;
import java.util.Arrays;
import java.util.List;
import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.TFRecordIO;
import org.apache.beam.sdk.io.TFRecordReadSchemaTransformConfiguration;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.util.Preconditions;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@AutoService(value={SchemaTransformProvider.class})
public class TFRecordReadSchemaTransformProvider
extends TypedSchemaTransformProvider<TFRecordReadSchemaTransformConfiguration> {
    private static final @UnknownKeyFor @NonNull @Initialized String IDENTIFIER = "beam:schematransform:org.apache.beam:tfrecord_read:v1";
    private static final @UnknownKeyFor @NonNull @Initialized String OUTPUT = "output";
    private static final @UnknownKeyFor @NonNull @Initialized String ERROR = "errors";
    public static final @UnknownKeyFor @NonNull @Initialized TupleTag<@UnknownKeyFor @NonNull @Initialized Row> OUTPUT_TAG = new TupleTag<Row>(){};
    public static final @UnknownKeyFor @NonNull @Initialized TupleTag<@UnknownKeyFor @NonNull @Initialized Row> ERROR_TAG = new TupleTag<Row>(){};
    private static final @UnknownKeyFor @NonNull @Initialized Logger LOG = LoggerFactory.getLogger(TFRecordReadSchemaTransformProvider.class);

    @Override
    protected @UnknownKeyFor @NonNull @Initialized SchemaTransform from(@UnknownKeyFor @NonNull @Initialized TFRecordReadSchemaTransformConfiguration configuration) {
        return new TFRecordReadSchemaTransform(configuration);
    }

    @Override
    public @UnknownKeyFor @NonNull @Initialized String identifier() {
        return IDENTIFIER;
    }

    @Override
    public @UnknownKeyFor @NonNull @Initialized List<@UnknownKeyFor @NonNull @Initialized String> outputCollectionNames() {
        return Arrays.asList(OUTPUT, ERROR);
    }

    public static @UnknownKeyFor @NonNull @Initialized SerializableFunction<@UnknownKeyFor @NonNull @Initialized byte @UnknownKeyFor @NonNull @Initialized [], @UnknownKeyFor @NonNull @Initialized Row> getBytesToRowFn(final @UnknownKeyFor @NonNull @Initialized Schema schema) {
        return new SimpleFunction<byte[], Row>(){

            @Override
            public @UnknownKeyFor @NonNull @Initialized Row apply(@UnknownKeyFor @NonNull @Initialized byte @UnknownKeyFor @NonNull @Initialized [] input) {
                return Row.withSchema(schema).addValues(new Object[]{input}).build();
            }
        };
    }

    public static class ErrorFn
    extends DoFn<byte[], Row> {
        private final @UnknownKeyFor @NonNull @Initialized SerializableFunction<@UnknownKeyFor @NonNull @Initialized byte @UnknownKeyFor @NonNull @Initialized [], @UnknownKeyFor @NonNull @Initialized Row> valueMapper;
        private final @UnknownKeyFor @NonNull @Initialized Counter errorCounter;
        private @UnknownKeyFor @NonNull @Initialized Long errorsInBundle = 0L;
        private final @UnknownKeyFor @NonNull @Initialized boolean handleErrors;
        private final @UnknownKeyFor @NonNull @Initialized Schema errorSchema;

        public ErrorFn(@UnknownKeyFor @NonNull @Initialized String name, @UnknownKeyFor @NonNull @Initialized SerializableFunction<@UnknownKeyFor @NonNull @Initialized byte @UnknownKeyFor @NonNull @Initialized [], @UnknownKeyFor @NonNull @Initialized Row> valueMapper, @UnknownKeyFor @NonNull @Initialized Schema errorSchema, @UnknownKeyFor @NonNull @Initialized boolean handleErrors) {
            this.errorCounter = Metrics.counter(TFRecordReadSchemaTransformProvider.class, name);
            this.valueMapper = valueMapper;
            this.handleErrors = handleErrors;
            this.errorSchema = errorSchema;
        }

        @DoFn.ProcessElement
        public void process(@DoFn.Element @UnknownKeyFor @NonNull @Initialized byte @UnknownKeyFor @NonNull @Initialized [] msg, @UnknownKeyFor @NonNull @Initialized DoFn.MultiOutputReceiver receiver) {
            Row mappedRow = null;
            try {
                mappedRow = this.valueMapper.apply(msg);
            }
            catch (Exception e) {
                if (!this.handleErrors) {
                    throw new RuntimeException(e);
                }
                this.errorsInBundle = this.errorsInBundle + 1L;
                LOG.warn("Error while parsing the element", (Throwable)e);
                receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(this.errorSchema, msg, (Throwable)e));
            }
            if (mappedRow != null) {
                receiver.get(OUTPUT_TAG).output(mappedRow);
            }
        }

        @DoFn.FinishBundle
        public void finish(/*
         * Issues handling annotations - annotations may be inaccurate
         */
        @UnknownKeyFor @UnknownKeyFor @UnknownKeyFor @UnknownKeyFor @NonNull @Initialized @NonNull @Initialized @NonNull @Initialized @NonNull @Initialized DoFn. @UnknownKeyFor @NonNull @Initialized FinishBundleContext c) {
            this.errorCounter.inc(this.errorsInBundle);
            this.errorsInBundle = 0L;
        }
    }

    static class TFRecordReadSchemaTransform
    extends SchemaTransform {
        private final @UnknownKeyFor @NonNull @Initialized TFRecordReadSchemaTransformConfiguration configuration;

        TFRecordReadSchemaTransform(@UnknownKeyFor @NonNull @Initialized TFRecordReadSchemaTransformConfiguration configuration) {
            this.configuration = configuration;
        }

        public @UnknownKeyFor @NonNull @Initialized Row getConfigurationRow() {
            try {
                return SchemaRegistry.createDefault().getToRowFunction(TFRecordReadSchemaTransformConfiguration.class).apply(this.configuration).sorted().toSnakeCase();
            }
            catch (NoSuchSchemaException e) {
                throw new RuntimeException(e);
            }
        }

        @Override
        public @UnknownKeyFor @NonNull @Initialized PCollectionRowTuple expand(@UnknownKeyFor @NonNull @Initialized PCollectionRowTuple input) {
            this.configuration.validate();
            TFRecordIO.Read readTransform = TFRecordIO.read().withCompression(Compression.valueOf(this.configuration.getCompression()));
            String filePattern = this.configuration.getFilePattern();
            if (filePattern != null) {
                readTransform = readTransform.from(filePattern);
            }
            if (!this.configuration.getValidate()) {
                readTransform = readTransform.withoutValidation();
            }
            PCollection<byte[]> tfRecordValues = input.getPipeline().apply(readTransform);
            Schema schema = Schema.of(Schema.Field.of("record", Schema.FieldType.BYTES));
            Schema errorSchema = ErrorHandling.errorSchemaBytes();
            boolean handleErrors = ErrorHandling.hasOutput(this.configuration.getErrorHandling());
            SerializableFunction<byte[], Row> bytesToRowFn = TFRecordReadSchemaTransformProvider.getBytesToRowFn(schema);
            PCollectionTuple outputTuple = (PCollectionTuple)((Object)tfRecordValues.apply(ParDo.of(new ErrorFn("TFRecord-read-error-counter", bytesToRowFn, errorSchema, handleErrors)).withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG))));
            PCollectionRowTuple outputRows = PCollectionRowTuple.of(TFRecordReadSchemaTransformProvider.OUTPUT, outputTuple.get(OUTPUT_TAG).setRowSchema(schema));
            PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
            if (handleErrors) {
                outputRows = outputRows.and(Preconditions.checkArgumentNotNull(this.configuration.getErrorHandling()).getOutput(), errorOutput);
            }
            return outputRows;
        }
    }
}

