basic_operations.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. ################################################################################
  2. # Licensed to the Apache Software Foundation (ASF) under one
  3. # or more contributor license agreements. See the NOTICE file
  4. # distributed with this work for additional information
  5. # regarding copyright ownership. The ASF licenses this file
  6. # to you under the Apache License, Version 2.0 (the
  7. # "License"); you may not use this file except in compliance
  8. # with the License. You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. ################################################################################
  18. import json
  19. import logging
  20. import sys
  21. from pyflink.common import Row
  22. from pyflink.table import (DataTypes, TableEnvironment, EnvironmentSettings, ExplainDetail)
  23. from pyflink.table.expressions import *
  24. from pyflink.table.udf import udtf, udf, udaf, AggregateFunction, TableAggregateFunction, udtaf
  25. def basic_operations():
  26. t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode())
  27. # define the source
  28. table = t_env.from_elements(
  29. elements=[
  30. (1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'),
  31. (2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'),
  32. (3, '{"name": "world", "tel": 124, "addr": {"country": "USA", "city": "NewYork"}}'),
  33. (4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}')
  34. ],
  35. schema=['id', 'data'])
  36. right_table = t_env.from_elements(elements=[(1, 18), (2, 30), (3, 25), (4, 10)],
  37. schema=['id', 'age'])
  38. table = table.add_columns(
  39. col('data').json_value('$.name', DataTypes.STRING()).alias('name'),
  40. col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'),
  41. col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \
  42. .drop_columns(col('data'))
  43. table.execute().print()
  44. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  45. # | op | id | name | tel | country |
  46. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  47. # | +I | 1 | Flink | 123 | Germany |
  48. # | +I | 2 | hello | 135 | China |
  49. # | +I | 3 | world | 124 | USA |
  50. # | +I | 4 | PyFlink | 32 | China |
  51. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  52. # limit the number of outputs
  53. table.limit(3).execute().print()
  54. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  55. # | op | id | name | tel | country |
  56. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  57. # | +I | 1 | Flink | 123 | Germany |
  58. # | +I | 2 | hello | 135 | China |
  59. # | +I | 3 | world | 124 | USA |
  60. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  61. # filter
  62. table.filter(col('id') != 3).execute().print()
  63. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  64. # | op | id | name | tel | country |
  65. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  66. # | +I | 1 | Flink | 123 | Germany |
  67. # | +I | 2 | hello | 135 | China |
  68. # | +I | 4 | PyFlink | 32 | China |
  69. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  70. # aggregation
  71. table.group_by(col('country')) \
  72. .select(col('country'), col('id').count, col('tel').cast(DataTypes.BIGINT()).max) \
  73. .execute().print()
  74. # +----+--------------------------------+----------------------+----------------------+
  75. # | op | country | EXPR$0 | EXPR$1 |
  76. # +----+--------------------------------+----------------------+----------------------+
  77. # | +I | Germany | 1 | 123 |
  78. # | +I | USA | 1 | 124 |
  79. # | +I | China | 1 | 135 |
  80. # | -U | China | 1 | 135 |
  81. # | +U | China | 2 | 135 |
  82. # +----+--------------------------------+----------------------+----------------------+
  83. # distinct
  84. table.select(col('country')).distinct() \
  85. .execute().print()
  86. # +----+--------------------------------+
  87. # | op | country |
  88. # +----+--------------------------------+
  89. # | +I | Germany |
  90. # | +I | China |
  91. # | +I | USA |
  92. # +----+--------------------------------+
  93. # join
  94. # Note that it still doesn't support duplicate column names between the joined tables
  95. table.join(right_table.rename_columns(col('id').alias('r_id')), col('id') == col('r_id')) \
  96. .execute().print()
  97. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+----------------------+----------------------+
  98. # | op | id | name | tel | country | r_id | age |
  99. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+----------------------+----------------------+
  100. # | +I | 4 | PyFlink | 32 | China | 4 | 10 |
  101. # | +I | 1 | Flink | 123 | Germany | 1 | 18 |
  102. # | +I | 2 | hello | 135 | China | 2 | 30 |
  103. # | +I | 3 | world | 124 | USA | 3 | 25 |
  104. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+----------------------+----------------------+
  105. # join lateral
  106. @udtf(result_types=[DataTypes.STRING()])
  107. def split(r: Row):
  108. for s in r.name.split("i"):
  109. yield s
  110. table.join_lateral(split.alias('a')) \
  111. .execute().print()
  112. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
  113. # | op | id | name | tel | country | a |
  114. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
  115. # | +I | 1 | Flink | 123 | Germany | Fl |
  116. # | +I | 1 | Flink | 123 | Germany | nk |
  117. # | +I | 2 | hello | 135 | China | hello |
  118. # | +I | 3 | world | 124 | USA | world |
  119. # | +I | 4 | PyFlink | 32 | China | PyFl |
  120. # | +I | 4 | PyFlink | 32 | China | nk |
  121. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
  122. # show schema
  123. table.print_schema()
  124. # (
  125. # `id` BIGINT,
  126. # `name` STRING,
  127. # `tel` STRING,
  128. # `country` STRING
  129. # )
  130. # show execute plan
  131. print(table.join_lateral(split.alias('a')).explain())
  132. # == Abstract Syntax Tree ==
  133. # LogicalCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{}])
  134. # :- LogicalProject(id=[$0], name=[JSON_VALUE($1, _UTF-16LE'$.name', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))], tel=[JSON_VALUE($1, _UTF-16LE'$.tel', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))], country=[JSON_VALUE($1, _UTF-16LE'$.addr.country', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))])
  135. # : +- LogicalTableScan(table=[[default_catalog, default_database, Unregistered_TableSource_249535355, source: [PythonInputFormatTableSource(id, data)]]])
  136. # +- LogicalTableFunctionScan(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$1f0568d1f39bef59b4c969a5d620ba46*($0, $1, $2, $3)], rowType=[RecordType(VARCHAR(2147483647) a)], elementType=[class [Ljava.lang.Object;])
  137. #
  138. # == Optimized Physical Plan ==
  139. # PythonCorrelate(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$1f0568d1f39bef59b4c969a5d620ba46*($0, $1, $2, $3)], correlate=[table(split(id,name,tel,country))], select=[id,name,tel,country,a], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) name, VARCHAR(2147483647) tel, VARCHAR(2147483647) country, VARCHAR(2147483647) a)], joinType=[INNER])
  140. # +- Calc(select=[id, JSON_VALUE(data, _UTF-16LE'$.name', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS name, JSON_VALUE(data, _UTF-16LE'$.tel', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS tel, JSON_VALUE(data, _UTF-16LE'$.addr.country', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS country])
  141. # +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_249535355, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data])
  142. #
  143. # == Optimized Execution Plan ==
  144. # PythonCorrelate(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$1f0568d1f39bef59b4c969a5d620ba46*($0, $1, $2, $3)], correlate=[table(split(id,name,tel,country))], select=[id,name,tel,country,a], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) name, VARCHAR(2147483647) tel, VARCHAR(2147483647) country, VARCHAR(2147483647) a)], joinType=[INNER])
  145. # +- Calc(select=[id, JSON_VALUE(data, '$.name', NULL, ON EMPTY, NULL, ON ERROR) AS name, JSON_VALUE(data, '$.tel', NULL, ON EMPTY, NULL, ON ERROR) AS tel, JSON_VALUE(data, '$.addr.country', NULL, ON EMPTY, NULL, ON ERROR) AS country])
  146. # +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_249535355, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data])
  147. # show execute plan with advice
  148. print(table.join_lateral(split.alias('a')).explain(ExplainDetail.PLAN_ADVICE))
  149. # == Abstract Syntax Tree ==
  150. # LogicalCorrelate(correlation=[$cor2], joinType=[inner], requiredColumns=[{}])
  151. # :- LogicalProject(id=[$0], name=[JSON_VALUE($1, _UTF-16LE'$.name', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))], tel=[JSON_VALUE($1, _UTF-16LE'$.tel', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))], country=[JSON_VALUE($1, _UTF-16LE'$.addr.country', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))])
  152. # : +- LogicalTableScan(table=[[*anonymous_python-input-format$1*]])
  153. # +- LogicalTableFunctionScan(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$720258394f6a31d31376164d23142f53*($0, $1, $2, $3)], rowType=[RecordType(VARCHAR(2147483647) a)])
  154. #
  155. # == Optimized Physical Plan With Advice ==
  156. # PythonCorrelate(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$720258394f6a31d31376164d23142f53*($0, $1, $2, $3)], correlate=[table(*org.apache.flink.table.functions.python.PythonTableFunction$720258394f6a31d31376164d23142f53*(id,name,tel,country))], select=[id,name,tel,country,a], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) name, VARCHAR(2147483647) tel, VARCHAR(2147483647) country, VARCHAR(2147483647) a)], joinType=[INNER])
  157. # +- Calc(select=[id, JSON_VALUE(data, _UTF-16LE'$.name', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS name, JSON_VALUE(data, _UTF-16LE'$.tel', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS tel, JSON_VALUE(data, _UTF-16LE'$.addr.country', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS country])
  158. # +- TableSourceScan(table=[[*anonymous_python-input-format$1*]], fields=[id, data])
  159. #
  160. # No available advice...
  161. #
  162. # == Optimized Execution Plan ==
  163. # PythonCorrelate(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$720258394f6a31d31376164d23142f53*($0, $1, $2, $3)], correlate=[table(*org.apache.flink.table.functions.python.PythonTableFunction$720258394f6a31d31376164d23142f53*(id,name,tel,country))], select=[id,name,tel,country,a], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) name, VARCHAR(2147483647) tel, VARCHAR(2147483647) country, VARCHAR(2147483647) a)], joinType=[INNER])
  164. # +- Calc(select=[id, JSON_VALUE(data, '$.name', NULL, ON EMPTY, NULL, ON ERROR) AS name, JSON_VALUE(data, '$.tel', NULL, ON EMPTY, NULL, ON ERROR) AS tel, JSON_VALUE(data, '$.addr.country', NULL, ON EMPTY, NULL, ON ERROR) AS country])
  165. # +- TableSourceScan(table=[[*anonymous_python-input-format$1*]], fields=[id, data])
  166. def sql_operations():
  167. t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode())
  168. # define the source
  169. table = t_env.from_elements(
  170. elements=[
  171. (1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'),
  172. (2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'),
  173. (3, '{"name": "world", "tel": 124, "addr": {"country": "USA", "city": "NewYork"}}'),
  174. (4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}')
  175. ],
  176. schema=['id', 'data'])
  177. t_env.sql_query("SELECT * FROM %s" % table) \
  178. .execute().print()
  179. # +----+----------------------+--------------------------------+
  180. # | op | id | data |
  181. # +----+----------------------+--------------------------------+
  182. # | +I | 1 | {"name": "Flink", "tel": 12... |
  183. # | +I | 2 | {"name": "hello", "tel": 13... |
  184. # | +I | 3 | {"name": "world", "tel": 12... |
  185. # | +I | 4 | {"name": "PyFlink", "tel": ... |
  186. # +----+----------------------+--------------------------------+
  187. # execute sql statement
  188. @udtf(result_types=[DataTypes.STRING(), DataTypes.INT(), DataTypes.STRING()])
  189. def parse_data(data: str):
  190. json_data = json.loads(data)
  191. yield json_data['name'], json_data['tel'], json_data['addr']['country']
  192. t_env.create_temporary_function('parse_data', parse_data)
  193. t_env.execute_sql(
  194. """
  195. SELECT *
  196. FROM %s, LATERAL TABLE(parse_data(`data`)) t(name, tel, country)
  197. """ % table
  198. ).print()
  199. # +----+----------------------+--------------------------------+--------------------------------+-------------+--------------------------------+
  200. # | op | id | data | name | tel | country |
  201. # +----+----------------------+--------------------------------+--------------------------------+-------------+--------------------------------+
  202. # | +I | 1 | {"name": "Flink", "tel": 12... | Flink | 123 | Germany |
  203. # | +I | 2 | {"name": "hello", "tel": 13... | hello | 135 | China |
  204. # | +I | 3 | {"name": "world", "tel": 12... | world | 124 | USA |
  205. # | +I | 4 | {"name": "PyFlink", "tel": ... | PyFlink | 32 | China |
  206. # +----+----------------------+--------------------------------+--------------------------------+-------------+--------------------------------+
  207. # explain sql plan
  208. print(t_env.explain_sql(
  209. """
  210. SELECT *
  211. FROM %s, LATERAL TABLE(parse_data(`data`)) t(name, tel, country)
  212. """ % table
  213. ))
  214. # == Abstract Syntax Tree ==
  215. # LogicalProject(id=[$0], data=[$1], name=[$2], tel=[$3], country=[$4])
  216. # +- LogicalCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{1}])
  217. # :- LogicalTableScan(table=[[default_catalog, default_database, Unregistered_TableSource_734856049, source: [PythonInputFormatTableSource(id, data)]]])
  218. # +- LogicalTableFunctionScan(invocation=[parse_data($cor1.data)], rowType=[RecordType:peek_no_expand(VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)])
  219. #
  220. # == Optimized Physical Plan ==
  221. # PythonCorrelate(invocation=[parse_data($1)], correlate=[table(parse_data(data))], select=[id,data,f0,f1,f2], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) data, VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)], joinType=[INNER])
  222. # +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_734856049, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data])
  223. #
  224. # == Optimized Execution Plan ==
  225. # PythonCorrelate(invocation=[parse_data($1)], correlate=[table(parse_data(data))], select=[id,data,f0,f1,f2], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) data, VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)], joinType=[INNER])
  226. # +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_734856049, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data])
  227. # explain sql plan with advice
  228. print(t_env.explain_sql(
  229. """
  230. SELECT *
  231. FROM %s, LATERAL TABLE(parse_data(`data`)) t(name, tel, country)
  232. """ % table, ExplainDetail.PLAN_ADVICE
  233. ))
  234. # == Abstract Syntax Tree ==
  235. # LogicalProject(id=[$0], data=[$1], name=[$2], tel=[$3], country=[$4])
  236. # +- LogicalCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{1}])
  237. # :- LogicalTableScan(table=[[*anonymous_python-input-format$10*]])
  238. # +- LogicalTableFunctionScan(invocation=[parse_data($cor2.data)], rowType=[RecordType:peek_no_expand(VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)])
  239. #
  240. # == Optimized Physical Plan With Advice ==
  241. # PythonCorrelate(invocation=[parse_data($1)], correlate=[table(parse_data(data))], select=[id,data,f0,f1,f2], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) data, VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)], joinType=[INNER])
  242. # +- TableSourceScan(table=[[*anonymous_python-input-format$10*]], fields=[id, data])
  243. #
  244. # No available advice...
  245. #
  246. # == Optimized Execution Plan ==
  247. # PythonCorrelate(invocation=[parse_data($1)], correlate=[table(parse_data(data))], select=[id,data,f0,f1,f2], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) data, VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)], joinType=[INNER])
  248. # +- TableSourceScan(table=[[*anonymous_python-input-format$10*]], fields=[id, data])
  249. def column_operations():
  250. t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode())
  251. # define the source
  252. table = t_env.from_elements(
  253. elements=[
  254. (1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'),
  255. (2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'),
  256. (3, '{"name": "world", "tel": 124, "addr": {"country": "USA", "city": "NewYork"}}'),
  257. (4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}')
  258. ],
  259. schema=['id', 'data'])
  260. # add columns
  261. table = table.add_columns(
  262. col('data').json_value('$.name', DataTypes.STRING()).alias('name'),
  263. col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'),
  264. col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country'))
  265. table.execute().print()
  266. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
  267. # | op | id | data | name | tel | country |
  268. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
  269. # | +I | 1 | {"name": "Flink", "tel": 12... | Flink | 123 | Germany |
  270. # | +I | 2 | {"name": "hello", "tel": 13... | hello | 135 | China |
  271. # | +I | 3 | {"name": "world", "tel": 12... | world | 124 | USA |
  272. # | +I | 4 | {"name": "PyFlink", "tel": ... | PyFlink | 32 | China |
  273. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
  274. # drop columns
  275. table = table.drop_columns(col('data'))
  276. table.execute().print()
  277. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  278. # | op | id | name | tel | country |
  279. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  280. # | +I | 1 | Flink | 123 | Germany |
  281. # | +I | 2 | hello | 135 | China |
  282. # | +I | 3 | world | 124 | USA |
  283. # | +I | 4 | PyFlink | 32 | China |
  284. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  285. # rename columns
  286. table = table.rename_columns(col('tel').alias('telephone'))
  287. table.execute().print()
  288. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  289. # | op | id | name | telephone | country |
  290. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  291. # | +I | 1 | Flink | 123 | Germany |
  292. # | +I | 2 | hello | 135 | China |
  293. # | +I | 3 | world | 124 | USA |
  294. # | +I | 4 | PyFlink | 32 | China |
  295. # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+
  296. # replace columns
  297. table = table.add_or_replace_columns(
  298. concat(col('id').cast(DataTypes.STRING()), '_', col('name')).alias('id'))
  299. table.execute().print()
  300. # +----+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
  301. # | op | id | name | telephone | country |
  302. # +----+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
  303. # | +I | 1_Flink | Flink | 123 | Germany |
  304. # | +I | 2_hello | hello | 135 | China |
  305. # | +I | 3_world | world | 124 | USA |
  306. # | +I | 4_PyFlink | PyFlink | 32 | China |
  307. # +----+--------------------------------+--------------------------------+--------------------------------+--------------------------------+
  308. def row_operations():
  309. t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode())
  310. # define the source
  311. table = t_env.from_elements(
  312. elements=[
  313. (1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'),
  314. (2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'),
  315. (3, '{"name": "world", "tel": 124, "addr": {"country": "China", "city": "NewYork"}}'),
  316. (4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}')
  317. ],
  318. schema=['id', 'data'])
  319. # map operation
  320. @udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()),
  321. DataTypes.FIELD("country", DataTypes.STRING())]))
  322. def extract_country(input_row: Row):
  323. data = json.loads(input_row.data)
  324. return Row(input_row.id, data['addr']['country'])
  325. table.map(extract_country) \
  326. .execute().print()
  327. # +----+----------------------+--------------------------------+
  328. # | op | id | country |
  329. # +----+----------------------+--------------------------------+
  330. # | +I | 1 | Germany |
  331. # | +I | 2 | China |
  332. # | +I | 3 | China |
  333. # | +I | 4 | China |
  334. # +----+----------------------+--------------------------------+
  335. # flat_map operation
  336. @udtf(result_types=[DataTypes.BIGINT(), DataTypes.STRING()])
  337. def extract_city(input_row: Row):
  338. data = json.loads(input_row.data)
  339. yield input_row.id, data['addr']['city']
  340. table.flat_map(extract_city) \
  341. .execute().print()
  342. # +----+----------------------+--------------------------------+
  343. # | op | f0 | f1 |
  344. # +----+----------------------+--------------------------------+
  345. # | +I | 1 | Berlin |
  346. # | +I | 2 | Shanghai |
  347. # | +I | 3 | NewYork |
  348. # | +I | 4 | Hangzhou |
  349. # +----+----------------------+--------------------------------+
  350. # aggregate operation
  351. class CountAndSumAggregateFunction(AggregateFunction):
  352. def get_value(self, accumulator):
  353. return Row(accumulator[0], accumulator[1])
  354. def create_accumulator(self):
  355. return Row(0, 0)
  356. def accumulate(self, accumulator, input_row):
  357. accumulator[0] += 1
  358. accumulator[1] += int(input_row.tel)
  359. def retract(self, accumulator, input_row):
  360. accumulator[0] -= 1
  361. accumulator[1] -= int(input_row.tel)
  362. def merge(self, accumulator, accumulators):
  363. for other_acc in accumulators:
  364. accumulator[0] += other_acc[0]
  365. accumulator[1] += other_acc[1]
  366. def get_accumulator_type(self):
  367. return DataTypes.ROW(
  368. [DataTypes.FIELD("cnt", DataTypes.BIGINT()),
  369. DataTypes.FIELD("sum", DataTypes.BIGINT())])
  370. def get_result_type(self):
  371. return DataTypes.ROW(
  372. [DataTypes.FIELD("cnt", DataTypes.BIGINT()),
  373. DataTypes.FIELD("sum", DataTypes.BIGINT())])
  374. count_sum = udaf(CountAndSumAggregateFunction())
  375. table.add_columns(
  376. col('data').json_value('$.name', DataTypes.STRING()).alias('name'),
  377. col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'),
  378. col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \
  379. .group_by(col('country')) \
  380. .aggregate(count_sum.alias("cnt", "sum")) \
  381. .select(col('country'), col('cnt'), col('sum')) \
  382. .execute().print()
  383. # +----+--------------------------------+----------------------+----------------------+
  384. # | op | country | cnt | sum |
  385. # +----+--------------------------------+----------------------+----------------------+
  386. # | +I | China | 3 | 291 |
  387. # | +I | Germany | 1 | 123 |
  388. # +----+--------------------------------+----------------------+----------------------+
  389. # flat_aggregate operation
  390. class Top2(TableAggregateFunction):
  391. def emit_value(self, accumulator):
  392. for v in accumulator:
  393. if v:
  394. yield Row(v)
  395. def create_accumulator(self):
  396. return [None, None]
  397. def accumulate(self, accumulator, input_row):
  398. tel = int(input_row.tel)
  399. if accumulator[0] is None or tel > accumulator[0]:
  400. accumulator[1] = accumulator[0]
  401. accumulator[0] = tel
  402. elif accumulator[1] is None or tel > accumulator[1]:
  403. accumulator[1] = tel
  404. def get_accumulator_type(self):
  405. return DataTypes.ARRAY(DataTypes.BIGINT())
  406. def get_result_type(self):
  407. return DataTypes.ROW(
  408. [DataTypes.FIELD("tel", DataTypes.BIGINT())])
  409. top2 = udtaf(Top2())
  410. table.add_columns(
  411. col('data').json_value('$.name', DataTypes.STRING()).alias('name'),
  412. col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'),
  413. col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \
  414. .group_by(col('country')) \
  415. .flat_aggregate(top2) \
  416. .select(col('country'), col('tel')) \
  417. .execute().print()
  418. # +----+--------------------------------+----------------------+
  419. # | op | country | tel |
  420. # +----+--------------------------------+----------------------+
  421. # | +I | China | 135 |
  422. # | +I | China | 124 |
  423. # | +I | Germany | 123 |
  424. # +----+--------------------------------+----------------------+
  425. if __name__ == '__main__':
  426. logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(message)s")
  427. basic_operations()
  428. sql_operations()
  429. column_operations()
  430. row_operations()