背景抽象:

在物流系统中,上游业务系统向消息队列发送的运单和包裹数据可能存在:

短时间内(1秒或几秒)产生万级重复数据,维度关联导致笛卡尔积倍增,形成”大包裹数据”

导致下游Flink作业消费时出现数据倾斜,卡CP、OOM等问题。

需要在各作业中监控热点key以实现预警功能。

卡点:

通常会将每条数据按key塞入mapstate统计出现次数,但包裹数据更新在分钟级窗口内即可达百万甚至千万,持续的作业级监控难以承受巨量state开销;同时,该场景作为插件化通用操作的实现,也无法使用sink redis等外部手段,还是考虑在state内解决。

思路:

综上,方案需要优化数据结构,压缩state内数据存储量,同时自定义序列化器,避免flink fallback到原生序列化产生额外开销。

方案原理:

采用布隆过滤器的衍生算法Count-min Sketch进行数据频率统计,以下简单介绍一下CMS算法,

Count-Min Sketch通过使用多个哈希函数将元素映射到一个二维计数矩阵中,并在每个哈希位置累加计数来跟踪元素的出现频率。每次更新时,对于每个哈希函数,找到对应的计数器并增加其值。查询时,对于每个元素,使用所有哈希函数计算哈希值,并取这些哈希位置计数的最小值作为该元素频率的估计。这样,即使存在哈希碰撞,估计值也不会低于真实值。通过调整矩阵的宽度和深度(即哈希函数的数量和计数器的行数),可以控制误差范围和空间消耗,占用存储非常少,并且与数据量大小无关;假如每个计数单元占4B存储,几十万的数据,也就需要几十KB左右的存储就够了。非精确概率计算也适用于热点key的频率估计。

原理如下:

以下是一个模拟的示例代码:

import org.apache.flink.api.common.RuntimeExecutionMode;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.DataInputView;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.runtime.state.hashmap.HashMapStateBackend;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.types.TypeSerializer;
import org.apache.flink.util.Collector;

import java.io.IOException;
import java.io.Serializable;
import java.util.Random;
import java.util.concurrent.TimeUnit;

/**
 * CMS频率统计:统计key出现频率(次/分钟),含自定义序列化+RocksDB状态后端
 */
public class CmsFrequencyStatJob {

    // ===================== 1. 定义核心实体类(需序列化)=====================
    /**
     * CMS访问实体(输入数据)
     */
    public static class CmsAccessLog implements Serializable {
        private static final long serialVersionUID = 1L;
        private String pak; // 统计维度:key
        private String contentId; // 可选:内容ID
        private long accessTime; // 访问时间戳(毫秒)

        // 无参构造(序列化强制要求)
        public CmsAccessLog() {}

        public CmsAccessLog(String pak, String contentId, long accessTime) {
            this.pak = pak;
            this.contentId = contentId;
            this.accessTime = accessTime;
        }

        // getter/setter
        public String getpak() { return pak; }
        public void setpak(String pak) { this.pak = pak; }
        public long getAccessTime() { return accessTime; }
        public void setAccessTime(long accessTime) { this.accessTime = accessTime; }
    }

    /**
     * 频率统计状态实体(ValueState存储对象,需自定义序列化)
     */
    public static class FrequencyState implements Serializable {
        private static final long serialVersionUID = 1L;
        private String pak; // 关联key
        private long accessCount; // 累计访问次数
        private long windowStartTime; // 统计窗口起始时间(如每分钟的0秒)

        // 无参构造(序列化强制要求)
        public FrequencyState() {}

        public FrequencyState(String pak, long accessCount, long windowStartTime) {
            this.pak = pak;
            this.accessCount = accessCount;
            this.windowStartTime = windowStartTime;
        }

        // getter/setter
        public String getpak() { return pak; }
        public void setpak(String pak) { this.pak = pak; }
        public long getAccessCount() { return accessCount; }
        public void setAccessCount(long accessCount) { this.accessCount = accessCount; }
        public long getWindowStartTime() { return windowStartTime; }
        public void setWindowStartTime(long windowStartTime) { this.windowStartTime = windowStartTime; }
    }

    // ===================== 2. 自定义序列化器(实现TypeSerializer接口)=====================
    public static class FrequencyStateSerializer extends TypeSerializer<FrequencyState> {
        private static final long serialVersionUID = 1L;

        // 序列化:将FrequencyState写入字节流
        @Override
        public void serialize(FrequencyState state, DataOutputView output) throws IOException {
            if (state == null) {
                output.writeBoolean(false); // 标记空对象
                return;
            }
            output.writeBoolean(true); // 标记非空
            output.writeUTF(state.getpak()); // 序列化String:pak
            output.writeLong(state.getAccessCount()); // 序列化long:访问次数
            output.writeLong(state.getWindowStartTime()); // 序列化long:窗口起始时间
        }

        // 反序列化:从字节流读取FrequencyState
        @Override
        public FrequencyState deserialize(DataInputView input) throws IOException {
            boolean isNotNull = input.readBoolean();
            if (!isNotNull) return null;

            String pak = input.readUTF();
            long accessCount = input.readLong();
            long windowStartTime = input.readLong();
            return new FrequencyState(pak, accessCount, windowStartTime);
        }

        // 复制对象(Flink状态管理必需)
        @Override
        public FrequencyState copy(FrequencyState from) {
            return new FrequencyState(from.getpak(), from.getAccessCount(), from.getWindowStartTime());
        }

        // 复用对象复制(性能优化)
        @Override
        public FrequencyState copy(FrequencyState from, FrequencyState reuse) {
            reuse.setpak(from.getpak());
            reuse.setAccessCount(from.getAccessCount());
            reuse.setWindowStartTime(from.getWindowStartTime());
            return reuse;
        }

        // 估算序列化长度(不确定则返回-1)
        @Override
        public int getLength() { return -1; }

        // 复制序列化器
        @Override
        public TypeSerializer<FrequencyState> duplicate() {
            return new FrequencyStateSerializer();
        }

        // 序列化器 equals/hashCode(兼容性检查)
        @Override
        public boolean equals(Object obj) {
            return obj instanceof FrequencyStateSerializer;
        }

        @Override
        public int hashCode() {
            return FrequencyStateSerializer.class.hashCode();
        }
    }

    // ===================== 3. 自定义数据源 =====================
    public static class CmsAccessLogSource extends RichParallelSourceFunction<CmsAccessLog> {
        private volatile boolean isRunning = true;
        private final Random random = new Random();
        private final String[] paks = {"u1001", "u1002", "u1003", "u1004", "u1005"}; // 模拟5个key
        private final String[] contentIds = {"c2001", "c2002", "c2003"}; // 模拟3个内容

        @Override
        public void run(SourceContext<CmsAccessLog> ctx) throws Exception {
            while (isRunning) {
                // 随机生成访问日志
                String pak = paks[random.nextInt(paks.length)];
                String contentId = contentIds[random.nextInt(contentIds.length)];
                long accessTime = System.currentTimeMillis();

                ctx.collect(new CmsAccessLog(pak, contentId, accessTime));
                TimeUnit.MILLISECONDS.sleep(200); // 每200ms生成一条日志
            }
        }

        @Override
        public void cancel() {
            isRunning = false;
        }
    }

    // ===================== 4. 核心处理函数(统计频率)=====================
    public static class FrequencyStatProcessFunction extends KeyedProcessFunction<String, CmsAccessLog, String> {
        private ValueState<FrequencyState> frequencyState; // 存储频率统计状态
        private static final long WINDOW_SIZE = 60 * 1000; // 统计窗口:1分钟

        // 初始化State(仅执行一次)
        @Override
        public void open(Configuration parameters) throws Exception {
            super.open(parameters);

            // 定义State描述符:指定自定义序列化器
            ValueStateDescriptor<FrequencyState> stateDesc = new ValueStateDescriptor<>(
                "cms-frequency-state", // State名称
                new FrequencyStateSerializer() // 显式指定自定义序列化器
            );

            // 配置State TTL(3分钟过期,避免状态膨胀)
            StateTtlConfig ttlConfig = StateTtlConfig.newBuilder(Time.minutes(3))
                .setUpdateType(StateTtlConfig.UpdateType.OnWriteAndRead)
                .setStateVisibility(StateTtlConfig.StateVisibility.NeverReturnExpired)
                .build();
            stateDesc.enableTimeToLive(ttlConfig);

            // 获取State实例
            frequencyState = getRuntimeContext().getState(stateDesc);
        }

        // 处理每条访问日志
        @Override
        public void processElement(CmsAccessLog log, Context ctx, Collector<String> out) throws Exception {
            String pak = log.getpak();
            long currentTime = log.getAccessTime();
            // 计算当前窗口起始时间(如:1699999200000 = 2023-11-14 10:00:00)
            long currentWindowStart = currentTime - (currentTime % WINDOW_SIZE);

            // 读取State中的统计数据
            FrequencyState state = frequencyState.value();

            if (state == null) {
                // 1. 首次访问:初始化State
                state = new FrequencyState(pak, 1, currentWindowStart);
                // 注册定时器:窗口结束时输出频率(当前窗口+窗口大小)
                ctx.timerService().registerProcessingTimeTimer(currentWindowStart + WINDOW_SIZE);
            } else {
                if (state.getWindowStartTime() == currentWindowStart) {
                    // 2. 同一窗口:累计访问次数
                    state.setAccessCount(state.getAccessCount() + 1);
                } else {
                    // 3. 跨窗口:初始化新窗口State,注册新定时器
                    state.setWindowStartTime(currentWindowStart);
                    state.setAccessCount(1);
                    ctx.timerService().registerProcessingTimeTimer(currentWindowStart + WINDOW_SIZE);
                }
            }

            // 更新State到RocksDB
            frequencyState.update(state);
        }

        // 定时器触发:输出频率统计结果(次/分钟)
        @Override
        public void onTimer(long timestamp, OnTimerContext ctx, Collector<String> out) throws Exception {
            super.onTimer(timestamp, ctx, out);
            String pak = ctx.getCurrentKey();
            FrequencyState state = frequencyState.value();

            if (state != null) {
                long frequency = state.getAccessCount(); // 1分钟内的访问次数(频率=次/分钟)
                String windowTime = String.format("%d:%02d", 
                    (state.getWindowStartTime() / 1000 / 3600) % 24, 
                    (state.getWindowStartTime() / 1000 / 60) % 60);

                // 输出统计结果
                out.collect(String.format(
                    "CMS频率统计 | key:%s | 统计窗口:%s | 访问频率:%d次/分钟",
                    pak, windowTime, frequency
                ));

                // (可选)窗口结束后清空State,避免重复统计
                frequencyState.clear();
            }
        }
    }

    // ===================== 5. 主函数(Job入口)=====================
    public static void main(String[] args) throws Exception {
        // 1. 创建执行环境
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setRuntimeMode(RuntimeExecutionMode.STREAMING); // 流处理模式
        env.setParallelism(1); // 本地测试用1并行度,生产环境按需调整

        // 2. 配置RocksDB状态后端(核心:将State存储到RocksDB)
        env.setStateBackend(new HashMapStateBackend()); // 内存中存储元数据
        env.getCheckpointConfig().setCheckpointStorage("file:///tmp/flink-checkpoint"); // 检查点存储路径
        // (生产环境)启用增量检查点(RocksDB优化)
        // env.getCheckpointConfig().enableIncrementalCheckpointing();

        // 3. 注册自定义序列化器(全局生效)
        env.getConfig().registerTypeWithKryoSerializer(FrequencyState.class, FrequencyStateSerializer.class);
        env.getConfig().disableGenericTypes(); // 禁用泛型序列化,强制使用自定义序列化器

        // 4. 构建数据流
        DataStream<CmsAccessLog> accessLogStream = env.addSource(new CmsAccessLogSource());

        // 5. 按key分组 → 统计频率 → 输出结果
        accessLogStream
            .keyBy(CmsAccessLog::getpak) // 按key分组(State按key隔离)
            .process(new FrequencyStatProcessFunction()) // 核心统计逻辑
            .print("CMS频率统计结果:");

        // 6. 执行Job
        env.execute("CMS Access Frequency Statistics Job");
    }
}