Apache Flink 除了提供了大部分ANSI-SQL的核心算子,也为用户提供了自己编写业务代码的机会,那就是User-Defined Function,目前支持如下三种 User-Defined Function:
- UDF - User-Defined Scalar Function
- UDTF - User-Defined Table Function
- UDAF - User-Defined Aggregate Funciton
UDX都是用户自定义的函数,那么Apache Flink框架为啥将自定义的函数分成三类呢?是根据什么划分的呢?Apache Flink对自定义函数进行分类的依据是根据函数语义的不同,函数的输入和输出不同来分类的,具体如下:

1. UDF
a. 定义
用户想自己编写一个字符串联接的UDF,我们只需要实现ScalarFunction#eval()方法即可,简单实现如下:
- object MyConnect extends ScalarFunction {
- @varargs
- def eval(args: String*): String = {
- val sb = new StringBuilder
- var i = 0
- while (i < args.length) {
- if (args(i) == null) {
- return null
- }
- sb.append(args(i))
- i += 1
- }
- sb.toString
- }}
b. 使用
- ...
- val fun = MyConnect
- tEnv.registerFunction("myConnect", fun)
- val sql = "SELECT myConnect(a, b) as str FROM tab"
- ...
2. UDTF
a. 定义
用户想自己编写一个字符串切分的UDTF,我们只需要实现TableFunction#eval()方法即可,简单实现如下:
ScalarFunction#eval()`
- class MySplit extends TableFunction[String] {
- def eval(str: String): Unit = {
- if (str.contains("#")){
- str.split("#").foreach(collect)
- }
- }
-
- def eval(str: String, prefix: String): Unit = {
- if (str.contains("#")) {
- str.split("#").foreach(s => collect(prefix + s))
- }
- }}
b. 使用
- ...
- val fun = new MySplit()
- tEnv.registerFunction("mySplit", fun)
- val sql = "SELECT c, s FROM MyTable, LATERAL TABLE(mySplit(c)) AS T(s)"
- ...
3. UDAF
a. 定义
UDAF 要实现的接口比较多,我们以一个简单的CountAGG为例,做简单实现如下:
- /** The initial accumulator for count aggregate function */
- class CountAccumulator extends JTuple1[Long] {
- f0 = 0L //count
- }
-
- /**
- * User-defined count aggregate function
- */
- class MyCount
- extends AggregateFunction[JLong, CountAccumulator] {
-
- // process argument is optimized by Calcite.
- // For instance count(42) or count(*) will be optimized to count().
- def accumulate(acc: CountAccumulator): Unit = {
- acc.f0 += 1L
- }
-
- // process argument is optimized by Calcite.
- // For instance count(42) or count(*) will be optimized to count().
- def retract(acc: CountAccumulator): Unit = {
- acc.f0 -= 1L
- }
-
- def accumulate(acc: CountAccumulator, value: Any): Unit = {
- if (value != null) {
- acc.f0 += 1L
- }
- }
-
- def retract(acc: CountAccumulator, value: Any): Unit = {
- if (value != null) {
- acc.f0 -= 1L
- }
- }
-
- override def getValue(acc: CountAccumulator): JLong = {
- acc.f0
- }
-
- def merge(acc: CountAccumulator, its: JIterable[CountAccumulator]): Unit = {
- val iter = its.iterator()
- while (iter.hasNext) {
- acc.f0 += iter.next().f0
- }
- }
-
- override def createAccumulator(): CountAccumulator = {
- new CountAccumulator
- }
-
- def resetAccumulator(acc: CountAccumulator): Unit = {
- acc.f0 = 0L
- }
-
- override def getAccumulatorType: TypeInformation[CountAccumulator] = {
- new TupleTypeInfo(classOf[CountAccumulator], BasicTypeInfo.LONG_TYPE_INFO)
- }
-
- override def getResultType: TypeInformation[JLong] =
- BasicTypeInfo.LONG_TYPE_INFO}
(编辑:惠州站长网)
【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容!
|