Detecting Fraud Transactions with Kafka Streams

I recently wrote a small fun project where I wanted to explore transaction dataset I found on kaggle. I'll take you through what I did to implement a system to monitor transactions for fraud. All the source code and the example dataset can be found here: https://github.com/mamin11/fraud-detection

Creating new app

Head over to spring initializer and create a new app, or clone the example project from here https://github.com/mamin11/fraud-detection. The dependencies needed should be in pom file. You'll also find sample dataset in /data directory.

Base Data Structre

First let's define all the monitoring types we want to support in an enum. In the future, if the need to support more conditions we can add a new criteria type here. Please read through the code to see the structure of other classes like Transaction.

public enum MonitoringCriteriaType {
    HIGH_TRANSACTION_AMOUNT, SMALL_INTERVAL, SPECIFIC_SECTOR
}

Before we create a class for each of the above criteria items, the idea is to be able to have custom logic for each criteria we have. Because of that we will create an abstract class that each criteria can extend and provide its own implementation. This is what our abstract class looks like.

public abstract class MonitoringCriteriaItem<T extends MonitoringCriteriaValue> {
    public abstract boolean violatesRules(Transaction transaction);

    public void isHit(Transaction transaction) {
        boolean hitStatus = this.violatesRules(transaction);
        HitStatus status = hitStatus ? HitStatus.HIT : HitStatus.NO_HIT;
        log.info("MonitoringCriteriaItem: {} hitStatus: {} transaction {} ",
                this.getMonitoringCriteriaType(), status, transaction);

        Set<MonitoringCriteriaType> monitoringCriteriaType = new HashSet<>(
                transaction.getMonitoringCriteriaType() == null ? Set.of() : transaction.getMonitoringCriteriaType());
        Set<HitStatus> hitStatuses = new HashSet<>(
                transaction.getHitStatus() == null ? Set.of() : transaction.getHitStatus());
        hitStatuses.add(status);
        monitoringCriteriaType.add(this.getMonitoringCriteriaType());
        transaction.setHitStatus(hitStatuses.stream().toList());
        transaction.setMonitoringCriteriaType(monitoringCriteriaType.stream().toList());
        transaction.setMonitorAction(hitStatus ? this.getMonitoringAction() : MonitorAction.ALLOW);
    }

    public abstract MonitoringCriteriaValue criteriaValue();
    public abstract MonitoringCriteriaType getMonitoringCriteriaType();
    protected abstract MonitorAction getMonitoringAction();
}

We have defined 4methods that each class extending this must provide its own implementation for. Then we have isHit method which we provide the base logic to monitor a transaction.

Each class that extends this will have its own logic to determine if the rule has been violated or not, which we do in violatesRules method. We also have getMonitoringAction method which will tell us what we should do if a transaction was to be flagged as fraud ie if we want to hold the transaction, review etc.

These are then used in isHit method to update the transaction with hit status, the monitoring action to be taken, and a list of criteria types that the transaction has been put through.

Creating Criteria

Now we are ready to create our first criteria. Say we want to always monitor transaction that are of a specific category, could be anything but in this example let's use "Other".

We will first create SpecificSectorCriteria class that extends MonitoringCriteriaItem and provide the implementation needed.

public class SpecificSectorCriteria extends MonitoringCriteriaItem<SpecificSectorCriteriaValue> {
    @Override
    public boolean violatesRules(Transaction transaction) {
        return transaction.getTransactionSector().equals(this.criteriaValue().getValue());
    }

    @Override
    public SpecificSectorCriteriaValue criteriaValue() {
        return new SpecificSectorCriteriaValue();
    }

    @Override
    public MonitoringCriteriaType getMonitoringCriteriaType() {
        return MonitoringCriteriaType.SPECIFIC_SECTOR;
    }

    @Override
    protected MonitorAction getMonitoringAction() {
        return MonitorAction.BLOCK;
    }
}

In getMonitoringAction, we are saying we want to block the transaction.

Criteria Value

Next is our criteria value, below we are setting the value to "Other" that transactions belonging to this category are blocked.

@SneakyThrows
public class SpecificSectorCriteriaValue extends MonitoringCriteriaValue {
    public SpecificSectorCriteriaValue() {
        // can replace with any sector we want to monitor
        super("Other", MonitoringCriteriaType.SPECIFIC_SECTOR); // used to evaluate if HighAmountCheck is hit or not
    }
}

Now we’ll call this method from two places: one where we use recovery callback and another where we don’t.

Criteria Factory

We then create a factory class that we will use to instantiate each MonitoringItem based on given type.

public class MonitoringCriteriaFactory {
    private final MonitoringCriteriaType type;

    public MonitoringCriteriaFactory(MonitoringCriteriaType type) {
        this.type = type;
    }

    public MonitoringCriteriaItem<?> getMonitoringCriteriaItem() {
        return switch (this.type) {
            case HIGH_TRANSACTION_AMOUNT -> new HighAmountCriteria();
            case SMALL_INTERVAL -> new SmallIntervalCriteria();
            case SPECIFIC_SECTOR -> new SpecificSectorCriteria();
        };
    }
}

Putting it all together

Let’s create a kstream to consume from our input topic

KStream<TransactionKey, Transaction> firstStream = builder
        .stream(inputTopic, Consumed.with(CustomSerdes.TransactionKey(), CustomSerdes.Transaction()));

Then put the stream data through a transformer to process rules and finally output to another topic. We will output only transactions that have been hit.

firstStream
        .peek((key, value) -> log.info("Incoming record - key: {} value: {} ", key, value))
        .transformValues(fraudDetectionTransformer())
        .filter((key, value) -> value.getHitStatus() != null && value.getHitStatus().contains(HitStatus.HIT))
        .peek((key, value) -> log.info("Outgoing record - key: {} value: {} ", key, value))
        .to(outputTopic, Produced.with(CustomSerdes.TransactionKey(), CustomSerdes.Transaction()));

In our transformer, we will create a list of criteria we want to process and loop through each while calling the factory class to check if transaction is hit or not.

public class FraudDetectionTransformer implements ValueTransformerSupplier<Transaction, Transaction> {

    @Override
    public ValueTransformer<Transaction, Transaction> get() {
        return new FraudTransactionProcessor();
    }

    private static class FraudTransactionProcessor implements ValueTransformer<Transaction, Transaction> {
        @Override
        public void init(ProcessorContext processorContext) {
            // nothing to do
        }

        @Override
        public Transaction transform(Transaction value) {
            log.info("Incoming record for fraud detection - value: {}", value);

            // list of criteria to check
            List<MonitoringCriteriaType> criteriaTypeList = List.of(
                    MonitoringCriteriaType.HIGH_TRANSACTION_AMOUNT,
                    MonitoringCriteriaType.SPECIFIC_SECTOR);

            for (MonitoringCriteriaType criteriaType : criteriaTypeList) {
                MonitoringCriteriaFactory monitoringCriteriaFactory = new MonitoringCriteriaFactory(criteriaType);
                MonitoringCriteriaItem<?> monitoringCriteriaItem = monitoringCriteriaFactory.getMonitoringCriteriaItem();
                monitoringCriteriaItem.isHit(value);
            }

            log.info("Outgoing record after fraud check - value: {}", value);

            return value;
        }

        @Override
        public void close() {
            // nothing to do
        }
    }
}

That is it, in kafka we can see an output similar to below for one of the transactions that have been hit.